Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- config.json +2885 -0
- generation_config.json +9 -0
- lib/__init__.py +1 -0
- lib/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/__pycache__/config.cpython-311.pyc +0 -0
- lib/algo/__init__.py +0 -0
- lib/algo/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/algo/__pycache__/ldlq.cpython-311.pyc +0 -0
- lib/algo/ldlq.py +203 -0
- lib/algo/ldlq_beam_cd.py +209 -0
- lib/codebook/__pycache__/bitshift.cpython-311.pyc +0 -0
- lib/codebook/__pycache__/vq_codebook.cpython-311.pyc +0 -0
- lib/codebook/bitshift.py +486 -0
- lib/codebook/vq_codebook.py +56 -0
- lib/config.py +6 -0
- lib/linear/__init__.py +430 -0
- lib/linear/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/linear/__pycache__/comb_linear.cpython-311.pyc +0 -0
- lib/linear/__pycache__/incoherent_linear.cpython-311.pyc +0 -0
- lib/linear/__pycache__/quantized_linear.cpython-311.pyc +0 -0
- lib/linear/__pycache__/tcq_linear.cpython-311.pyc +0 -0
- lib/linear/__pycache__/vq_linear.cpython-311.pyc +0 -0
- lib/linear/comb_linear.py +325 -0
- lib/linear/incoherent_linear.py +639 -0
- lib/linear/quantized_linear.py +154 -0
- lib/linear/rotation.py +16 -0
- lib/linear/tcq_linear.py +122 -0
- lib/linear/vq_linear.py +208 -0
- lib/quantizer/__pycache__/comb_quant.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/nuq_op.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/pack_op.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.1.nbc +0 -0
- lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.nbi +0 -0
- lib/quantizer/__pycache__/pack_op.pack_32-242.py311.1.nbc +0 -0
- lib/quantizer/__pycache__/pack_op.pack_32-242.py311.nbi +0 -0
- lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc +3 -0
- lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.nbi +0 -0
- lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.1.nbc +0 -0
- lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.nbi +0 -0
- lib/quantizer/__pycache__/quant_op.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/tcq_quant.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/vq_quant.cpython-311.pyc +0 -0
- lib/quantizer/__pycache__/vq_quant_ldlq.cpython-311.pyc +0 -0
- lib/quantizer/comb_quant.py +201 -0
- lib/quantizer/nuq_op.py +431 -0
- lib/quantizer/pack_op.py +335 -0
- lib/quantizer/quant_op.py +277 -0
- lib/quantizer/tcq_quant.py +160 -0
- lib/quantizer/vq_quant.py +149 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
lib/utils/__pycache__/matmul_had.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
config.json
ADDED
|
@@ -0,0 +1,2885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "meta-llama/Llama-3.2-1B",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"LlamaForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_bias": false,
|
| 7 |
+
"attention_dropout": 0.0,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoModelForCausalLM": "qpal_modelling_llama.QPalLlamaForCausalLM"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 128000,
|
| 12 |
+
"eos_token_id": 128001,
|
| 13 |
+
"head_dim": 64,
|
| 14 |
+
"hidden_act": "silu",
|
| 15 |
+
"hidden_size": 2048,
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": 8192,
|
| 18 |
+
"max_position_embeddings": 131072,
|
| 19 |
+
"mlp_bias": false,
|
| 20 |
+
"model_type": "llama",
|
| 21 |
+
"num_attention_heads": 32,
|
| 22 |
+
"num_hidden_layers": 16,
|
| 23 |
+
"num_key_value_heads": 8,
|
| 24 |
+
"pretraining_tp": 1,
|
| 25 |
+
"qpal_quant_config": {
|
| 26 |
+
"modules": {
|
| 27 |
+
"model.layers.0.mlp.down_proj": {
|
| 28 |
+
"bias": false,
|
| 29 |
+
"dtype": "float32",
|
| 30 |
+
"hadU": 8192,
|
| 31 |
+
"hadV": 2048,
|
| 32 |
+
"in_features": 8192,
|
| 33 |
+
"linear": {
|
| 34 |
+
"KV": 7,
|
| 35 |
+
"L": 16,
|
| 36 |
+
"V": 2,
|
| 37 |
+
"bias": false,
|
| 38 |
+
"in_features": 8192,
|
| 39 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 40 |
+
"linear_dtype": "float32",
|
| 41 |
+
"out_features": 2048,
|
| 42 |
+
"td_x": 16,
|
| 43 |
+
"td_y": 16,
|
| 44 |
+
"tlut_bits": 9
|
| 45 |
+
},
|
| 46 |
+
"module_type": "IncoherentLinear",
|
| 47 |
+
"out_features": 2048,
|
| 48 |
+
"rot_info": "skip_r",
|
| 49 |
+
"scale": 32.0
|
| 50 |
+
},
|
| 51 |
+
"model.layers.0.mlp.gate_proj": {
|
| 52 |
+
"bias": false,
|
| 53 |
+
"dtype": "float32",
|
| 54 |
+
"hadU": 2048,
|
| 55 |
+
"hadV": 8192,
|
| 56 |
+
"in_features": 2048,
|
| 57 |
+
"linear": {
|
| 58 |
+
"KV": 5,
|
| 59 |
+
"L": 16,
|
| 60 |
+
"V": 2,
|
| 61 |
+
"bias": false,
|
| 62 |
+
"in_features": 2048,
|
| 63 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 64 |
+
"linear_dtype": "float32",
|
| 65 |
+
"out_features": 8192,
|
| 66 |
+
"td_x": 16,
|
| 67 |
+
"td_y": 16,
|
| 68 |
+
"tlut_bits": 9
|
| 69 |
+
},
|
| 70 |
+
"module_type": "IncoherentLinear",
|
| 71 |
+
"out_features": 8192,
|
| 72 |
+
"rot_info": "skip_r",
|
| 73 |
+
"scale": 32.0
|
| 74 |
+
},
|
| 75 |
+
"model.layers.0.mlp.up_proj": {
|
| 76 |
+
"bias": false,
|
| 77 |
+
"dtype": "float32",
|
| 78 |
+
"hadU": 2048,
|
| 79 |
+
"hadV": 8192,
|
| 80 |
+
"in_features": 2048,
|
| 81 |
+
"linear": {
|
| 82 |
+
"KV": 6,
|
| 83 |
+
"L": 16,
|
| 84 |
+
"V": 2,
|
| 85 |
+
"bias": false,
|
| 86 |
+
"in_features": 2048,
|
| 87 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 88 |
+
"linear_dtype": "float32",
|
| 89 |
+
"out_features": 8192,
|
| 90 |
+
"td_x": 16,
|
| 91 |
+
"td_y": 16,
|
| 92 |
+
"tlut_bits": 9
|
| 93 |
+
},
|
| 94 |
+
"module_type": "IncoherentLinear",
|
| 95 |
+
"out_features": 8192,
|
| 96 |
+
"rot_info": "skip_r",
|
| 97 |
+
"scale": 32.0
|
| 98 |
+
},
|
| 99 |
+
"model.layers.0.self_attn.k_proj": {
|
| 100 |
+
"bias": false,
|
| 101 |
+
"dtype": "float32",
|
| 102 |
+
"hadU": 2048,
|
| 103 |
+
"hadV": 512,
|
| 104 |
+
"in_features": 2048,
|
| 105 |
+
"linear": {
|
| 106 |
+
"KV": 7,
|
| 107 |
+
"L": 16,
|
| 108 |
+
"V": 2,
|
| 109 |
+
"bias": false,
|
| 110 |
+
"in_features": 2048,
|
| 111 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 112 |
+
"linear_dtype": "float32",
|
| 113 |
+
"out_features": 512,
|
| 114 |
+
"td_x": 16,
|
| 115 |
+
"td_y": 16,
|
| 116 |
+
"tlut_bits": 9
|
| 117 |
+
},
|
| 118 |
+
"module_type": "IncoherentLinear",
|
| 119 |
+
"out_features": 512,
|
| 120 |
+
"rot_info": "skip_r",
|
| 121 |
+
"scale": 32.0
|
| 122 |
+
},
|
| 123 |
+
"model.layers.0.self_attn.o_proj": {
|
| 124 |
+
"bias": false,
|
| 125 |
+
"dtype": "float32",
|
| 126 |
+
"hadU": 2048,
|
| 127 |
+
"hadV": 2048,
|
| 128 |
+
"in_features": 2048,
|
| 129 |
+
"linear": {
|
| 130 |
+
"KV": 7,
|
| 131 |
+
"L": 16,
|
| 132 |
+
"V": 2,
|
| 133 |
+
"bias": false,
|
| 134 |
+
"in_features": 2048,
|
| 135 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 136 |
+
"linear_dtype": "float32",
|
| 137 |
+
"out_features": 2048,
|
| 138 |
+
"td_x": 16,
|
| 139 |
+
"td_y": 16,
|
| 140 |
+
"tlut_bits": 9
|
| 141 |
+
},
|
| 142 |
+
"module_type": "IncoherentLinear",
|
| 143 |
+
"out_features": 2048,
|
| 144 |
+
"rot_info": "skip_r",
|
| 145 |
+
"scale": 32.0
|
| 146 |
+
},
|
| 147 |
+
"model.layers.0.self_attn.q_proj": {
|
| 148 |
+
"bias": false,
|
| 149 |
+
"dtype": "float32",
|
| 150 |
+
"hadU": 2048,
|
| 151 |
+
"hadV": 2048,
|
| 152 |
+
"in_features": 2048,
|
| 153 |
+
"linear": {
|
| 154 |
+
"KV": 4,
|
| 155 |
+
"L": 16,
|
| 156 |
+
"V": 2,
|
| 157 |
+
"bias": false,
|
| 158 |
+
"in_features": 2048,
|
| 159 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 160 |
+
"linear_dtype": "float32",
|
| 161 |
+
"out_features": 2048,
|
| 162 |
+
"td_x": 16,
|
| 163 |
+
"td_y": 16,
|
| 164 |
+
"tlut_bits": 9
|
| 165 |
+
},
|
| 166 |
+
"module_type": "IncoherentLinear",
|
| 167 |
+
"out_features": 2048,
|
| 168 |
+
"rot_info": "skip_r",
|
| 169 |
+
"scale": 32.0
|
| 170 |
+
},
|
| 171 |
+
"model.layers.0.self_attn.v_proj": {
|
| 172 |
+
"bias": false,
|
| 173 |
+
"dtype": "float32",
|
| 174 |
+
"hadU": 2048,
|
| 175 |
+
"hadV": 512,
|
| 176 |
+
"in_features": 2048,
|
| 177 |
+
"linear": {
|
| 178 |
+
"KV": [
|
| 179 |
+
9,
|
| 180 |
+
10
|
| 181 |
+
],
|
| 182 |
+
"L": 16,
|
| 183 |
+
"V": 2,
|
| 184 |
+
"bias": false,
|
| 185 |
+
"in_features": 2048,
|
| 186 |
+
"in_part": [
|
| 187 |
+
1024,
|
| 188 |
+
1024
|
| 189 |
+
],
|
| 190 |
+
"linear_cls": "CombtLinearTCQ",
|
| 191 |
+
"linear_dtype": "float32",
|
| 192 |
+
"out_features": 512,
|
| 193 |
+
"td_x": 16,
|
| 194 |
+
"td_y": 16,
|
| 195 |
+
"tlut_bits": 11
|
| 196 |
+
},
|
| 197 |
+
"module_type": "IncoherentLinear",
|
| 198 |
+
"out_features": 512,
|
| 199 |
+
"rot_info": "skip_r",
|
| 200 |
+
"scale": 32.0
|
| 201 |
+
},
|
| 202 |
+
"model.layers.1.mlp.down_proj": {
|
| 203 |
+
"bias": false,
|
| 204 |
+
"dtype": "float32",
|
| 205 |
+
"hadU": 8192,
|
| 206 |
+
"hadV": 2048,
|
| 207 |
+
"in_features": 8192,
|
| 208 |
+
"linear": {
|
| 209 |
+
"KV": 10,
|
| 210 |
+
"L": 16,
|
| 211 |
+
"V": 2,
|
| 212 |
+
"bias": false,
|
| 213 |
+
"in_features": 8192,
|
| 214 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 215 |
+
"linear_dtype": "float32",
|
| 216 |
+
"out_features": 2048,
|
| 217 |
+
"td_x": 16,
|
| 218 |
+
"td_y": 16,
|
| 219 |
+
"tlut_bits": 11
|
| 220 |
+
},
|
| 221 |
+
"module_type": "IncoherentLinear",
|
| 222 |
+
"out_features": 2048,
|
| 223 |
+
"rot_info": "skip_r",
|
| 224 |
+
"scale": 32.0
|
| 225 |
+
},
|
| 226 |
+
"model.layers.1.mlp.gate_proj": {
|
| 227 |
+
"bias": false,
|
| 228 |
+
"dtype": "float32",
|
| 229 |
+
"hadU": 2048,
|
| 230 |
+
"hadV": 8192,
|
| 231 |
+
"in_features": 2048,
|
| 232 |
+
"linear": {
|
| 233 |
+
"KV": 5,
|
| 234 |
+
"L": 16,
|
| 235 |
+
"V": 2,
|
| 236 |
+
"bias": false,
|
| 237 |
+
"in_features": 2048,
|
| 238 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 239 |
+
"linear_dtype": "float32",
|
| 240 |
+
"out_features": 8192,
|
| 241 |
+
"td_x": 16,
|
| 242 |
+
"td_y": 16,
|
| 243 |
+
"tlut_bits": 9
|
| 244 |
+
},
|
| 245 |
+
"module_type": "IncoherentLinear",
|
| 246 |
+
"out_features": 8192,
|
| 247 |
+
"rot_info": "skip_r",
|
| 248 |
+
"scale": 32.0
|
| 249 |
+
},
|
| 250 |
+
"model.layers.1.mlp.up_proj": {
|
| 251 |
+
"bias": false,
|
| 252 |
+
"dtype": "float32",
|
| 253 |
+
"hadU": 2048,
|
| 254 |
+
"hadV": 8192,
|
| 255 |
+
"in_features": 2048,
|
| 256 |
+
"linear": {
|
| 257 |
+
"KV": 5,
|
| 258 |
+
"L": 16,
|
| 259 |
+
"V": 2,
|
| 260 |
+
"bias": false,
|
| 261 |
+
"in_features": 2048,
|
| 262 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 263 |
+
"linear_dtype": "float32",
|
| 264 |
+
"out_features": 8192,
|
| 265 |
+
"td_x": 16,
|
| 266 |
+
"td_y": 16,
|
| 267 |
+
"tlut_bits": 9
|
| 268 |
+
},
|
| 269 |
+
"module_type": "IncoherentLinear",
|
| 270 |
+
"out_features": 8192,
|
| 271 |
+
"rot_info": "skip_r",
|
| 272 |
+
"scale": 32.0
|
| 273 |
+
},
|
| 274 |
+
"model.layers.1.self_attn.k_proj": {
|
| 275 |
+
"bias": false,
|
| 276 |
+
"dtype": "float32",
|
| 277 |
+
"hadU": 2048,
|
| 278 |
+
"hadV": 512,
|
| 279 |
+
"in_features": 2048,
|
| 280 |
+
"linear": {
|
| 281 |
+
"KV": 9,
|
| 282 |
+
"L": 16,
|
| 283 |
+
"V": 2,
|
| 284 |
+
"bias": false,
|
| 285 |
+
"in_features": 2048,
|
| 286 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 287 |
+
"linear_dtype": "float32",
|
| 288 |
+
"out_features": 512,
|
| 289 |
+
"td_x": 16,
|
| 290 |
+
"td_y": 16,
|
| 291 |
+
"tlut_bits": 10
|
| 292 |
+
},
|
| 293 |
+
"module_type": "IncoherentLinear",
|
| 294 |
+
"out_features": 512,
|
| 295 |
+
"rot_info": "skip_r",
|
| 296 |
+
"scale": 32.0
|
| 297 |
+
},
|
| 298 |
+
"model.layers.1.self_attn.o_proj": {
|
| 299 |
+
"bias": false,
|
| 300 |
+
"dtype": "float32",
|
| 301 |
+
"hadU": 2048,
|
| 302 |
+
"hadV": 2048,
|
| 303 |
+
"in_features": 2048,
|
| 304 |
+
"linear": {
|
| 305 |
+
"KV": 7,
|
| 306 |
+
"L": 16,
|
| 307 |
+
"V": 2,
|
| 308 |
+
"bias": false,
|
| 309 |
+
"in_features": 2048,
|
| 310 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 311 |
+
"linear_dtype": "float32",
|
| 312 |
+
"out_features": 2048,
|
| 313 |
+
"td_x": 16,
|
| 314 |
+
"td_y": 16,
|
| 315 |
+
"tlut_bits": 9
|
| 316 |
+
},
|
| 317 |
+
"module_type": "IncoherentLinear",
|
| 318 |
+
"out_features": 2048,
|
| 319 |
+
"rot_info": "skip_r",
|
| 320 |
+
"scale": 32.0
|
| 321 |
+
},
|
| 322 |
+
"model.layers.1.self_attn.q_proj": {
|
| 323 |
+
"bias": false,
|
| 324 |
+
"dtype": "float32",
|
| 325 |
+
"hadU": 2048,
|
| 326 |
+
"hadV": 2048,
|
| 327 |
+
"in_features": 2048,
|
| 328 |
+
"linear": {
|
| 329 |
+
"KV": 7,
|
| 330 |
+
"L": 16,
|
| 331 |
+
"V": 2,
|
| 332 |
+
"bias": false,
|
| 333 |
+
"in_features": 2048,
|
| 334 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 335 |
+
"linear_dtype": "float32",
|
| 336 |
+
"out_features": 2048,
|
| 337 |
+
"td_x": 16,
|
| 338 |
+
"td_y": 16,
|
| 339 |
+
"tlut_bits": 9
|
| 340 |
+
},
|
| 341 |
+
"module_type": "IncoherentLinear",
|
| 342 |
+
"out_features": 2048,
|
| 343 |
+
"rot_info": "skip_r",
|
| 344 |
+
"scale": 32.0
|
| 345 |
+
},
|
| 346 |
+
"model.layers.1.self_attn.v_proj": {
|
| 347 |
+
"bias": false,
|
| 348 |
+
"dtype": "float32",
|
| 349 |
+
"hadU": 2048,
|
| 350 |
+
"hadV": 512,
|
| 351 |
+
"in_features": 2048,
|
| 352 |
+
"linear": {
|
| 353 |
+
"KV": [
|
| 354 |
+
9,
|
| 355 |
+
10
|
| 356 |
+
],
|
| 357 |
+
"L": 16,
|
| 358 |
+
"V": 2,
|
| 359 |
+
"bias": false,
|
| 360 |
+
"in_features": 2048,
|
| 361 |
+
"in_part": [
|
| 362 |
+
1024,
|
| 363 |
+
1024
|
| 364 |
+
],
|
| 365 |
+
"linear_cls": "CombtLinearTCQ",
|
| 366 |
+
"linear_dtype": "float32",
|
| 367 |
+
"out_features": 512,
|
| 368 |
+
"td_x": 16,
|
| 369 |
+
"td_y": 16,
|
| 370 |
+
"tlut_bits": 11
|
| 371 |
+
},
|
| 372 |
+
"module_type": "IncoherentLinear",
|
| 373 |
+
"out_features": 512,
|
| 374 |
+
"rot_info": "skip_r",
|
| 375 |
+
"scale": 32.0
|
| 376 |
+
},
|
| 377 |
+
"model.layers.10.mlp.down_proj": {
|
| 378 |
+
"bias": false,
|
| 379 |
+
"dtype": "float32",
|
| 380 |
+
"hadU": 8192,
|
| 381 |
+
"hadV": 2048,
|
| 382 |
+
"in_features": 8192,
|
| 383 |
+
"linear": {
|
| 384 |
+
"KV": 7,
|
| 385 |
+
"L": 16,
|
| 386 |
+
"V": 2,
|
| 387 |
+
"bias": false,
|
| 388 |
+
"in_features": 8192,
|
| 389 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 390 |
+
"linear_dtype": "float32",
|
| 391 |
+
"out_features": 2048,
|
| 392 |
+
"td_x": 16,
|
| 393 |
+
"td_y": 16,
|
| 394 |
+
"tlut_bits": 9
|
| 395 |
+
},
|
| 396 |
+
"module_type": "IncoherentLinear",
|
| 397 |
+
"out_features": 2048,
|
| 398 |
+
"rot_info": "skip_r",
|
| 399 |
+
"scale": 32.0
|
| 400 |
+
},
|
| 401 |
+
"model.layers.10.mlp.gate_proj": {
|
| 402 |
+
"bias": false,
|
| 403 |
+
"dtype": "float32",
|
| 404 |
+
"hadU": 2048,
|
| 405 |
+
"hadV": 8192,
|
| 406 |
+
"in_features": 2048,
|
| 407 |
+
"linear": {
|
| 408 |
+
"KV": 6,
|
| 409 |
+
"L": 16,
|
| 410 |
+
"V": 2,
|
| 411 |
+
"bias": false,
|
| 412 |
+
"in_features": 2048,
|
| 413 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 414 |
+
"linear_dtype": "float32",
|
| 415 |
+
"out_features": 8192,
|
| 416 |
+
"td_x": 16,
|
| 417 |
+
"td_y": 16,
|
| 418 |
+
"tlut_bits": 9
|
| 419 |
+
},
|
| 420 |
+
"module_type": "IncoherentLinear",
|
| 421 |
+
"out_features": 8192,
|
| 422 |
+
"rot_info": "skip_r",
|
| 423 |
+
"scale": 32.0
|
| 424 |
+
},
|
| 425 |
+
"model.layers.10.mlp.up_proj": {
|
| 426 |
+
"bias": false,
|
| 427 |
+
"dtype": "float32",
|
| 428 |
+
"hadU": 2048,
|
| 429 |
+
"hadV": 8192,
|
| 430 |
+
"in_features": 2048,
|
| 431 |
+
"linear": {
|
| 432 |
+
"KV": 6,
|
| 433 |
+
"L": 16,
|
| 434 |
+
"V": 2,
|
| 435 |
+
"bias": false,
|
| 436 |
+
"in_features": 2048,
|
| 437 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 438 |
+
"linear_dtype": "float32",
|
| 439 |
+
"out_features": 8192,
|
| 440 |
+
"td_x": 16,
|
| 441 |
+
"td_y": 16,
|
| 442 |
+
"tlut_bits": 9
|
| 443 |
+
},
|
| 444 |
+
"module_type": "IncoherentLinear",
|
| 445 |
+
"out_features": 8192,
|
| 446 |
+
"rot_info": "skip_r",
|
| 447 |
+
"scale": 32.0
|
| 448 |
+
},
|
| 449 |
+
"model.layers.10.self_attn.k_proj": {
|
| 450 |
+
"bias": false,
|
| 451 |
+
"dtype": "float32",
|
| 452 |
+
"hadU": 2048,
|
| 453 |
+
"hadV": 512,
|
| 454 |
+
"in_features": 2048,
|
| 455 |
+
"linear": {
|
| 456 |
+
"KV": 9,
|
| 457 |
+
"L": 16,
|
| 458 |
+
"V": 2,
|
| 459 |
+
"bias": false,
|
| 460 |
+
"in_features": 2048,
|
| 461 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 462 |
+
"linear_dtype": "float32",
|
| 463 |
+
"out_features": 512,
|
| 464 |
+
"td_x": 16,
|
| 465 |
+
"td_y": 16,
|
| 466 |
+
"tlut_bits": 10
|
| 467 |
+
},
|
| 468 |
+
"module_type": "IncoherentLinear",
|
| 469 |
+
"out_features": 512,
|
| 470 |
+
"rot_info": "skip_r",
|
| 471 |
+
"scale": 32.0
|
| 472 |
+
},
|
| 473 |
+
"model.layers.10.self_attn.o_proj": {
|
| 474 |
+
"bias": false,
|
| 475 |
+
"dtype": "float32",
|
| 476 |
+
"hadU": 2048,
|
| 477 |
+
"hadV": 2048,
|
| 478 |
+
"in_features": 2048,
|
| 479 |
+
"linear": {
|
| 480 |
+
"KV": 8,
|
| 481 |
+
"L": 16,
|
| 482 |
+
"V": 2,
|
| 483 |
+
"bias": false,
|
| 484 |
+
"in_features": 2048,
|
| 485 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 486 |
+
"linear_dtype": "float32",
|
| 487 |
+
"out_features": 2048,
|
| 488 |
+
"td_x": 16,
|
| 489 |
+
"td_y": 16,
|
| 490 |
+
"tlut_bits": 9
|
| 491 |
+
},
|
| 492 |
+
"module_type": "IncoherentLinear",
|
| 493 |
+
"out_features": 2048,
|
| 494 |
+
"rot_info": "skip_r",
|
| 495 |
+
"scale": 32.0
|
| 496 |
+
},
|
| 497 |
+
"model.layers.10.self_attn.q_proj": {
|
| 498 |
+
"bias": false,
|
| 499 |
+
"dtype": "float32",
|
| 500 |
+
"hadU": 2048,
|
| 501 |
+
"hadV": 2048,
|
| 502 |
+
"in_features": 2048,
|
| 503 |
+
"linear": {
|
| 504 |
+
"KV": 8,
|
| 505 |
+
"L": 16,
|
| 506 |
+
"V": 2,
|
| 507 |
+
"bias": false,
|
| 508 |
+
"in_features": 2048,
|
| 509 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 510 |
+
"linear_dtype": "float32",
|
| 511 |
+
"out_features": 2048,
|
| 512 |
+
"td_x": 16,
|
| 513 |
+
"td_y": 16,
|
| 514 |
+
"tlut_bits": 9
|
| 515 |
+
},
|
| 516 |
+
"module_type": "IncoherentLinear",
|
| 517 |
+
"out_features": 2048,
|
| 518 |
+
"rot_info": "skip_r",
|
| 519 |
+
"scale": 32.0
|
| 520 |
+
},
|
| 521 |
+
"model.layers.10.self_attn.v_proj": {
|
| 522 |
+
"bias": false,
|
| 523 |
+
"dtype": "float32",
|
| 524 |
+
"hadU": 2048,
|
| 525 |
+
"hadV": 512,
|
| 526 |
+
"in_features": 2048,
|
| 527 |
+
"linear": {
|
| 528 |
+
"KV": [
|
| 529 |
+
9,
|
| 530 |
+
10
|
| 531 |
+
],
|
| 532 |
+
"L": 16,
|
| 533 |
+
"V": 2,
|
| 534 |
+
"bias": false,
|
| 535 |
+
"in_features": 2048,
|
| 536 |
+
"in_part": [
|
| 537 |
+
1024,
|
| 538 |
+
1024
|
| 539 |
+
],
|
| 540 |
+
"linear_cls": "CombtLinearTCQ",
|
| 541 |
+
"linear_dtype": "float32",
|
| 542 |
+
"out_features": 512,
|
| 543 |
+
"td_x": 16,
|
| 544 |
+
"td_y": 16,
|
| 545 |
+
"tlut_bits": 11
|
| 546 |
+
},
|
| 547 |
+
"module_type": "IncoherentLinear",
|
| 548 |
+
"out_features": 512,
|
| 549 |
+
"rot_info": "skip_r",
|
| 550 |
+
"scale": 32.0
|
| 551 |
+
},
|
| 552 |
+
"model.layers.11.mlp.down_proj": {
|
| 553 |
+
"bias": false,
|
| 554 |
+
"dtype": "float32",
|
| 555 |
+
"hadU": 8192,
|
| 556 |
+
"hadV": 2048,
|
| 557 |
+
"in_features": 8192,
|
| 558 |
+
"linear": {
|
| 559 |
+
"KV": 7,
|
| 560 |
+
"L": 16,
|
| 561 |
+
"V": 2,
|
| 562 |
+
"bias": false,
|
| 563 |
+
"in_features": 8192,
|
| 564 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 565 |
+
"linear_dtype": "float32",
|
| 566 |
+
"out_features": 2048,
|
| 567 |
+
"td_x": 16,
|
| 568 |
+
"td_y": 16,
|
| 569 |
+
"tlut_bits": 9
|
| 570 |
+
},
|
| 571 |
+
"module_type": "IncoherentLinear",
|
| 572 |
+
"out_features": 2048,
|
| 573 |
+
"rot_info": "skip_r",
|
| 574 |
+
"scale": 32.0
|
| 575 |
+
},
|
| 576 |
+
"model.layers.11.mlp.gate_proj": {
|
| 577 |
+
"bias": false,
|
| 578 |
+
"dtype": "float32",
|
| 579 |
+
"hadU": 2048,
|
| 580 |
+
"hadV": 8192,
|
| 581 |
+
"in_features": 2048,
|
| 582 |
+
"linear": {
|
| 583 |
+
"KV": 6,
|
| 584 |
+
"L": 16,
|
| 585 |
+
"V": 2,
|
| 586 |
+
"bias": false,
|
| 587 |
+
"in_features": 2048,
|
| 588 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 589 |
+
"linear_dtype": "float32",
|
| 590 |
+
"out_features": 8192,
|
| 591 |
+
"td_x": 16,
|
| 592 |
+
"td_y": 16,
|
| 593 |
+
"tlut_bits": 9
|
| 594 |
+
},
|
| 595 |
+
"module_type": "IncoherentLinear",
|
| 596 |
+
"out_features": 8192,
|
| 597 |
+
"rot_info": "skip_r",
|
| 598 |
+
"scale": 32.0
|
| 599 |
+
},
|
| 600 |
+
"model.layers.11.mlp.up_proj": {
|
| 601 |
+
"bias": false,
|
| 602 |
+
"dtype": "float32",
|
| 603 |
+
"hadU": 2048,
|
| 604 |
+
"hadV": 8192,
|
| 605 |
+
"in_features": 2048,
|
| 606 |
+
"linear": {
|
| 607 |
+
"KV": 6,
|
| 608 |
+
"L": 16,
|
| 609 |
+
"V": 2,
|
| 610 |
+
"bias": false,
|
| 611 |
+
"in_features": 2048,
|
| 612 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 613 |
+
"linear_dtype": "float32",
|
| 614 |
+
"out_features": 8192,
|
| 615 |
+
"td_x": 16,
|
| 616 |
+
"td_y": 16,
|
| 617 |
+
"tlut_bits": 9
|
| 618 |
+
},
|
| 619 |
+
"module_type": "IncoherentLinear",
|
| 620 |
+
"out_features": 8192,
|
| 621 |
+
"rot_info": "skip_r",
|
| 622 |
+
"scale": 32.0
|
| 623 |
+
},
|
| 624 |
+
"model.layers.11.self_attn.k_proj": {
|
| 625 |
+
"bias": false,
|
| 626 |
+
"dtype": "float32",
|
| 627 |
+
"hadU": 2048,
|
| 628 |
+
"hadV": 512,
|
| 629 |
+
"in_features": 2048,
|
| 630 |
+
"linear": {
|
| 631 |
+
"KV": [
|
| 632 |
+
8,
|
| 633 |
+
9
|
| 634 |
+
],
|
| 635 |
+
"L": 16,
|
| 636 |
+
"V": 2,
|
| 637 |
+
"bias": false,
|
| 638 |
+
"in_features": 2048,
|
| 639 |
+
"in_part": [
|
| 640 |
+
1024,
|
| 641 |
+
1024
|
| 642 |
+
],
|
| 643 |
+
"linear_cls": "CombtLinearTCQ",
|
| 644 |
+
"linear_dtype": "float32",
|
| 645 |
+
"out_features": 512,
|
| 646 |
+
"td_x": 16,
|
| 647 |
+
"td_y": 16,
|
| 648 |
+
"tlut_bits": 10
|
| 649 |
+
},
|
| 650 |
+
"module_type": "IncoherentLinear",
|
| 651 |
+
"out_features": 512,
|
| 652 |
+
"rot_info": "skip_r",
|
| 653 |
+
"scale": 32.0
|
| 654 |
+
},
|
| 655 |
+
"model.layers.11.self_attn.o_proj": {
|
| 656 |
+
"bias": false,
|
| 657 |
+
"dtype": "float32",
|
| 658 |
+
"hadU": 2048,
|
| 659 |
+
"hadV": 2048,
|
| 660 |
+
"in_features": 2048,
|
| 661 |
+
"linear": {
|
| 662 |
+
"KV": 7,
|
| 663 |
+
"L": 16,
|
| 664 |
+
"V": 2,
|
| 665 |
+
"bias": false,
|
| 666 |
+
"in_features": 2048,
|
| 667 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 668 |
+
"linear_dtype": "float32",
|
| 669 |
+
"out_features": 2048,
|
| 670 |
+
"td_x": 16,
|
| 671 |
+
"td_y": 16,
|
| 672 |
+
"tlut_bits": 9
|
| 673 |
+
},
|
| 674 |
+
"module_type": "IncoherentLinear",
|
| 675 |
+
"out_features": 2048,
|
| 676 |
+
"rot_info": "skip_r",
|
| 677 |
+
"scale": 32.0
|
| 678 |
+
},
|
| 679 |
+
"model.layers.11.self_attn.q_proj": {
|
| 680 |
+
"bias": false,
|
| 681 |
+
"dtype": "float32",
|
| 682 |
+
"hadU": 2048,
|
| 683 |
+
"hadV": 2048,
|
| 684 |
+
"in_features": 2048,
|
| 685 |
+
"linear": {
|
| 686 |
+
"KV": 7,
|
| 687 |
+
"L": 16,
|
| 688 |
+
"V": 2,
|
| 689 |
+
"bias": false,
|
| 690 |
+
"in_features": 2048,
|
| 691 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 692 |
+
"linear_dtype": "float32",
|
| 693 |
+
"out_features": 2048,
|
| 694 |
+
"td_x": 16,
|
| 695 |
+
"td_y": 16,
|
| 696 |
+
"tlut_bits": 9
|
| 697 |
+
},
|
| 698 |
+
"module_type": "IncoherentLinear",
|
| 699 |
+
"out_features": 2048,
|
| 700 |
+
"rot_info": "skip_r",
|
| 701 |
+
"scale": 32.0
|
| 702 |
+
},
|
| 703 |
+
"model.layers.11.self_attn.v_proj": {
|
| 704 |
+
"bias": false,
|
| 705 |
+
"dtype": "float32",
|
| 706 |
+
"hadU": 2048,
|
| 707 |
+
"hadV": 512,
|
| 708 |
+
"in_features": 2048,
|
| 709 |
+
"linear": {
|
| 710 |
+
"KV": 9,
|
| 711 |
+
"L": 16,
|
| 712 |
+
"V": 2,
|
| 713 |
+
"bias": false,
|
| 714 |
+
"in_features": 2048,
|
| 715 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 716 |
+
"linear_dtype": "float32",
|
| 717 |
+
"out_features": 512,
|
| 718 |
+
"td_x": 16,
|
| 719 |
+
"td_y": 16,
|
| 720 |
+
"tlut_bits": 10
|
| 721 |
+
},
|
| 722 |
+
"module_type": "IncoherentLinear",
|
| 723 |
+
"out_features": 512,
|
| 724 |
+
"rot_info": "skip_r",
|
| 725 |
+
"scale": 32.0
|
| 726 |
+
},
|
| 727 |
+
"model.layers.12.mlp.down_proj": {
|
| 728 |
+
"bias": false,
|
| 729 |
+
"dtype": "float32",
|
| 730 |
+
"hadU": 8192,
|
| 731 |
+
"hadV": 2048,
|
| 732 |
+
"in_features": 8192,
|
| 733 |
+
"linear": {
|
| 734 |
+
"KV": 6,
|
| 735 |
+
"L": 16,
|
| 736 |
+
"V": 2,
|
| 737 |
+
"bias": false,
|
| 738 |
+
"in_features": 8192,
|
| 739 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 740 |
+
"linear_dtype": "float32",
|
| 741 |
+
"out_features": 2048,
|
| 742 |
+
"td_x": 16,
|
| 743 |
+
"td_y": 16,
|
| 744 |
+
"tlut_bits": 9
|
| 745 |
+
},
|
| 746 |
+
"module_type": "IncoherentLinear",
|
| 747 |
+
"out_features": 2048,
|
| 748 |
+
"rot_info": "skip_r",
|
| 749 |
+
"scale": 32.0
|
| 750 |
+
},
|
| 751 |
+
"model.layers.12.mlp.gate_proj": {
|
| 752 |
+
"bias": false,
|
| 753 |
+
"dtype": "float32",
|
| 754 |
+
"hadU": 2048,
|
| 755 |
+
"hadV": 8192,
|
| 756 |
+
"in_features": 2048,
|
| 757 |
+
"linear": {
|
| 758 |
+
"KV": 6,
|
| 759 |
+
"L": 16,
|
| 760 |
+
"V": 2,
|
| 761 |
+
"bias": false,
|
| 762 |
+
"in_features": 2048,
|
| 763 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 764 |
+
"linear_dtype": "float32",
|
| 765 |
+
"out_features": 8192,
|
| 766 |
+
"td_x": 16,
|
| 767 |
+
"td_y": 16,
|
| 768 |
+
"tlut_bits": 9
|
| 769 |
+
},
|
| 770 |
+
"module_type": "IncoherentLinear",
|
| 771 |
+
"out_features": 8192,
|
| 772 |
+
"rot_info": "skip_r",
|
| 773 |
+
"scale": 32.0
|
| 774 |
+
},
|
| 775 |
+
"model.layers.12.mlp.up_proj": {
|
| 776 |
+
"bias": false,
|
| 777 |
+
"dtype": "float32",
|
| 778 |
+
"hadU": 2048,
|
| 779 |
+
"hadV": 8192,
|
| 780 |
+
"in_features": 2048,
|
| 781 |
+
"linear": {
|
| 782 |
+
"KV": 6,
|
| 783 |
+
"L": 16,
|
| 784 |
+
"V": 2,
|
| 785 |
+
"bias": false,
|
| 786 |
+
"in_features": 2048,
|
| 787 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 788 |
+
"linear_dtype": "float32",
|
| 789 |
+
"out_features": 8192,
|
| 790 |
+
"td_x": 16,
|
| 791 |
+
"td_y": 16,
|
| 792 |
+
"tlut_bits": 9
|
| 793 |
+
},
|
| 794 |
+
"module_type": "IncoherentLinear",
|
| 795 |
+
"out_features": 8192,
|
| 796 |
+
"rot_info": "skip_r",
|
| 797 |
+
"scale": 32.0
|
| 798 |
+
},
|
| 799 |
+
"model.layers.12.self_attn.k_proj": {
|
| 800 |
+
"bias": false,
|
| 801 |
+
"dtype": "float32",
|
| 802 |
+
"hadU": 2048,
|
| 803 |
+
"hadV": 512,
|
| 804 |
+
"in_features": 2048,
|
| 805 |
+
"linear": {
|
| 806 |
+
"KV": 9,
|
| 807 |
+
"L": 16,
|
| 808 |
+
"V": 2,
|
| 809 |
+
"bias": false,
|
| 810 |
+
"in_features": 2048,
|
| 811 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 812 |
+
"linear_dtype": "float32",
|
| 813 |
+
"out_features": 512,
|
| 814 |
+
"td_x": 16,
|
| 815 |
+
"td_y": 16,
|
| 816 |
+
"tlut_bits": 10
|
| 817 |
+
},
|
| 818 |
+
"module_type": "IncoherentLinear",
|
| 819 |
+
"out_features": 512,
|
| 820 |
+
"rot_info": "skip_r",
|
| 821 |
+
"scale": 32.0
|
| 822 |
+
},
|
| 823 |
+
"model.layers.12.self_attn.o_proj": {
|
| 824 |
+
"bias": false,
|
| 825 |
+
"dtype": "float32",
|
| 826 |
+
"hadU": 2048,
|
| 827 |
+
"hadV": 2048,
|
| 828 |
+
"in_features": 2048,
|
| 829 |
+
"linear": {
|
| 830 |
+
"KV": 7,
|
| 831 |
+
"L": 16,
|
| 832 |
+
"V": 2,
|
| 833 |
+
"bias": false,
|
| 834 |
+
"in_features": 2048,
|
| 835 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 836 |
+
"linear_dtype": "float32",
|
| 837 |
+
"out_features": 2048,
|
| 838 |
+
"td_x": 16,
|
| 839 |
+
"td_y": 16,
|
| 840 |
+
"tlut_bits": 9
|
| 841 |
+
},
|
| 842 |
+
"module_type": "IncoherentLinear",
|
| 843 |
+
"out_features": 2048,
|
| 844 |
+
"rot_info": "skip_r",
|
| 845 |
+
"scale": 32.0
|
| 846 |
+
},
|
| 847 |
+
"model.layers.12.self_attn.q_proj": {
|
| 848 |
+
"bias": false,
|
| 849 |
+
"dtype": "float32",
|
| 850 |
+
"hadU": 2048,
|
| 851 |
+
"hadV": 2048,
|
| 852 |
+
"in_features": 2048,
|
| 853 |
+
"linear": {
|
| 854 |
+
"KV": 7,
|
| 855 |
+
"L": 16,
|
| 856 |
+
"V": 2,
|
| 857 |
+
"bias": false,
|
| 858 |
+
"in_features": 2048,
|
| 859 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 860 |
+
"linear_dtype": "float32",
|
| 861 |
+
"out_features": 2048,
|
| 862 |
+
"td_x": 16,
|
| 863 |
+
"td_y": 16,
|
| 864 |
+
"tlut_bits": 9
|
| 865 |
+
},
|
| 866 |
+
"module_type": "IncoherentLinear",
|
| 867 |
+
"out_features": 2048,
|
| 868 |
+
"rot_info": "skip_r",
|
| 869 |
+
"scale": 32.0
|
| 870 |
+
},
|
| 871 |
+
"model.layers.12.self_attn.v_proj": {
|
| 872 |
+
"bias": false,
|
| 873 |
+
"dtype": "float32",
|
| 874 |
+
"hadU": 2048,
|
| 875 |
+
"hadV": 512,
|
| 876 |
+
"in_features": 2048,
|
| 877 |
+
"linear": {
|
| 878 |
+
"KV": 9,
|
| 879 |
+
"L": 16,
|
| 880 |
+
"V": 2,
|
| 881 |
+
"bias": false,
|
| 882 |
+
"in_features": 2048,
|
| 883 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 884 |
+
"linear_dtype": "float32",
|
| 885 |
+
"out_features": 512,
|
| 886 |
+
"td_x": 16,
|
| 887 |
+
"td_y": 16,
|
| 888 |
+
"tlut_bits": 10
|
| 889 |
+
},
|
| 890 |
+
"module_type": "IncoherentLinear",
|
| 891 |
+
"out_features": 512,
|
| 892 |
+
"rot_info": "skip_r",
|
| 893 |
+
"scale": 32.0
|
| 894 |
+
},
|
| 895 |
+
"model.layers.13.mlp.down_proj": {
|
| 896 |
+
"bias": false,
|
| 897 |
+
"dtype": "float32",
|
| 898 |
+
"hadU": 8192,
|
| 899 |
+
"hadV": 2048,
|
| 900 |
+
"in_features": 8192,
|
| 901 |
+
"linear": {
|
| 902 |
+
"KV": 6,
|
| 903 |
+
"L": 16,
|
| 904 |
+
"V": 2,
|
| 905 |
+
"bias": false,
|
| 906 |
+
"in_features": 8192,
|
| 907 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 908 |
+
"linear_dtype": "float32",
|
| 909 |
+
"out_features": 2048,
|
| 910 |
+
"td_x": 16,
|
| 911 |
+
"td_y": 16,
|
| 912 |
+
"tlut_bits": 9
|
| 913 |
+
},
|
| 914 |
+
"module_type": "IncoherentLinear",
|
| 915 |
+
"out_features": 2048,
|
| 916 |
+
"rot_info": "skip_r",
|
| 917 |
+
"scale": 32.0
|
| 918 |
+
},
|
| 919 |
+
"model.layers.13.mlp.gate_proj": {
|
| 920 |
+
"bias": false,
|
| 921 |
+
"dtype": "float32",
|
| 922 |
+
"hadU": 2048,
|
| 923 |
+
"hadV": 8192,
|
| 924 |
+
"in_features": 2048,
|
| 925 |
+
"linear": {
|
| 926 |
+
"KV": 6,
|
| 927 |
+
"L": 16,
|
| 928 |
+
"V": 2,
|
| 929 |
+
"bias": false,
|
| 930 |
+
"in_features": 2048,
|
| 931 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 932 |
+
"linear_dtype": "float32",
|
| 933 |
+
"out_features": 8192,
|
| 934 |
+
"td_x": 16,
|
| 935 |
+
"td_y": 16,
|
| 936 |
+
"tlut_bits": 9
|
| 937 |
+
},
|
| 938 |
+
"module_type": "IncoherentLinear",
|
| 939 |
+
"out_features": 8192,
|
| 940 |
+
"rot_info": "skip_r",
|
| 941 |
+
"scale": 32.0
|
| 942 |
+
},
|
| 943 |
+
"model.layers.13.mlp.up_proj": {
|
| 944 |
+
"bias": false,
|
| 945 |
+
"dtype": "float32",
|
| 946 |
+
"hadU": 2048,
|
| 947 |
+
"hadV": 8192,
|
| 948 |
+
"in_features": 2048,
|
| 949 |
+
"linear": {
|
| 950 |
+
"KV": 6,
|
| 951 |
+
"L": 16,
|
| 952 |
+
"V": 2,
|
| 953 |
+
"bias": false,
|
| 954 |
+
"in_features": 2048,
|
| 955 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 956 |
+
"linear_dtype": "float32",
|
| 957 |
+
"out_features": 8192,
|
| 958 |
+
"td_x": 16,
|
| 959 |
+
"td_y": 16,
|
| 960 |
+
"tlut_bits": 9
|
| 961 |
+
},
|
| 962 |
+
"module_type": "IncoherentLinear",
|
| 963 |
+
"out_features": 8192,
|
| 964 |
+
"rot_info": "skip_r",
|
| 965 |
+
"scale": 32.0
|
| 966 |
+
},
|
| 967 |
+
"model.layers.13.self_attn.k_proj": {
|
| 968 |
+
"bias": false,
|
| 969 |
+
"dtype": "float32",
|
| 970 |
+
"hadU": 2048,
|
| 971 |
+
"hadV": 512,
|
| 972 |
+
"in_features": 2048,
|
| 973 |
+
"linear": {
|
| 974 |
+
"KV": 8,
|
| 975 |
+
"L": 16,
|
| 976 |
+
"V": 2,
|
| 977 |
+
"bias": false,
|
| 978 |
+
"in_features": 2048,
|
| 979 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 980 |
+
"linear_dtype": "float32",
|
| 981 |
+
"out_features": 512,
|
| 982 |
+
"td_x": 16,
|
| 983 |
+
"td_y": 16,
|
| 984 |
+
"tlut_bits": 9
|
| 985 |
+
},
|
| 986 |
+
"module_type": "IncoherentLinear",
|
| 987 |
+
"out_features": 512,
|
| 988 |
+
"rot_info": "skip_r",
|
| 989 |
+
"scale": 32.0
|
| 990 |
+
},
|
| 991 |
+
"model.layers.13.self_attn.o_proj": {
|
| 992 |
+
"bias": false,
|
| 993 |
+
"dtype": "float32",
|
| 994 |
+
"hadU": 2048,
|
| 995 |
+
"hadV": 2048,
|
| 996 |
+
"in_features": 2048,
|
| 997 |
+
"linear": {
|
| 998 |
+
"KV": 7,
|
| 999 |
+
"L": 16,
|
| 1000 |
+
"V": 2,
|
| 1001 |
+
"bias": false,
|
| 1002 |
+
"in_features": 2048,
|
| 1003 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1004 |
+
"linear_dtype": "float32",
|
| 1005 |
+
"out_features": 2048,
|
| 1006 |
+
"td_x": 16,
|
| 1007 |
+
"td_y": 16,
|
| 1008 |
+
"tlut_bits": 9
|
| 1009 |
+
},
|
| 1010 |
+
"module_type": "IncoherentLinear",
|
| 1011 |
+
"out_features": 2048,
|
| 1012 |
+
"rot_info": "skip_r",
|
| 1013 |
+
"scale": 32.0
|
| 1014 |
+
},
|
| 1015 |
+
"model.layers.13.self_attn.q_proj": {
|
| 1016 |
+
"bias": false,
|
| 1017 |
+
"dtype": "float32",
|
| 1018 |
+
"hadU": 2048,
|
| 1019 |
+
"hadV": 2048,
|
| 1020 |
+
"in_features": 2048,
|
| 1021 |
+
"linear": {
|
| 1022 |
+
"KV": 8,
|
| 1023 |
+
"L": 16,
|
| 1024 |
+
"V": 2,
|
| 1025 |
+
"bias": false,
|
| 1026 |
+
"in_features": 2048,
|
| 1027 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1028 |
+
"linear_dtype": "float32",
|
| 1029 |
+
"out_features": 2048,
|
| 1030 |
+
"td_x": 16,
|
| 1031 |
+
"td_y": 16,
|
| 1032 |
+
"tlut_bits": 9
|
| 1033 |
+
},
|
| 1034 |
+
"module_type": "IncoherentLinear",
|
| 1035 |
+
"out_features": 2048,
|
| 1036 |
+
"rot_info": "skip_r",
|
| 1037 |
+
"scale": 32.0
|
| 1038 |
+
},
|
| 1039 |
+
"model.layers.13.self_attn.v_proj": {
|
| 1040 |
+
"bias": false,
|
| 1041 |
+
"dtype": "float32",
|
| 1042 |
+
"hadU": 2048,
|
| 1043 |
+
"hadV": 512,
|
| 1044 |
+
"in_features": 2048,
|
| 1045 |
+
"linear": {
|
| 1046 |
+
"KV": 9,
|
| 1047 |
+
"L": 16,
|
| 1048 |
+
"V": 2,
|
| 1049 |
+
"bias": false,
|
| 1050 |
+
"in_features": 2048,
|
| 1051 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1052 |
+
"linear_dtype": "float32",
|
| 1053 |
+
"out_features": 512,
|
| 1054 |
+
"td_x": 16,
|
| 1055 |
+
"td_y": 16,
|
| 1056 |
+
"tlut_bits": 10
|
| 1057 |
+
},
|
| 1058 |
+
"module_type": "IncoherentLinear",
|
| 1059 |
+
"out_features": 512,
|
| 1060 |
+
"rot_info": "skip_r",
|
| 1061 |
+
"scale": 32.0
|
| 1062 |
+
},
|
| 1063 |
+
"model.layers.14.mlp.down_proj": {
|
| 1064 |
+
"bias": false,
|
| 1065 |
+
"dtype": "float32",
|
| 1066 |
+
"hadU": 8192,
|
| 1067 |
+
"hadV": 2048,
|
| 1068 |
+
"in_features": 8192,
|
| 1069 |
+
"linear": {
|
| 1070 |
+
"KV": 7,
|
| 1071 |
+
"L": 16,
|
| 1072 |
+
"V": 2,
|
| 1073 |
+
"bias": false,
|
| 1074 |
+
"in_features": 8192,
|
| 1075 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1076 |
+
"linear_dtype": "float32",
|
| 1077 |
+
"out_features": 2048,
|
| 1078 |
+
"td_x": 16,
|
| 1079 |
+
"td_y": 16,
|
| 1080 |
+
"tlut_bits": 9
|
| 1081 |
+
},
|
| 1082 |
+
"module_type": "IncoherentLinear",
|
| 1083 |
+
"out_features": 2048,
|
| 1084 |
+
"rot_info": "skip_r",
|
| 1085 |
+
"scale": 32.0
|
| 1086 |
+
},
|
| 1087 |
+
"model.layers.14.mlp.gate_proj": {
|
| 1088 |
+
"bias": false,
|
| 1089 |
+
"dtype": "float32",
|
| 1090 |
+
"hadU": 2048,
|
| 1091 |
+
"hadV": 8192,
|
| 1092 |
+
"in_features": 2048,
|
| 1093 |
+
"linear": {
|
| 1094 |
+
"KV": 6,
|
| 1095 |
+
"L": 16,
|
| 1096 |
+
"V": 2,
|
| 1097 |
+
"bias": false,
|
| 1098 |
+
"in_features": 2048,
|
| 1099 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1100 |
+
"linear_dtype": "float32",
|
| 1101 |
+
"out_features": 8192,
|
| 1102 |
+
"td_x": 16,
|
| 1103 |
+
"td_y": 16,
|
| 1104 |
+
"tlut_bits": 9
|
| 1105 |
+
},
|
| 1106 |
+
"module_type": "IncoherentLinear",
|
| 1107 |
+
"out_features": 8192,
|
| 1108 |
+
"rot_info": "skip_r",
|
| 1109 |
+
"scale": 32.0
|
| 1110 |
+
},
|
| 1111 |
+
"model.layers.14.mlp.up_proj": {
|
| 1112 |
+
"bias": false,
|
| 1113 |
+
"dtype": "float32",
|
| 1114 |
+
"hadU": 2048,
|
| 1115 |
+
"hadV": 8192,
|
| 1116 |
+
"in_features": 2048,
|
| 1117 |
+
"linear": {
|
| 1118 |
+
"KV": 6,
|
| 1119 |
+
"L": 16,
|
| 1120 |
+
"V": 2,
|
| 1121 |
+
"bias": false,
|
| 1122 |
+
"in_features": 2048,
|
| 1123 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1124 |
+
"linear_dtype": "float32",
|
| 1125 |
+
"out_features": 8192,
|
| 1126 |
+
"td_x": 16,
|
| 1127 |
+
"td_y": 16,
|
| 1128 |
+
"tlut_bits": 9
|
| 1129 |
+
},
|
| 1130 |
+
"module_type": "IncoherentLinear",
|
| 1131 |
+
"out_features": 8192,
|
| 1132 |
+
"rot_info": "skip_r",
|
| 1133 |
+
"scale": 32.0
|
| 1134 |
+
},
|
| 1135 |
+
"model.layers.14.self_attn.k_proj": {
|
| 1136 |
+
"bias": false,
|
| 1137 |
+
"dtype": "float32",
|
| 1138 |
+
"hadU": 2048,
|
| 1139 |
+
"hadV": 512,
|
| 1140 |
+
"in_features": 2048,
|
| 1141 |
+
"linear": {
|
| 1142 |
+
"KV": [
|
| 1143 |
+
8,
|
| 1144 |
+
9
|
| 1145 |
+
],
|
| 1146 |
+
"L": 16,
|
| 1147 |
+
"V": 2,
|
| 1148 |
+
"bias": false,
|
| 1149 |
+
"in_features": 2048,
|
| 1150 |
+
"in_part": [
|
| 1151 |
+
1024,
|
| 1152 |
+
1024
|
| 1153 |
+
],
|
| 1154 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1155 |
+
"linear_dtype": "float32",
|
| 1156 |
+
"out_features": 512,
|
| 1157 |
+
"td_x": 16,
|
| 1158 |
+
"td_y": 16,
|
| 1159 |
+
"tlut_bits": 10
|
| 1160 |
+
},
|
| 1161 |
+
"module_type": "IncoherentLinear",
|
| 1162 |
+
"out_features": 512,
|
| 1163 |
+
"rot_info": "skip_r",
|
| 1164 |
+
"scale": 32.0
|
| 1165 |
+
},
|
| 1166 |
+
"model.layers.14.self_attn.o_proj": {
|
| 1167 |
+
"bias": false,
|
| 1168 |
+
"dtype": "float32",
|
| 1169 |
+
"hadU": 2048,
|
| 1170 |
+
"hadV": 2048,
|
| 1171 |
+
"in_features": 2048,
|
| 1172 |
+
"linear": {
|
| 1173 |
+
"KV": 7,
|
| 1174 |
+
"L": 16,
|
| 1175 |
+
"V": 2,
|
| 1176 |
+
"bias": false,
|
| 1177 |
+
"in_features": 2048,
|
| 1178 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1179 |
+
"linear_dtype": "float32",
|
| 1180 |
+
"out_features": 2048,
|
| 1181 |
+
"td_x": 16,
|
| 1182 |
+
"td_y": 16,
|
| 1183 |
+
"tlut_bits": 9
|
| 1184 |
+
},
|
| 1185 |
+
"module_type": "IncoherentLinear",
|
| 1186 |
+
"out_features": 2048,
|
| 1187 |
+
"rot_info": "skip_r",
|
| 1188 |
+
"scale": 32.0
|
| 1189 |
+
},
|
| 1190 |
+
"model.layers.14.self_attn.q_proj": {
|
| 1191 |
+
"bias": false,
|
| 1192 |
+
"dtype": "float32",
|
| 1193 |
+
"hadU": 2048,
|
| 1194 |
+
"hadV": 2048,
|
| 1195 |
+
"in_features": 2048,
|
| 1196 |
+
"linear": {
|
| 1197 |
+
"KV": 6,
|
| 1198 |
+
"L": 16,
|
| 1199 |
+
"V": 2,
|
| 1200 |
+
"bias": false,
|
| 1201 |
+
"in_features": 2048,
|
| 1202 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1203 |
+
"linear_dtype": "float32",
|
| 1204 |
+
"out_features": 2048,
|
| 1205 |
+
"td_x": 16,
|
| 1206 |
+
"td_y": 16,
|
| 1207 |
+
"tlut_bits": 9
|
| 1208 |
+
},
|
| 1209 |
+
"module_type": "IncoherentLinear",
|
| 1210 |
+
"out_features": 2048,
|
| 1211 |
+
"rot_info": "skip_r",
|
| 1212 |
+
"scale": 32.0
|
| 1213 |
+
},
|
| 1214 |
+
"model.layers.14.self_attn.v_proj": {
|
| 1215 |
+
"bias": false,
|
| 1216 |
+
"dtype": "float32",
|
| 1217 |
+
"hadU": 2048,
|
| 1218 |
+
"hadV": 512,
|
| 1219 |
+
"in_features": 2048,
|
| 1220 |
+
"linear": {
|
| 1221 |
+
"KV": [
|
| 1222 |
+
9,
|
| 1223 |
+
10
|
| 1224 |
+
],
|
| 1225 |
+
"L": 16,
|
| 1226 |
+
"V": 2,
|
| 1227 |
+
"bias": false,
|
| 1228 |
+
"in_features": 2048,
|
| 1229 |
+
"in_part": [
|
| 1230 |
+
1024,
|
| 1231 |
+
1024
|
| 1232 |
+
],
|
| 1233 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1234 |
+
"linear_dtype": "float32",
|
| 1235 |
+
"out_features": 512,
|
| 1236 |
+
"td_x": 16,
|
| 1237 |
+
"td_y": 16,
|
| 1238 |
+
"tlut_bits": 11
|
| 1239 |
+
},
|
| 1240 |
+
"module_type": "IncoherentLinear",
|
| 1241 |
+
"out_features": 512,
|
| 1242 |
+
"rot_info": "skip_r",
|
| 1243 |
+
"scale": 32.0
|
| 1244 |
+
},
|
| 1245 |
+
"model.layers.15.mlp.down_proj": {
|
| 1246 |
+
"bias": false,
|
| 1247 |
+
"dtype": "float32",
|
| 1248 |
+
"hadU": 8192,
|
| 1249 |
+
"hadV": 2048,
|
| 1250 |
+
"in_features": 8192,
|
| 1251 |
+
"linear": {
|
| 1252 |
+
"KV": [
|
| 1253 |
+
8,
|
| 1254 |
+
9
|
| 1255 |
+
],
|
| 1256 |
+
"L": 16,
|
| 1257 |
+
"V": 2,
|
| 1258 |
+
"bias": false,
|
| 1259 |
+
"in_features": 8192,
|
| 1260 |
+
"in_part": [
|
| 1261 |
+
4096,
|
| 1262 |
+
4096
|
| 1263 |
+
],
|
| 1264 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1265 |
+
"linear_dtype": "float32",
|
| 1266 |
+
"out_features": 2048,
|
| 1267 |
+
"td_x": 16,
|
| 1268 |
+
"td_y": 16,
|
| 1269 |
+
"tlut_bits": 10
|
| 1270 |
+
},
|
| 1271 |
+
"module_type": "IncoherentLinear",
|
| 1272 |
+
"out_features": 2048,
|
| 1273 |
+
"rot_info": "skip_r",
|
| 1274 |
+
"scale": 32.0
|
| 1275 |
+
},
|
| 1276 |
+
"model.layers.15.mlp.gate_proj": {
|
| 1277 |
+
"bias": false,
|
| 1278 |
+
"dtype": "float32",
|
| 1279 |
+
"hadU": 2048,
|
| 1280 |
+
"hadV": 8192,
|
| 1281 |
+
"in_features": 2048,
|
| 1282 |
+
"linear": {
|
| 1283 |
+
"KV": 6,
|
| 1284 |
+
"L": 16,
|
| 1285 |
+
"V": 2,
|
| 1286 |
+
"bias": false,
|
| 1287 |
+
"in_features": 2048,
|
| 1288 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1289 |
+
"linear_dtype": "float32",
|
| 1290 |
+
"out_features": 8192,
|
| 1291 |
+
"td_x": 16,
|
| 1292 |
+
"td_y": 16,
|
| 1293 |
+
"tlut_bits": 9
|
| 1294 |
+
},
|
| 1295 |
+
"module_type": "IncoherentLinear",
|
| 1296 |
+
"out_features": 8192,
|
| 1297 |
+
"rot_info": "skip_r",
|
| 1298 |
+
"scale": 32.0
|
| 1299 |
+
},
|
| 1300 |
+
"model.layers.15.mlp.up_proj": {
|
| 1301 |
+
"bias": false,
|
| 1302 |
+
"dtype": "float32",
|
| 1303 |
+
"hadU": 2048,
|
| 1304 |
+
"hadV": 8192,
|
| 1305 |
+
"in_features": 2048,
|
| 1306 |
+
"linear": {
|
| 1307 |
+
"KV": 7,
|
| 1308 |
+
"L": 16,
|
| 1309 |
+
"V": 2,
|
| 1310 |
+
"bias": false,
|
| 1311 |
+
"in_features": 2048,
|
| 1312 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1313 |
+
"linear_dtype": "float32",
|
| 1314 |
+
"out_features": 8192,
|
| 1315 |
+
"td_x": 16,
|
| 1316 |
+
"td_y": 16,
|
| 1317 |
+
"tlut_bits": 9
|
| 1318 |
+
},
|
| 1319 |
+
"module_type": "IncoherentLinear",
|
| 1320 |
+
"out_features": 8192,
|
| 1321 |
+
"rot_info": "skip_r",
|
| 1322 |
+
"scale": 32.0
|
| 1323 |
+
},
|
| 1324 |
+
"model.layers.15.self_attn.k_proj": {
|
| 1325 |
+
"bias": false,
|
| 1326 |
+
"dtype": "float32",
|
| 1327 |
+
"hadU": 2048,
|
| 1328 |
+
"hadV": 512,
|
| 1329 |
+
"in_features": 2048,
|
| 1330 |
+
"linear": {
|
| 1331 |
+
"KV": [
|
| 1332 |
+
8,
|
| 1333 |
+
9
|
| 1334 |
+
],
|
| 1335 |
+
"L": 16,
|
| 1336 |
+
"V": 2,
|
| 1337 |
+
"bias": false,
|
| 1338 |
+
"in_features": 2048,
|
| 1339 |
+
"in_part": [
|
| 1340 |
+
1024,
|
| 1341 |
+
1024
|
| 1342 |
+
],
|
| 1343 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1344 |
+
"linear_dtype": "float32",
|
| 1345 |
+
"out_features": 512,
|
| 1346 |
+
"td_x": 16,
|
| 1347 |
+
"td_y": 16,
|
| 1348 |
+
"tlut_bits": 10
|
| 1349 |
+
},
|
| 1350 |
+
"module_type": "IncoherentLinear",
|
| 1351 |
+
"out_features": 512,
|
| 1352 |
+
"rot_info": "skip_r",
|
| 1353 |
+
"scale": 32.0
|
| 1354 |
+
},
|
| 1355 |
+
"model.layers.15.self_attn.o_proj": {
|
| 1356 |
+
"bias": false,
|
| 1357 |
+
"dtype": "float32",
|
| 1358 |
+
"hadU": 2048,
|
| 1359 |
+
"hadV": 2048,
|
| 1360 |
+
"in_features": 2048,
|
| 1361 |
+
"linear": {
|
| 1362 |
+
"KV": [
|
| 1363 |
+
8,
|
| 1364 |
+
9
|
| 1365 |
+
],
|
| 1366 |
+
"L": 16,
|
| 1367 |
+
"V": 2,
|
| 1368 |
+
"bias": false,
|
| 1369 |
+
"in_features": 2048,
|
| 1370 |
+
"in_part": [
|
| 1371 |
+
1024,
|
| 1372 |
+
1024
|
| 1373 |
+
],
|
| 1374 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1375 |
+
"linear_dtype": "float32",
|
| 1376 |
+
"out_features": 2048,
|
| 1377 |
+
"td_x": 16,
|
| 1378 |
+
"td_y": 16,
|
| 1379 |
+
"tlut_bits": 10
|
| 1380 |
+
},
|
| 1381 |
+
"module_type": "IncoherentLinear",
|
| 1382 |
+
"out_features": 2048,
|
| 1383 |
+
"rot_info": "skip_r",
|
| 1384 |
+
"scale": 32.0
|
| 1385 |
+
},
|
| 1386 |
+
"model.layers.15.self_attn.q_proj": {
|
| 1387 |
+
"bias": false,
|
| 1388 |
+
"dtype": "float32",
|
| 1389 |
+
"hadU": 2048,
|
| 1390 |
+
"hadV": 2048,
|
| 1391 |
+
"in_features": 2048,
|
| 1392 |
+
"linear": {
|
| 1393 |
+
"KV": 6,
|
| 1394 |
+
"L": 16,
|
| 1395 |
+
"V": 2,
|
| 1396 |
+
"bias": false,
|
| 1397 |
+
"in_features": 2048,
|
| 1398 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1399 |
+
"linear_dtype": "float32",
|
| 1400 |
+
"out_features": 2048,
|
| 1401 |
+
"td_x": 16,
|
| 1402 |
+
"td_y": 16,
|
| 1403 |
+
"tlut_bits": 9
|
| 1404 |
+
},
|
| 1405 |
+
"module_type": "IncoherentLinear",
|
| 1406 |
+
"out_features": 2048,
|
| 1407 |
+
"rot_info": "skip_r",
|
| 1408 |
+
"scale": 32.0
|
| 1409 |
+
},
|
| 1410 |
+
"model.layers.15.self_attn.v_proj": {
|
| 1411 |
+
"bias": false,
|
| 1412 |
+
"dtype": "float32",
|
| 1413 |
+
"hadU": 2048,
|
| 1414 |
+
"hadV": 512,
|
| 1415 |
+
"in_features": 2048,
|
| 1416 |
+
"linear": {
|
| 1417 |
+
"KV": 9,
|
| 1418 |
+
"L": 16,
|
| 1419 |
+
"V": 2,
|
| 1420 |
+
"bias": false,
|
| 1421 |
+
"in_features": 2048,
|
| 1422 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1423 |
+
"linear_dtype": "float32",
|
| 1424 |
+
"out_features": 512,
|
| 1425 |
+
"td_x": 16,
|
| 1426 |
+
"td_y": 16,
|
| 1427 |
+
"tlut_bits": 10
|
| 1428 |
+
},
|
| 1429 |
+
"module_type": "IncoherentLinear",
|
| 1430 |
+
"out_features": 512,
|
| 1431 |
+
"rot_info": "skip_r",
|
| 1432 |
+
"scale": 32.0
|
| 1433 |
+
},
|
| 1434 |
+
"model.layers.2.mlp.down_proj": {
|
| 1435 |
+
"bias": false,
|
| 1436 |
+
"dtype": "float32",
|
| 1437 |
+
"hadU": 8192,
|
| 1438 |
+
"hadV": 2048,
|
| 1439 |
+
"in_features": 8192,
|
| 1440 |
+
"linear": {
|
| 1441 |
+
"KV": 6,
|
| 1442 |
+
"L": 16,
|
| 1443 |
+
"V": 2,
|
| 1444 |
+
"bias": false,
|
| 1445 |
+
"in_features": 8192,
|
| 1446 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1447 |
+
"linear_dtype": "float32",
|
| 1448 |
+
"out_features": 2048,
|
| 1449 |
+
"td_x": 16,
|
| 1450 |
+
"td_y": 16,
|
| 1451 |
+
"tlut_bits": 9
|
| 1452 |
+
},
|
| 1453 |
+
"module_type": "IncoherentLinear",
|
| 1454 |
+
"out_features": 2048,
|
| 1455 |
+
"rot_info": "skip_r",
|
| 1456 |
+
"scale": 32.0
|
| 1457 |
+
},
|
| 1458 |
+
"model.layers.2.mlp.gate_proj": {
|
| 1459 |
+
"bias": false,
|
| 1460 |
+
"dtype": "float32",
|
| 1461 |
+
"hadU": 2048,
|
| 1462 |
+
"hadV": 8192,
|
| 1463 |
+
"in_features": 2048,
|
| 1464 |
+
"linear": {
|
| 1465 |
+
"KV": 5,
|
| 1466 |
+
"L": 16,
|
| 1467 |
+
"V": 2,
|
| 1468 |
+
"bias": false,
|
| 1469 |
+
"in_features": 2048,
|
| 1470 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1471 |
+
"linear_dtype": "float32",
|
| 1472 |
+
"out_features": 8192,
|
| 1473 |
+
"td_x": 16,
|
| 1474 |
+
"td_y": 16,
|
| 1475 |
+
"tlut_bits": 9
|
| 1476 |
+
},
|
| 1477 |
+
"module_type": "IncoherentLinear",
|
| 1478 |
+
"out_features": 8192,
|
| 1479 |
+
"rot_info": "skip_r",
|
| 1480 |
+
"scale": 32.0
|
| 1481 |
+
},
|
| 1482 |
+
"model.layers.2.mlp.up_proj": {
|
| 1483 |
+
"bias": false,
|
| 1484 |
+
"dtype": "float32",
|
| 1485 |
+
"hadU": 2048,
|
| 1486 |
+
"hadV": 8192,
|
| 1487 |
+
"in_features": 2048,
|
| 1488 |
+
"linear": {
|
| 1489 |
+
"KV": 6,
|
| 1490 |
+
"L": 16,
|
| 1491 |
+
"V": 2,
|
| 1492 |
+
"bias": false,
|
| 1493 |
+
"in_features": 2048,
|
| 1494 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1495 |
+
"linear_dtype": "float32",
|
| 1496 |
+
"out_features": 8192,
|
| 1497 |
+
"td_x": 16,
|
| 1498 |
+
"td_y": 16,
|
| 1499 |
+
"tlut_bits": 9
|
| 1500 |
+
},
|
| 1501 |
+
"module_type": "IncoherentLinear",
|
| 1502 |
+
"out_features": 8192,
|
| 1503 |
+
"rot_info": "skip_r",
|
| 1504 |
+
"scale": 32.0
|
| 1505 |
+
},
|
| 1506 |
+
"model.layers.2.self_attn.k_proj": {
|
| 1507 |
+
"bias": false,
|
| 1508 |
+
"dtype": "float32",
|
| 1509 |
+
"hadU": 2048,
|
| 1510 |
+
"hadV": 512,
|
| 1511 |
+
"in_features": 2048,
|
| 1512 |
+
"linear": {
|
| 1513 |
+
"KV": [
|
| 1514 |
+
8,
|
| 1515 |
+
9
|
| 1516 |
+
],
|
| 1517 |
+
"L": 16,
|
| 1518 |
+
"V": 2,
|
| 1519 |
+
"bias": false,
|
| 1520 |
+
"in_features": 2048,
|
| 1521 |
+
"in_part": [
|
| 1522 |
+
1024,
|
| 1523 |
+
1024
|
| 1524 |
+
],
|
| 1525 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1526 |
+
"linear_dtype": "float32",
|
| 1527 |
+
"out_features": 512,
|
| 1528 |
+
"td_x": 16,
|
| 1529 |
+
"td_y": 16,
|
| 1530 |
+
"tlut_bits": 10
|
| 1531 |
+
},
|
| 1532 |
+
"module_type": "IncoherentLinear",
|
| 1533 |
+
"out_features": 512,
|
| 1534 |
+
"rot_info": "skip_r",
|
| 1535 |
+
"scale": 32.0
|
| 1536 |
+
},
|
| 1537 |
+
"model.layers.2.self_attn.o_proj": {
|
| 1538 |
+
"bias": false,
|
| 1539 |
+
"dtype": "float32",
|
| 1540 |
+
"hadU": 2048,
|
| 1541 |
+
"hadV": 2048,
|
| 1542 |
+
"in_features": 2048,
|
| 1543 |
+
"linear": {
|
| 1544 |
+
"KV": 7,
|
| 1545 |
+
"L": 16,
|
| 1546 |
+
"V": 2,
|
| 1547 |
+
"bias": false,
|
| 1548 |
+
"in_features": 2048,
|
| 1549 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1550 |
+
"linear_dtype": "float32",
|
| 1551 |
+
"out_features": 2048,
|
| 1552 |
+
"td_x": 16,
|
| 1553 |
+
"td_y": 16,
|
| 1554 |
+
"tlut_bits": 9
|
| 1555 |
+
},
|
| 1556 |
+
"module_type": "IncoherentLinear",
|
| 1557 |
+
"out_features": 2048,
|
| 1558 |
+
"rot_info": "skip_r",
|
| 1559 |
+
"scale": 32.0
|
| 1560 |
+
},
|
| 1561 |
+
"model.layers.2.self_attn.q_proj": {
|
| 1562 |
+
"bias": false,
|
| 1563 |
+
"dtype": "float32",
|
| 1564 |
+
"hadU": 2048,
|
| 1565 |
+
"hadV": 2048,
|
| 1566 |
+
"in_features": 2048,
|
| 1567 |
+
"linear": {
|
| 1568 |
+
"KV": 7,
|
| 1569 |
+
"L": 16,
|
| 1570 |
+
"V": 2,
|
| 1571 |
+
"bias": false,
|
| 1572 |
+
"in_features": 2048,
|
| 1573 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1574 |
+
"linear_dtype": "float32",
|
| 1575 |
+
"out_features": 2048,
|
| 1576 |
+
"td_x": 16,
|
| 1577 |
+
"td_y": 16,
|
| 1578 |
+
"tlut_bits": 9
|
| 1579 |
+
},
|
| 1580 |
+
"module_type": "IncoherentLinear",
|
| 1581 |
+
"out_features": 2048,
|
| 1582 |
+
"rot_info": "skip_r",
|
| 1583 |
+
"scale": 32.0
|
| 1584 |
+
},
|
| 1585 |
+
"model.layers.2.self_attn.v_proj": {
|
| 1586 |
+
"bias": false,
|
| 1587 |
+
"dtype": "float32",
|
| 1588 |
+
"hadU": 2048,
|
| 1589 |
+
"hadV": 512,
|
| 1590 |
+
"in_features": 2048,
|
| 1591 |
+
"linear": {
|
| 1592 |
+
"KV": [
|
| 1593 |
+
9,
|
| 1594 |
+
10
|
| 1595 |
+
],
|
| 1596 |
+
"L": 16,
|
| 1597 |
+
"V": 2,
|
| 1598 |
+
"bias": false,
|
| 1599 |
+
"in_features": 2048,
|
| 1600 |
+
"in_part": [
|
| 1601 |
+
1024,
|
| 1602 |
+
1024
|
| 1603 |
+
],
|
| 1604 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1605 |
+
"linear_dtype": "float32",
|
| 1606 |
+
"out_features": 512,
|
| 1607 |
+
"td_x": 16,
|
| 1608 |
+
"td_y": 16,
|
| 1609 |
+
"tlut_bits": 11
|
| 1610 |
+
},
|
| 1611 |
+
"module_type": "IncoherentLinear",
|
| 1612 |
+
"out_features": 512,
|
| 1613 |
+
"rot_info": "skip_r",
|
| 1614 |
+
"scale": 32.0
|
| 1615 |
+
},
|
| 1616 |
+
"model.layers.3.mlp.down_proj": {
|
| 1617 |
+
"bias": false,
|
| 1618 |
+
"dtype": "float32",
|
| 1619 |
+
"hadU": 8192,
|
| 1620 |
+
"hadV": 2048,
|
| 1621 |
+
"in_features": 8192,
|
| 1622 |
+
"linear": {
|
| 1623 |
+
"KV": 6,
|
| 1624 |
+
"L": 16,
|
| 1625 |
+
"V": 2,
|
| 1626 |
+
"bias": false,
|
| 1627 |
+
"in_features": 8192,
|
| 1628 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1629 |
+
"linear_dtype": "float32",
|
| 1630 |
+
"out_features": 2048,
|
| 1631 |
+
"td_x": 16,
|
| 1632 |
+
"td_y": 16,
|
| 1633 |
+
"tlut_bits": 9
|
| 1634 |
+
},
|
| 1635 |
+
"module_type": "IncoherentLinear",
|
| 1636 |
+
"out_features": 2048,
|
| 1637 |
+
"rot_info": "skip_r",
|
| 1638 |
+
"scale": 32.0
|
| 1639 |
+
},
|
| 1640 |
+
"model.layers.3.mlp.gate_proj": {
|
| 1641 |
+
"bias": false,
|
| 1642 |
+
"dtype": "float32",
|
| 1643 |
+
"hadU": 2048,
|
| 1644 |
+
"hadV": 8192,
|
| 1645 |
+
"in_features": 2048,
|
| 1646 |
+
"linear": {
|
| 1647 |
+
"KV": 5,
|
| 1648 |
+
"L": 16,
|
| 1649 |
+
"V": 2,
|
| 1650 |
+
"bias": false,
|
| 1651 |
+
"in_features": 2048,
|
| 1652 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1653 |
+
"linear_dtype": "float32",
|
| 1654 |
+
"out_features": 8192,
|
| 1655 |
+
"td_x": 16,
|
| 1656 |
+
"td_y": 16,
|
| 1657 |
+
"tlut_bits": 9
|
| 1658 |
+
},
|
| 1659 |
+
"module_type": "IncoherentLinear",
|
| 1660 |
+
"out_features": 8192,
|
| 1661 |
+
"rot_info": "skip_r",
|
| 1662 |
+
"scale": 32.0
|
| 1663 |
+
},
|
| 1664 |
+
"model.layers.3.mlp.up_proj": {
|
| 1665 |
+
"bias": false,
|
| 1666 |
+
"dtype": "float32",
|
| 1667 |
+
"hadU": 2048,
|
| 1668 |
+
"hadV": 8192,
|
| 1669 |
+
"in_features": 2048,
|
| 1670 |
+
"linear": {
|
| 1671 |
+
"KV": 6,
|
| 1672 |
+
"L": 16,
|
| 1673 |
+
"V": 2,
|
| 1674 |
+
"bias": false,
|
| 1675 |
+
"in_features": 2048,
|
| 1676 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1677 |
+
"linear_dtype": "float32",
|
| 1678 |
+
"out_features": 8192,
|
| 1679 |
+
"td_x": 16,
|
| 1680 |
+
"td_y": 16,
|
| 1681 |
+
"tlut_bits": 9
|
| 1682 |
+
},
|
| 1683 |
+
"module_type": "IncoherentLinear",
|
| 1684 |
+
"out_features": 8192,
|
| 1685 |
+
"rot_info": "skip_r",
|
| 1686 |
+
"scale": 32.0
|
| 1687 |
+
},
|
| 1688 |
+
"model.layers.3.self_attn.k_proj": {
|
| 1689 |
+
"bias": false,
|
| 1690 |
+
"dtype": "float32",
|
| 1691 |
+
"hadU": 2048,
|
| 1692 |
+
"hadV": 512,
|
| 1693 |
+
"in_features": 2048,
|
| 1694 |
+
"linear": {
|
| 1695 |
+
"KV": [
|
| 1696 |
+
8,
|
| 1697 |
+
9
|
| 1698 |
+
],
|
| 1699 |
+
"L": 16,
|
| 1700 |
+
"V": 2,
|
| 1701 |
+
"bias": false,
|
| 1702 |
+
"in_features": 2048,
|
| 1703 |
+
"in_part": [
|
| 1704 |
+
1024,
|
| 1705 |
+
1024
|
| 1706 |
+
],
|
| 1707 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1708 |
+
"linear_dtype": "float32",
|
| 1709 |
+
"out_features": 512,
|
| 1710 |
+
"td_x": 16,
|
| 1711 |
+
"td_y": 16,
|
| 1712 |
+
"tlut_bits": 10
|
| 1713 |
+
},
|
| 1714 |
+
"module_type": "IncoherentLinear",
|
| 1715 |
+
"out_features": 512,
|
| 1716 |
+
"rot_info": "skip_r",
|
| 1717 |
+
"scale": 32.0
|
| 1718 |
+
},
|
| 1719 |
+
"model.layers.3.self_attn.o_proj": {
|
| 1720 |
+
"bias": false,
|
| 1721 |
+
"dtype": "float32",
|
| 1722 |
+
"hadU": 2048,
|
| 1723 |
+
"hadV": 2048,
|
| 1724 |
+
"in_features": 2048,
|
| 1725 |
+
"linear": {
|
| 1726 |
+
"KV": 7,
|
| 1727 |
+
"L": 16,
|
| 1728 |
+
"V": 2,
|
| 1729 |
+
"bias": false,
|
| 1730 |
+
"in_features": 2048,
|
| 1731 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1732 |
+
"linear_dtype": "float32",
|
| 1733 |
+
"out_features": 2048,
|
| 1734 |
+
"td_x": 16,
|
| 1735 |
+
"td_y": 16,
|
| 1736 |
+
"tlut_bits": 9
|
| 1737 |
+
},
|
| 1738 |
+
"module_type": "IncoherentLinear",
|
| 1739 |
+
"out_features": 2048,
|
| 1740 |
+
"rot_info": "skip_r",
|
| 1741 |
+
"scale": 32.0
|
| 1742 |
+
},
|
| 1743 |
+
"model.layers.3.self_attn.q_proj": {
|
| 1744 |
+
"bias": false,
|
| 1745 |
+
"dtype": "float32",
|
| 1746 |
+
"hadU": 2048,
|
| 1747 |
+
"hadV": 2048,
|
| 1748 |
+
"in_features": 2048,
|
| 1749 |
+
"linear": {
|
| 1750 |
+
"KV": 7,
|
| 1751 |
+
"L": 16,
|
| 1752 |
+
"V": 2,
|
| 1753 |
+
"bias": false,
|
| 1754 |
+
"in_features": 2048,
|
| 1755 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1756 |
+
"linear_dtype": "float32",
|
| 1757 |
+
"out_features": 2048,
|
| 1758 |
+
"td_x": 16,
|
| 1759 |
+
"td_y": 16,
|
| 1760 |
+
"tlut_bits": 9
|
| 1761 |
+
},
|
| 1762 |
+
"module_type": "IncoherentLinear",
|
| 1763 |
+
"out_features": 2048,
|
| 1764 |
+
"rot_info": "skip_r",
|
| 1765 |
+
"scale": 32.0
|
| 1766 |
+
},
|
| 1767 |
+
"model.layers.3.self_attn.v_proj": {
|
| 1768 |
+
"bias": false,
|
| 1769 |
+
"dtype": "float32",
|
| 1770 |
+
"hadU": 2048,
|
| 1771 |
+
"hadV": 512,
|
| 1772 |
+
"in_features": 2048,
|
| 1773 |
+
"linear": {
|
| 1774 |
+
"KV": 10,
|
| 1775 |
+
"L": 16,
|
| 1776 |
+
"V": 2,
|
| 1777 |
+
"bias": false,
|
| 1778 |
+
"in_features": 2048,
|
| 1779 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1780 |
+
"linear_dtype": "float32",
|
| 1781 |
+
"out_features": 512,
|
| 1782 |
+
"td_x": 16,
|
| 1783 |
+
"td_y": 16,
|
| 1784 |
+
"tlut_bits": 11
|
| 1785 |
+
},
|
| 1786 |
+
"module_type": "IncoherentLinear",
|
| 1787 |
+
"out_features": 512,
|
| 1788 |
+
"rot_info": "skip_r",
|
| 1789 |
+
"scale": 32.0
|
| 1790 |
+
},
|
| 1791 |
+
"model.layers.4.mlp.down_proj": {
|
| 1792 |
+
"bias": false,
|
| 1793 |
+
"dtype": "float32",
|
| 1794 |
+
"hadU": 8192,
|
| 1795 |
+
"hadV": 2048,
|
| 1796 |
+
"in_features": 8192,
|
| 1797 |
+
"linear": {
|
| 1798 |
+
"KV": 7,
|
| 1799 |
+
"L": 16,
|
| 1800 |
+
"V": 2,
|
| 1801 |
+
"bias": false,
|
| 1802 |
+
"in_features": 8192,
|
| 1803 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1804 |
+
"linear_dtype": "float32",
|
| 1805 |
+
"out_features": 2048,
|
| 1806 |
+
"td_x": 16,
|
| 1807 |
+
"td_y": 16,
|
| 1808 |
+
"tlut_bits": 9
|
| 1809 |
+
},
|
| 1810 |
+
"module_type": "IncoherentLinear",
|
| 1811 |
+
"out_features": 2048,
|
| 1812 |
+
"rot_info": "skip_r",
|
| 1813 |
+
"scale": 32.0
|
| 1814 |
+
},
|
| 1815 |
+
"model.layers.4.mlp.gate_proj": {
|
| 1816 |
+
"bias": false,
|
| 1817 |
+
"dtype": "float32",
|
| 1818 |
+
"hadU": 2048,
|
| 1819 |
+
"hadV": 8192,
|
| 1820 |
+
"in_features": 2048,
|
| 1821 |
+
"linear": {
|
| 1822 |
+
"KV": 6,
|
| 1823 |
+
"L": 16,
|
| 1824 |
+
"V": 2,
|
| 1825 |
+
"bias": false,
|
| 1826 |
+
"in_features": 2048,
|
| 1827 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1828 |
+
"linear_dtype": "float32",
|
| 1829 |
+
"out_features": 8192,
|
| 1830 |
+
"td_x": 16,
|
| 1831 |
+
"td_y": 16,
|
| 1832 |
+
"tlut_bits": 9
|
| 1833 |
+
},
|
| 1834 |
+
"module_type": "IncoherentLinear",
|
| 1835 |
+
"out_features": 8192,
|
| 1836 |
+
"rot_info": "skip_r",
|
| 1837 |
+
"scale": 32.0
|
| 1838 |
+
},
|
| 1839 |
+
"model.layers.4.mlp.up_proj": {
|
| 1840 |
+
"bias": false,
|
| 1841 |
+
"dtype": "float32",
|
| 1842 |
+
"hadU": 2048,
|
| 1843 |
+
"hadV": 8192,
|
| 1844 |
+
"in_features": 2048,
|
| 1845 |
+
"linear": {
|
| 1846 |
+
"KV": 6,
|
| 1847 |
+
"L": 16,
|
| 1848 |
+
"V": 2,
|
| 1849 |
+
"bias": false,
|
| 1850 |
+
"in_features": 2048,
|
| 1851 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1852 |
+
"linear_dtype": "float32",
|
| 1853 |
+
"out_features": 8192,
|
| 1854 |
+
"td_x": 16,
|
| 1855 |
+
"td_y": 16,
|
| 1856 |
+
"tlut_bits": 9
|
| 1857 |
+
},
|
| 1858 |
+
"module_type": "IncoherentLinear",
|
| 1859 |
+
"out_features": 8192,
|
| 1860 |
+
"rot_info": "skip_r",
|
| 1861 |
+
"scale": 32.0
|
| 1862 |
+
},
|
| 1863 |
+
"model.layers.4.self_attn.k_proj": {
|
| 1864 |
+
"bias": false,
|
| 1865 |
+
"dtype": "float32",
|
| 1866 |
+
"hadU": 2048,
|
| 1867 |
+
"hadV": 512,
|
| 1868 |
+
"in_features": 2048,
|
| 1869 |
+
"linear": {
|
| 1870 |
+
"KV": [
|
| 1871 |
+
8,
|
| 1872 |
+
9
|
| 1873 |
+
],
|
| 1874 |
+
"L": 16,
|
| 1875 |
+
"V": 2,
|
| 1876 |
+
"bias": false,
|
| 1877 |
+
"in_features": 2048,
|
| 1878 |
+
"in_part": [
|
| 1879 |
+
1024,
|
| 1880 |
+
1024
|
| 1881 |
+
],
|
| 1882 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1883 |
+
"linear_dtype": "float32",
|
| 1884 |
+
"out_features": 512,
|
| 1885 |
+
"td_x": 16,
|
| 1886 |
+
"td_y": 16,
|
| 1887 |
+
"tlut_bits": 10
|
| 1888 |
+
},
|
| 1889 |
+
"module_type": "IncoherentLinear",
|
| 1890 |
+
"out_features": 512,
|
| 1891 |
+
"rot_info": "skip_r",
|
| 1892 |
+
"scale": 32.0
|
| 1893 |
+
},
|
| 1894 |
+
"model.layers.4.self_attn.o_proj": {
|
| 1895 |
+
"bias": false,
|
| 1896 |
+
"dtype": "float32",
|
| 1897 |
+
"hadU": 2048,
|
| 1898 |
+
"hadV": 2048,
|
| 1899 |
+
"in_features": 2048,
|
| 1900 |
+
"linear": {
|
| 1901 |
+
"KV": 8,
|
| 1902 |
+
"L": 16,
|
| 1903 |
+
"V": 2,
|
| 1904 |
+
"bias": false,
|
| 1905 |
+
"in_features": 2048,
|
| 1906 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1907 |
+
"linear_dtype": "float32",
|
| 1908 |
+
"out_features": 2048,
|
| 1909 |
+
"td_x": 16,
|
| 1910 |
+
"td_y": 16,
|
| 1911 |
+
"tlut_bits": 9
|
| 1912 |
+
},
|
| 1913 |
+
"module_type": "IncoherentLinear",
|
| 1914 |
+
"out_features": 2048,
|
| 1915 |
+
"rot_info": "skip_r",
|
| 1916 |
+
"scale": 32.0
|
| 1917 |
+
},
|
| 1918 |
+
"model.layers.4.self_attn.q_proj": {
|
| 1919 |
+
"bias": false,
|
| 1920 |
+
"dtype": "float32",
|
| 1921 |
+
"hadU": 2048,
|
| 1922 |
+
"hadV": 2048,
|
| 1923 |
+
"in_features": 2048,
|
| 1924 |
+
"linear": {
|
| 1925 |
+
"KV": 7,
|
| 1926 |
+
"L": 16,
|
| 1927 |
+
"V": 2,
|
| 1928 |
+
"bias": false,
|
| 1929 |
+
"in_features": 2048,
|
| 1930 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1931 |
+
"linear_dtype": "float32",
|
| 1932 |
+
"out_features": 2048,
|
| 1933 |
+
"td_x": 16,
|
| 1934 |
+
"td_y": 16,
|
| 1935 |
+
"tlut_bits": 9
|
| 1936 |
+
},
|
| 1937 |
+
"module_type": "IncoherentLinear",
|
| 1938 |
+
"out_features": 2048,
|
| 1939 |
+
"rot_info": "skip_r",
|
| 1940 |
+
"scale": 32.0
|
| 1941 |
+
},
|
| 1942 |
+
"model.layers.4.self_attn.v_proj": {
|
| 1943 |
+
"bias": false,
|
| 1944 |
+
"dtype": "float32",
|
| 1945 |
+
"hadU": 2048,
|
| 1946 |
+
"hadV": 512,
|
| 1947 |
+
"in_features": 2048,
|
| 1948 |
+
"linear": {
|
| 1949 |
+
"KV": [
|
| 1950 |
+
9,
|
| 1951 |
+
10
|
| 1952 |
+
],
|
| 1953 |
+
"L": 16,
|
| 1954 |
+
"V": 2,
|
| 1955 |
+
"bias": false,
|
| 1956 |
+
"in_features": 2048,
|
| 1957 |
+
"in_part": [
|
| 1958 |
+
1024,
|
| 1959 |
+
1024
|
| 1960 |
+
],
|
| 1961 |
+
"linear_cls": "CombtLinearTCQ",
|
| 1962 |
+
"linear_dtype": "float32",
|
| 1963 |
+
"out_features": 512,
|
| 1964 |
+
"td_x": 16,
|
| 1965 |
+
"td_y": 16,
|
| 1966 |
+
"tlut_bits": 11
|
| 1967 |
+
},
|
| 1968 |
+
"module_type": "IncoherentLinear",
|
| 1969 |
+
"out_features": 512,
|
| 1970 |
+
"rot_info": "skip_r",
|
| 1971 |
+
"scale": 32.0
|
| 1972 |
+
},
|
| 1973 |
+
"model.layers.5.mlp.down_proj": {
|
| 1974 |
+
"bias": false,
|
| 1975 |
+
"dtype": "float32",
|
| 1976 |
+
"hadU": 8192,
|
| 1977 |
+
"hadV": 2048,
|
| 1978 |
+
"in_features": 8192,
|
| 1979 |
+
"linear": {
|
| 1980 |
+
"KV": 6,
|
| 1981 |
+
"L": 16,
|
| 1982 |
+
"V": 2,
|
| 1983 |
+
"bias": false,
|
| 1984 |
+
"in_features": 8192,
|
| 1985 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 1986 |
+
"linear_dtype": "float32",
|
| 1987 |
+
"out_features": 2048,
|
| 1988 |
+
"td_x": 16,
|
| 1989 |
+
"td_y": 16,
|
| 1990 |
+
"tlut_bits": 9
|
| 1991 |
+
},
|
| 1992 |
+
"module_type": "IncoherentLinear",
|
| 1993 |
+
"out_features": 2048,
|
| 1994 |
+
"rot_info": "skip_r",
|
| 1995 |
+
"scale": 32.0
|
| 1996 |
+
},
|
| 1997 |
+
"model.layers.5.mlp.gate_proj": {
|
| 1998 |
+
"bias": false,
|
| 1999 |
+
"dtype": "float32",
|
| 2000 |
+
"hadU": 2048,
|
| 2001 |
+
"hadV": 8192,
|
| 2002 |
+
"in_features": 2048,
|
| 2003 |
+
"linear": {
|
| 2004 |
+
"KV": 6,
|
| 2005 |
+
"L": 16,
|
| 2006 |
+
"V": 2,
|
| 2007 |
+
"bias": false,
|
| 2008 |
+
"in_features": 2048,
|
| 2009 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2010 |
+
"linear_dtype": "float32",
|
| 2011 |
+
"out_features": 8192,
|
| 2012 |
+
"td_x": 16,
|
| 2013 |
+
"td_y": 16,
|
| 2014 |
+
"tlut_bits": 9
|
| 2015 |
+
},
|
| 2016 |
+
"module_type": "IncoherentLinear",
|
| 2017 |
+
"out_features": 8192,
|
| 2018 |
+
"rot_info": "skip_r",
|
| 2019 |
+
"scale": 32.0
|
| 2020 |
+
},
|
| 2021 |
+
"model.layers.5.mlp.up_proj": {
|
| 2022 |
+
"bias": false,
|
| 2023 |
+
"dtype": "float32",
|
| 2024 |
+
"hadU": 2048,
|
| 2025 |
+
"hadV": 8192,
|
| 2026 |
+
"in_features": 2048,
|
| 2027 |
+
"linear": {
|
| 2028 |
+
"KV": 6,
|
| 2029 |
+
"L": 16,
|
| 2030 |
+
"V": 2,
|
| 2031 |
+
"bias": false,
|
| 2032 |
+
"in_features": 2048,
|
| 2033 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2034 |
+
"linear_dtype": "float32",
|
| 2035 |
+
"out_features": 8192,
|
| 2036 |
+
"td_x": 16,
|
| 2037 |
+
"td_y": 16,
|
| 2038 |
+
"tlut_bits": 9
|
| 2039 |
+
},
|
| 2040 |
+
"module_type": "IncoherentLinear",
|
| 2041 |
+
"out_features": 8192,
|
| 2042 |
+
"rot_info": "skip_r",
|
| 2043 |
+
"scale": 32.0
|
| 2044 |
+
},
|
| 2045 |
+
"model.layers.5.self_attn.k_proj": {
|
| 2046 |
+
"bias": false,
|
| 2047 |
+
"dtype": "float32",
|
| 2048 |
+
"hadU": 2048,
|
| 2049 |
+
"hadV": 512,
|
| 2050 |
+
"in_features": 2048,
|
| 2051 |
+
"linear": {
|
| 2052 |
+
"KV": 9,
|
| 2053 |
+
"L": 16,
|
| 2054 |
+
"V": 2,
|
| 2055 |
+
"bias": false,
|
| 2056 |
+
"in_features": 2048,
|
| 2057 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2058 |
+
"linear_dtype": "float32",
|
| 2059 |
+
"out_features": 512,
|
| 2060 |
+
"td_x": 16,
|
| 2061 |
+
"td_y": 16,
|
| 2062 |
+
"tlut_bits": 10
|
| 2063 |
+
},
|
| 2064 |
+
"module_type": "IncoherentLinear",
|
| 2065 |
+
"out_features": 512,
|
| 2066 |
+
"rot_info": "skip_r",
|
| 2067 |
+
"scale": 32.0
|
| 2068 |
+
},
|
| 2069 |
+
"model.layers.5.self_attn.o_proj": {
|
| 2070 |
+
"bias": false,
|
| 2071 |
+
"dtype": "float32",
|
| 2072 |
+
"hadU": 2048,
|
| 2073 |
+
"hadV": 2048,
|
| 2074 |
+
"in_features": 2048,
|
| 2075 |
+
"linear": {
|
| 2076 |
+
"KV": [
|
| 2077 |
+
8,
|
| 2078 |
+
9
|
| 2079 |
+
],
|
| 2080 |
+
"L": 16,
|
| 2081 |
+
"V": 2,
|
| 2082 |
+
"bias": false,
|
| 2083 |
+
"in_features": 2048,
|
| 2084 |
+
"in_part": [
|
| 2085 |
+
1024,
|
| 2086 |
+
1024
|
| 2087 |
+
],
|
| 2088 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2089 |
+
"linear_dtype": "float32",
|
| 2090 |
+
"out_features": 2048,
|
| 2091 |
+
"td_x": 16,
|
| 2092 |
+
"td_y": 16,
|
| 2093 |
+
"tlut_bits": 10
|
| 2094 |
+
},
|
| 2095 |
+
"module_type": "IncoherentLinear",
|
| 2096 |
+
"out_features": 2048,
|
| 2097 |
+
"rot_info": "skip_r",
|
| 2098 |
+
"scale": 32.0
|
| 2099 |
+
},
|
| 2100 |
+
"model.layers.5.self_attn.q_proj": {
|
| 2101 |
+
"bias": false,
|
| 2102 |
+
"dtype": "float32",
|
| 2103 |
+
"hadU": 2048,
|
| 2104 |
+
"hadV": 2048,
|
| 2105 |
+
"in_features": 2048,
|
| 2106 |
+
"linear": {
|
| 2107 |
+
"KV": 8,
|
| 2108 |
+
"L": 16,
|
| 2109 |
+
"V": 2,
|
| 2110 |
+
"bias": false,
|
| 2111 |
+
"in_features": 2048,
|
| 2112 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2113 |
+
"linear_dtype": "float32",
|
| 2114 |
+
"out_features": 2048,
|
| 2115 |
+
"td_x": 16,
|
| 2116 |
+
"td_y": 16,
|
| 2117 |
+
"tlut_bits": 9
|
| 2118 |
+
},
|
| 2119 |
+
"module_type": "IncoherentLinear",
|
| 2120 |
+
"out_features": 2048,
|
| 2121 |
+
"rot_info": "skip_r",
|
| 2122 |
+
"scale": 32.0
|
| 2123 |
+
},
|
| 2124 |
+
"model.layers.5.self_attn.v_proj": {
|
| 2125 |
+
"bias": false,
|
| 2126 |
+
"dtype": "float32",
|
| 2127 |
+
"hadU": 2048,
|
| 2128 |
+
"hadV": 512,
|
| 2129 |
+
"in_features": 2048,
|
| 2130 |
+
"linear": {
|
| 2131 |
+
"KV": [
|
| 2132 |
+
9,
|
| 2133 |
+
10
|
| 2134 |
+
],
|
| 2135 |
+
"L": 16,
|
| 2136 |
+
"V": 2,
|
| 2137 |
+
"bias": false,
|
| 2138 |
+
"in_features": 2048,
|
| 2139 |
+
"in_part": [
|
| 2140 |
+
1024,
|
| 2141 |
+
1024
|
| 2142 |
+
],
|
| 2143 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2144 |
+
"linear_dtype": "float32",
|
| 2145 |
+
"out_features": 512,
|
| 2146 |
+
"td_x": 16,
|
| 2147 |
+
"td_y": 16,
|
| 2148 |
+
"tlut_bits": 11
|
| 2149 |
+
},
|
| 2150 |
+
"module_type": "IncoherentLinear",
|
| 2151 |
+
"out_features": 512,
|
| 2152 |
+
"rot_info": "skip_r",
|
| 2153 |
+
"scale": 32.0
|
| 2154 |
+
},
|
| 2155 |
+
"model.layers.6.mlp.down_proj": {
|
| 2156 |
+
"bias": false,
|
| 2157 |
+
"dtype": "float32",
|
| 2158 |
+
"hadU": 8192,
|
| 2159 |
+
"hadV": 2048,
|
| 2160 |
+
"in_features": 8192,
|
| 2161 |
+
"linear": {
|
| 2162 |
+
"KV": 6,
|
| 2163 |
+
"L": 16,
|
| 2164 |
+
"V": 2,
|
| 2165 |
+
"bias": false,
|
| 2166 |
+
"in_features": 8192,
|
| 2167 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2168 |
+
"linear_dtype": "float32",
|
| 2169 |
+
"out_features": 2048,
|
| 2170 |
+
"td_x": 16,
|
| 2171 |
+
"td_y": 16,
|
| 2172 |
+
"tlut_bits": 9
|
| 2173 |
+
},
|
| 2174 |
+
"module_type": "IncoherentLinear",
|
| 2175 |
+
"out_features": 2048,
|
| 2176 |
+
"rot_info": "skip_r",
|
| 2177 |
+
"scale": 32.0
|
| 2178 |
+
},
|
| 2179 |
+
"model.layers.6.mlp.gate_proj": {
|
| 2180 |
+
"bias": false,
|
| 2181 |
+
"dtype": "float32",
|
| 2182 |
+
"hadU": 2048,
|
| 2183 |
+
"hadV": 8192,
|
| 2184 |
+
"in_features": 2048,
|
| 2185 |
+
"linear": {
|
| 2186 |
+
"KV": 6,
|
| 2187 |
+
"L": 16,
|
| 2188 |
+
"V": 2,
|
| 2189 |
+
"bias": false,
|
| 2190 |
+
"in_features": 2048,
|
| 2191 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2192 |
+
"linear_dtype": "float32",
|
| 2193 |
+
"out_features": 8192,
|
| 2194 |
+
"td_x": 16,
|
| 2195 |
+
"td_y": 16,
|
| 2196 |
+
"tlut_bits": 9
|
| 2197 |
+
},
|
| 2198 |
+
"module_type": "IncoherentLinear",
|
| 2199 |
+
"out_features": 8192,
|
| 2200 |
+
"rot_info": "skip_r",
|
| 2201 |
+
"scale": 32.0
|
| 2202 |
+
},
|
| 2203 |
+
"model.layers.6.mlp.up_proj": {
|
| 2204 |
+
"bias": false,
|
| 2205 |
+
"dtype": "float32",
|
| 2206 |
+
"hadU": 2048,
|
| 2207 |
+
"hadV": 8192,
|
| 2208 |
+
"in_features": 2048,
|
| 2209 |
+
"linear": {
|
| 2210 |
+
"KV": 6,
|
| 2211 |
+
"L": 16,
|
| 2212 |
+
"V": 2,
|
| 2213 |
+
"bias": false,
|
| 2214 |
+
"in_features": 2048,
|
| 2215 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2216 |
+
"linear_dtype": "float32",
|
| 2217 |
+
"out_features": 8192,
|
| 2218 |
+
"td_x": 16,
|
| 2219 |
+
"td_y": 16,
|
| 2220 |
+
"tlut_bits": 9
|
| 2221 |
+
},
|
| 2222 |
+
"module_type": "IncoherentLinear",
|
| 2223 |
+
"out_features": 8192,
|
| 2224 |
+
"rot_info": "skip_r",
|
| 2225 |
+
"scale": 32.0
|
| 2226 |
+
},
|
| 2227 |
+
"model.layers.6.self_attn.k_proj": {
|
| 2228 |
+
"bias": false,
|
| 2229 |
+
"dtype": "float32",
|
| 2230 |
+
"hadU": 2048,
|
| 2231 |
+
"hadV": 512,
|
| 2232 |
+
"in_features": 2048,
|
| 2233 |
+
"linear": {
|
| 2234 |
+
"KV": [
|
| 2235 |
+
8,
|
| 2236 |
+
9
|
| 2237 |
+
],
|
| 2238 |
+
"L": 16,
|
| 2239 |
+
"V": 2,
|
| 2240 |
+
"bias": false,
|
| 2241 |
+
"in_features": 2048,
|
| 2242 |
+
"in_part": [
|
| 2243 |
+
1024,
|
| 2244 |
+
1024
|
| 2245 |
+
],
|
| 2246 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2247 |
+
"linear_dtype": "float32",
|
| 2248 |
+
"out_features": 512,
|
| 2249 |
+
"td_x": 16,
|
| 2250 |
+
"td_y": 16,
|
| 2251 |
+
"tlut_bits": 10
|
| 2252 |
+
},
|
| 2253 |
+
"module_type": "IncoherentLinear",
|
| 2254 |
+
"out_features": 512,
|
| 2255 |
+
"rot_info": "skip_r",
|
| 2256 |
+
"scale": 32.0
|
| 2257 |
+
},
|
| 2258 |
+
"model.layers.6.self_attn.o_proj": {
|
| 2259 |
+
"bias": false,
|
| 2260 |
+
"dtype": "float32",
|
| 2261 |
+
"hadU": 2048,
|
| 2262 |
+
"hadV": 2048,
|
| 2263 |
+
"in_features": 2048,
|
| 2264 |
+
"linear": {
|
| 2265 |
+
"KV": 9,
|
| 2266 |
+
"L": 16,
|
| 2267 |
+
"V": 2,
|
| 2268 |
+
"bias": false,
|
| 2269 |
+
"in_features": 2048,
|
| 2270 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2271 |
+
"linear_dtype": "float32",
|
| 2272 |
+
"out_features": 2048,
|
| 2273 |
+
"td_x": 16,
|
| 2274 |
+
"td_y": 16,
|
| 2275 |
+
"tlut_bits": 10
|
| 2276 |
+
},
|
| 2277 |
+
"module_type": "IncoherentLinear",
|
| 2278 |
+
"out_features": 2048,
|
| 2279 |
+
"rot_info": "skip_r",
|
| 2280 |
+
"scale": 32.0
|
| 2281 |
+
},
|
| 2282 |
+
"model.layers.6.self_attn.q_proj": {
|
| 2283 |
+
"bias": false,
|
| 2284 |
+
"dtype": "float32",
|
| 2285 |
+
"hadU": 2048,
|
| 2286 |
+
"hadV": 2048,
|
| 2287 |
+
"in_features": 2048,
|
| 2288 |
+
"linear": {
|
| 2289 |
+
"KV": 7,
|
| 2290 |
+
"L": 16,
|
| 2291 |
+
"V": 2,
|
| 2292 |
+
"bias": false,
|
| 2293 |
+
"in_features": 2048,
|
| 2294 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2295 |
+
"linear_dtype": "float32",
|
| 2296 |
+
"out_features": 2048,
|
| 2297 |
+
"td_x": 16,
|
| 2298 |
+
"td_y": 16,
|
| 2299 |
+
"tlut_bits": 9
|
| 2300 |
+
},
|
| 2301 |
+
"module_type": "IncoherentLinear",
|
| 2302 |
+
"out_features": 2048,
|
| 2303 |
+
"rot_info": "skip_r",
|
| 2304 |
+
"scale": 32.0
|
| 2305 |
+
},
|
| 2306 |
+
"model.layers.6.self_attn.v_proj": {
|
| 2307 |
+
"bias": false,
|
| 2308 |
+
"dtype": "float32",
|
| 2309 |
+
"hadU": 2048,
|
| 2310 |
+
"hadV": 512,
|
| 2311 |
+
"in_features": 2048,
|
| 2312 |
+
"linear": {
|
| 2313 |
+
"KV": [
|
| 2314 |
+
9,
|
| 2315 |
+
10
|
| 2316 |
+
],
|
| 2317 |
+
"L": 16,
|
| 2318 |
+
"V": 2,
|
| 2319 |
+
"bias": false,
|
| 2320 |
+
"in_features": 2048,
|
| 2321 |
+
"in_part": [
|
| 2322 |
+
1024,
|
| 2323 |
+
1024
|
| 2324 |
+
],
|
| 2325 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2326 |
+
"linear_dtype": "float32",
|
| 2327 |
+
"out_features": 512,
|
| 2328 |
+
"td_x": 16,
|
| 2329 |
+
"td_y": 16,
|
| 2330 |
+
"tlut_bits": 11
|
| 2331 |
+
},
|
| 2332 |
+
"module_type": "IncoherentLinear",
|
| 2333 |
+
"out_features": 512,
|
| 2334 |
+
"rot_info": "skip_r",
|
| 2335 |
+
"scale": 32.0
|
| 2336 |
+
},
|
| 2337 |
+
"model.layers.7.mlp.down_proj": {
|
| 2338 |
+
"bias": false,
|
| 2339 |
+
"dtype": "float32",
|
| 2340 |
+
"hadU": 8192,
|
| 2341 |
+
"hadV": 2048,
|
| 2342 |
+
"in_features": 8192,
|
| 2343 |
+
"linear": {
|
| 2344 |
+
"KV": [
|
| 2345 |
+
6,
|
| 2346 |
+
7
|
| 2347 |
+
],
|
| 2348 |
+
"L": 16,
|
| 2349 |
+
"V": 2,
|
| 2350 |
+
"bias": false,
|
| 2351 |
+
"in_features": 8192,
|
| 2352 |
+
"in_part": [
|
| 2353 |
+
4096,
|
| 2354 |
+
4096
|
| 2355 |
+
],
|
| 2356 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2357 |
+
"linear_dtype": "float32",
|
| 2358 |
+
"out_features": 2048,
|
| 2359 |
+
"td_x": 16,
|
| 2360 |
+
"td_y": 16,
|
| 2361 |
+
"tlut_bits": 9
|
| 2362 |
+
},
|
| 2363 |
+
"module_type": "IncoherentLinear",
|
| 2364 |
+
"out_features": 2048,
|
| 2365 |
+
"rot_info": "skip_r",
|
| 2366 |
+
"scale": 32.0
|
| 2367 |
+
},
|
| 2368 |
+
"model.layers.7.mlp.gate_proj": {
|
| 2369 |
+
"bias": false,
|
| 2370 |
+
"dtype": "float32",
|
| 2371 |
+
"hadU": 2048,
|
| 2372 |
+
"hadV": 8192,
|
| 2373 |
+
"in_features": 2048,
|
| 2374 |
+
"linear": {
|
| 2375 |
+
"KV": 6,
|
| 2376 |
+
"L": 16,
|
| 2377 |
+
"V": 2,
|
| 2378 |
+
"bias": false,
|
| 2379 |
+
"in_features": 2048,
|
| 2380 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2381 |
+
"linear_dtype": "float32",
|
| 2382 |
+
"out_features": 8192,
|
| 2383 |
+
"td_x": 16,
|
| 2384 |
+
"td_y": 16,
|
| 2385 |
+
"tlut_bits": 9
|
| 2386 |
+
},
|
| 2387 |
+
"module_type": "IncoherentLinear",
|
| 2388 |
+
"out_features": 8192,
|
| 2389 |
+
"rot_info": "skip_r",
|
| 2390 |
+
"scale": 32.0
|
| 2391 |
+
},
|
| 2392 |
+
"model.layers.7.mlp.up_proj": {
|
| 2393 |
+
"bias": false,
|
| 2394 |
+
"dtype": "float32",
|
| 2395 |
+
"hadU": 2048,
|
| 2396 |
+
"hadV": 8192,
|
| 2397 |
+
"in_features": 2048,
|
| 2398 |
+
"linear": {
|
| 2399 |
+
"KV": 6,
|
| 2400 |
+
"L": 16,
|
| 2401 |
+
"V": 2,
|
| 2402 |
+
"bias": false,
|
| 2403 |
+
"in_features": 2048,
|
| 2404 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2405 |
+
"linear_dtype": "float32",
|
| 2406 |
+
"out_features": 8192,
|
| 2407 |
+
"td_x": 16,
|
| 2408 |
+
"td_y": 16,
|
| 2409 |
+
"tlut_bits": 9
|
| 2410 |
+
},
|
| 2411 |
+
"module_type": "IncoherentLinear",
|
| 2412 |
+
"out_features": 8192,
|
| 2413 |
+
"rot_info": "skip_r",
|
| 2414 |
+
"scale": 32.0
|
| 2415 |
+
},
|
| 2416 |
+
"model.layers.7.self_attn.k_proj": {
|
| 2417 |
+
"bias": false,
|
| 2418 |
+
"dtype": "float32",
|
| 2419 |
+
"hadU": 2048,
|
| 2420 |
+
"hadV": 512,
|
| 2421 |
+
"in_features": 2048,
|
| 2422 |
+
"linear": {
|
| 2423 |
+
"KV": [
|
| 2424 |
+
8,
|
| 2425 |
+
9
|
| 2426 |
+
],
|
| 2427 |
+
"L": 16,
|
| 2428 |
+
"V": 2,
|
| 2429 |
+
"bias": false,
|
| 2430 |
+
"in_features": 2048,
|
| 2431 |
+
"in_part": [
|
| 2432 |
+
1024,
|
| 2433 |
+
1024
|
| 2434 |
+
],
|
| 2435 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2436 |
+
"linear_dtype": "float32",
|
| 2437 |
+
"out_features": 512,
|
| 2438 |
+
"td_x": 16,
|
| 2439 |
+
"td_y": 16,
|
| 2440 |
+
"tlut_bits": 10
|
| 2441 |
+
},
|
| 2442 |
+
"module_type": "IncoherentLinear",
|
| 2443 |
+
"out_features": 512,
|
| 2444 |
+
"rot_info": "skip_r",
|
| 2445 |
+
"scale": 32.0
|
| 2446 |
+
},
|
| 2447 |
+
"model.layers.7.self_attn.o_proj": {
|
| 2448 |
+
"bias": false,
|
| 2449 |
+
"dtype": "float32",
|
| 2450 |
+
"hadU": 2048,
|
| 2451 |
+
"hadV": 2048,
|
| 2452 |
+
"in_features": 2048,
|
| 2453 |
+
"linear": {
|
| 2454 |
+
"KV": 9,
|
| 2455 |
+
"L": 16,
|
| 2456 |
+
"V": 2,
|
| 2457 |
+
"bias": false,
|
| 2458 |
+
"in_features": 2048,
|
| 2459 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2460 |
+
"linear_dtype": "float32",
|
| 2461 |
+
"out_features": 2048,
|
| 2462 |
+
"td_x": 16,
|
| 2463 |
+
"td_y": 16,
|
| 2464 |
+
"tlut_bits": 10
|
| 2465 |
+
},
|
| 2466 |
+
"module_type": "IncoherentLinear",
|
| 2467 |
+
"out_features": 2048,
|
| 2468 |
+
"rot_info": "skip_r",
|
| 2469 |
+
"scale": 32.0
|
| 2470 |
+
},
|
| 2471 |
+
"model.layers.7.self_attn.q_proj": {
|
| 2472 |
+
"bias": false,
|
| 2473 |
+
"dtype": "float32",
|
| 2474 |
+
"hadU": 2048,
|
| 2475 |
+
"hadV": 2048,
|
| 2476 |
+
"in_features": 2048,
|
| 2477 |
+
"linear": {
|
| 2478 |
+
"KV": 7,
|
| 2479 |
+
"L": 16,
|
| 2480 |
+
"V": 2,
|
| 2481 |
+
"bias": false,
|
| 2482 |
+
"in_features": 2048,
|
| 2483 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2484 |
+
"linear_dtype": "float32",
|
| 2485 |
+
"out_features": 2048,
|
| 2486 |
+
"td_x": 16,
|
| 2487 |
+
"td_y": 16,
|
| 2488 |
+
"tlut_bits": 9
|
| 2489 |
+
},
|
| 2490 |
+
"module_type": "IncoherentLinear",
|
| 2491 |
+
"out_features": 2048,
|
| 2492 |
+
"rot_info": "skip_r",
|
| 2493 |
+
"scale": 32.0
|
| 2494 |
+
},
|
| 2495 |
+
"model.layers.7.self_attn.v_proj": {
|
| 2496 |
+
"bias": false,
|
| 2497 |
+
"dtype": "float32",
|
| 2498 |
+
"hadU": 2048,
|
| 2499 |
+
"hadV": 512,
|
| 2500 |
+
"in_features": 2048,
|
| 2501 |
+
"linear": {
|
| 2502 |
+
"KV": 10,
|
| 2503 |
+
"L": 16,
|
| 2504 |
+
"V": 2,
|
| 2505 |
+
"bias": false,
|
| 2506 |
+
"in_features": 2048,
|
| 2507 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2508 |
+
"linear_dtype": "float32",
|
| 2509 |
+
"out_features": 512,
|
| 2510 |
+
"td_x": 16,
|
| 2511 |
+
"td_y": 16,
|
| 2512 |
+
"tlut_bits": 11
|
| 2513 |
+
},
|
| 2514 |
+
"module_type": "IncoherentLinear",
|
| 2515 |
+
"out_features": 512,
|
| 2516 |
+
"rot_info": "skip_r",
|
| 2517 |
+
"scale": 32.0
|
| 2518 |
+
},
|
| 2519 |
+
"model.layers.8.mlp.down_proj": {
|
| 2520 |
+
"bias": false,
|
| 2521 |
+
"dtype": "float32",
|
| 2522 |
+
"hadU": 8192,
|
| 2523 |
+
"hadV": 2048,
|
| 2524 |
+
"in_features": 8192,
|
| 2525 |
+
"linear": {
|
| 2526 |
+
"KV": 7,
|
| 2527 |
+
"L": 16,
|
| 2528 |
+
"V": 2,
|
| 2529 |
+
"bias": false,
|
| 2530 |
+
"in_features": 8192,
|
| 2531 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2532 |
+
"linear_dtype": "float32",
|
| 2533 |
+
"out_features": 2048,
|
| 2534 |
+
"td_x": 16,
|
| 2535 |
+
"td_y": 16,
|
| 2536 |
+
"tlut_bits": 9
|
| 2537 |
+
},
|
| 2538 |
+
"module_type": "IncoherentLinear",
|
| 2539 |
+
"out_features": 2048,
|
| 2540 |
+
"rot_info": "skip_r",
|
| 2541 |
+
"scale": 32.0
|
| 2542 |
+
},
|
| 2543 |
+
"model.layers.8.mlp.gate_proj": {
|
| 2544 |
+
"bias": false,
|
| 2545 |
+
"dtype": "float32",
|
| 2546 |
+
"hadU": 2048,
|
| 2547 |
+
"hadV": 8192,
|
| 2548 |
+
"in_features": 2048,
|
| 2549 |
+
"linear": {
|
| 2550 |
+
"KV": 6,
|
| 2551 |
+
"L": 16,
|
| 2552 |
+
"V": 2,
|
| 2553 |
+
"bias": false,
|
| 2554 |
+
"in_features": 2048,
|
| 2555 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2556 |
+
"linear_dtype": "float32",
|
| 2557 |
+
"out_features": 8192,
|
| 2558 |
+
"td_x": 16,
|
| 2559 |
+
"td_y": 16,
|
| 2560 |
+
"tlut_bits": 9
|
| 2561 |
+
},
|
| 2562 |
+
"module_type": "IncoherentLinear",
|
| 2563 |
+
"out_features": 8192,
|
| 2564 |
+
"rot_info": "skip_r",
|
| 2565 |
+
"scale": 32.0
|
| 2566 |
+
},
|
| 2567 |
+
"model.layers.8.mlp.up_proj": {
|
| 2568 |
+
"bias": false,
|
| 2569 |
+
"dtype": "float32",
|
| 2570 |
+
"hadU": 2048,
|
| 2571 |
+
"hadV": 8192,
|
| 2572 |
+
"in_features": 2048,
|
| 2573 |
+
"linear": {
|
| 2574 |
+
"KV": 7,
|
| 2575 |
+
"L": 16,
|
| 2576 |
+
"V": 2,
|
| 2577 |
+
"bias": false,
|
| 2578 |
+
"in_features": 2048,
|
| 2579 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2580 |
+
"linear_dtype": "float32",
|
| 2581 |
+
"out_features": 8192,
|
| 2582 |
+
"td_x": 16,
|
| 2583 |
+
"td_y": 16,
|
| 2584 |
+
"tlut_bits": 9
|
| 2585 |
+
},
|
| 2586 |
+
"module_type": "IncoherentLinear",
|
| 2587 |
+
"out_features": 8192,
|
| 2588 |
+
"rot_info": "skip_r",
|
| 2589 |
+
"scale": 32.0
|
| 2590 |
+
},
|
| 2591 |
+
"model.layers.8.self_attn.k_proj": {
|
| 2592 |
+
"bias": false,
|
| 2593 |
+
"dtype": "float32",
|
| 2594 |
+
"hadU": 2048,
|
| 2595 |
+
"hadV": 512,
|
| 2596 |
+
"in_features": 2048,
|
| 2597 |
+
"linear": {
|
| 2598 |
+
"KV": [
|
| 2599 |
+
8,
|
| 2600 |
+
9
|
| 2601 |
+
],
|
| 2602 |
+
"L": 16,
|
| 2603 |
+
"V": 2,
|
| 2604 |
+
"bias": false,
|
| 2605 |
+
"in_features": 2048,
|
| 2606 |
+
"in_part": [
|
| 2607 |
+
1024,
|
| 2608 |
+
1024
|
| 2609 |
+
],
|
| 2610 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2611 |
+
"linear_dtype": "float32",
|
| 2612 |
+
"out_features": 512,
|
| 2613 |
+
"td_x": 16,
|
| 2614 |
+
"td_y": 16,
|
| 2615 |
+
"tlut_bits": 10
|
| 2616 |
+
},
|
| 2617 |
+
"module_type": "IncoherentLinear",
|
| 2618 |
+
"out_features": 512,
|
| 2619 |
+
"rot_info": "skip_r",
|
| 2620 |
+
"scale": 32.0
|
| 2621 |
+
},
|
| 2622 |
+
"model.layers.8.self_attn.o_proj": {
|
| 2623 |
+
"bias": false,
|
| 2624 |
+
"dtype": "float32",
|
| 2625 |
+
"hadU": 2048,
|
| 2626 |
+
"hadV": 2048,
|
| 2627 |
+
"in_features": 2048,
|
| 2628 |
+
"linear": {
|
| 2629 |
+
"KV": 9,
|
| 2630 |
+
"L": 16,
|
| 2631 |
+
"V": 2,
|
| 2632 |
+
"bias": false,
|
| 2633 |
+
"in_features": 2048,
|
| 2634 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2635 |
+
"linear_dtype": "float32",
|
| 2636 |
+
"out_features": 2048,
|
| 2637 |
+
"td_x": 16,
|
| 2638 |
+
"td_y": 16,
|
| 2639 |
+
"tlut_bits": 10
|
| 2640 |
+
},
|
| 2641 |
+
"module_type": "IncoherentLinear",
|
| 2642 |
+
"out_features": 2048,
|
| 2643 |
+
"rot_info": "skip_r",
|
| 2644 |
+
"scale": 32.0
|
| 2645 |
+
},
|
| 2646 |
+
"model.layers.8.self_attn.q_proj": {
|
| 2647 |
+
"bias": false,
|
| 2648 |
+
"dtype": "float32",
|
| 2649 |
+
"hadU": 2048,
|
| 2650 |
+
"hadV": 2048,
|
| 2651 |
+
"in_features": 2048,
|
| 2652 |
+
"linear": {
|
| 2653 |
+
"KV": 7,
|
| 2654 |
+
"L": 16,
|
| 2655 |
+
"V": 2,
|
| 2656 |
+
"bias": false,
|
| 2657 |
+
"in_features": 2048,
|
| 2658 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2659 |
+
"linear_dtype": "float32",
|
| 2660 |
+
"out_features": 2048,
|
| 2661 |
+
"td_x": 16,
|
| 2662 |
+
"td_y": 16,
|
| 2663 |
+
"tlut_bits": 9
|
| 2664 |
+
},
|
| 2665 |
+
"module_type": "IncoherentLinear",
|
| 2666 |
+
"out_features": 2048,
|
| 2667 |
+
"rot_info": "skip_r",
|
| 2668 |
+
"scale": 32.0
|
| 2669 |
+
},
|
| 2670 |
+
"model.layers.8.self_attn.v_proj": {
|
| 2671 |
+
"bias": false,
|
| 2672 |
+
"dtype": "float32",
|
| 2673 |
+
"hadU": 2048,
|
| 2674 |
+
"hadV": 512,
|
| 2675 |
+
"in_features": 2048,
|
| 2676 |
+
"linear": {
|
| 2677 |
+
"KV": 10,
|
| 2678 |
+
"L": 16,
|
| 2679 |
+
"V": 2,
|
| 2680 |
+
"bias": false,
|
| 2681 |
+
"in_features": 2048,
|
| 2682 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2683 |
+
"linear_dtype": "float32",
|
| 2684 |
+
"out_features": 512,
|
| 2685 |
+
"td_x": 16,
|
| 2686 |
+
"td_y": 16,
|
| 2687 |
+
"tlut_bits": 11
|
| 2688 |
+
},
|
| 2689 |
+
"module_type": "IncoherentLinear",
|
| 2690 |
+
"out_features": 512,
|
| 2691 |
+
"rot_info": "skip_r",
|
| 2692 |
+
"scale": 32.0
|
| 2693 |
+
},
|
| 2694 |
+
"model.layers.9.mlp.down_proj": {
|
| 2695 |
+
"bias": false,
|
| 2696 |
+
"dtype": "float32",
|
| 2697 |
+
"hadU": 8192,
|
| 2698 |
+
"hadV": 2048,
|
| 2699 |
+
"in_features": 8192,
|
| 2700 |
+
"linear": {
|
| 2701 |
+
"KV": 7,
|
| 2702 |
+
"L": 16,
|
| 2703 |
+
"V": 2,
|
| 2704 |
+
"bias": false,
|
| 2705 |
+
"in_features": 8192,
|
| 2706 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2707 |
+
"linear_dtype": "float32",
|
| 2708 |
+
"out_features": 2048,
|
| 2709 |
+
"td_x": 16,
|
| 2710 |
+
"td_y": 16,
|
| 2711 |
+
"tlut_bits": 9
|
| 2712 |
+
},
|
| 2713 |
+
"module_type": "IncoherentLinear",
|
| 2714 |
+
"out_features": 2048,
|
| 2715 |
+
"rot_info": "skip_r",
|
| 2716 |
+
"scale": 32.0
|
| 2717 |
+
},
|
| 2718 |
+
"model.layers.9.mlp.gate_proj": {
|
| 2719 |
+
"bias": false,
|
| 2720 |
+
"dtype": "float32",
|
| 2721 |
+
"hadU": 2048,
|
| 2722 |
+
"hadV": 8192,
|
| 2723 |
+
"in_features": 2048,
|
| 2724 |
+
"linear": {
|
| 2725 |
+
"KV": 6,
|
| 2726 |
+
"L": 16,
|
| 2727 |
+
"V": 2,
|
| 2728 |
+
"bias": false,
|
| 2729 |
+
"in_features": 2048,
|
| 2730 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2731 |
+
"linear_dtype": "float32",
|
| 2732 |
+
"out_features": 8192,
|
| 2733 |
+
"td_x": 16,
|
| 2734 |
+
"td_y": 16,
|
| 2735 |
+
"tlut_bits": 9
|
| 2736 |
+
},
|
| 2737 |
+
"module_type": "IncoherentLinear",
|
| 2738 |
+
"out_features": 8192,
|
| 2739 |
+
"rot_info": "skip_r",
|
| 2740 |
+
"scale": 32.0
|
| 2741 |
+
},
|
| 2742 |
+
"model.layers.9.mlp.up_proj": {
|
| 2743 |
+
"bias": false,
|
| 2744 |
+
"dtype": "float32",
|
| 2745 |
+
"hadU": 2048,
|
| 2746 |
+
"hadV": 8192,
|
| 2747 |
+
"in_features": 2048,
|
| 2748 |
+
"linear": {
|
| 2749 |
+
"KV": 7,
|
| 2750 |
+
"L": 16,
|
| 2751 |
+
"V": 2,
|
| 2752 |
+
"bias": false,
|
| 2753 |
+
"in_features": 2048,
|
| 2754 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2755 |
+
"linear_dtype": "float32",
|
| 2756 |
+
"out_features": 8192,
|
| 2757 |
+
"td_x": 16,
|
| 2758 |
+
"td_y": 16,
|
| 2759 |
+
"tlut_bits": 9
|
| 2760 |
+
},
|
| 2761 |
+
"module_type": "IncoherentLinear",
|
| 2762 |
+
"out_features": 8192,
|
| 2763 |
+
"rot_info": "skip_r",
|
| 2764 |
+
"scale": 32.0
|
| 2765 |
+
},
|
| 2766 |
+
"model.layers.9.self_attn.k_proj": {
|
| 2767 |
+
"bias": false,
|
| 2768 |
+
"dtype": "float32",
|
| 2769 |
+
"hadU": 2048,
|
| 2770 |
+
"hadV": 512,
|
| 2771 |
+
"in_features": 2048,
|
| 2772 |
+
"linear": {
|
| 2773 |
+
"KV": 9,
|
| 2774 |
+
"L": 16,
|
| 2775 |
+
"V": 2,
|
| 2776 |
+
"bias": false,
|
| 2777 |
+
"in_features": 2048,
|
| 2778 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2779 |
+
"linear_dtype": "float32",
|
| 2780 |
+
"out_features": 512,
|
| 2781 |
+
"td_x": 16,
|
| 2782 |
+
"td_y": 16,
|
| 2783 |
+
"tlut_bits": 10
|
| 2784 |
+
},
|
| 2785 |
+
"module_type": "IncoherentLinear",
|
| 2786 |
+
"out_features": 512,
|
| 2787 |
+
"rot_info": "skip_r",
|
| 2788 |
+
"scale": 32.0
|
| 2789 |
+
},
|
| 2790 |
+
"model.layers.9.self_attn.o_proj": {
|
| 2791 |
+
"bias": false,
|
| 2792 |
+
"dtype": "float32",
|
| 2793 |
+
"hadU": 2048,
|
| 2794 |
+
"hadV": 2048,
|
| 2795 |
+
"in_features": 2048,
|
| 2796 |
+
"linear": {
|
| 2797 |
+
"KV": [
|
| 2798 |
+
8,
|
| 2799 |
+
9
|
| 2800 |
+
],
|
| 2801 |
+
"L": 16,
|
| 2802 |
+
"V": 2,
|
| 2803 |
+
"bias": false,
|
| 2804 |
+
"in_features": 2048,
|
| 2805 |
+
"in_part": [
|
| 2806 |
+
1024,
|
| 2807 |
+
1024
|
| 2808 |
+
],
|
| 2809 |
+
"linear_cls": "CombtLinearTCQ",
|
| 2810 |
+
"linear_dtype": "float32",
|
| 2811 |
+
"out_features": 2048,
|
| 2812 |
+
"td_x": 16,
|
| 2813 |
+
"td_y": 16,
|
| 2814 |
+
"tlut_bits": 10
|
| 2815 |
+
},
|
| 2816 |
+
"module_type": "IncoherentLinear",
|
| 2817 |
+
"out_features": 2048,
|
| 2818 |
+
"rot_info": "skip_r",
|
| 2819 |
+
"scale": 32.0
|
| 2820 |
+
},
|
| 2821 |
+
"model.layers.9.self_attn.q_proj": {
|
| 2822 |
+
"bias": false,
|
| 2823 |
+
"dtype": "float32",
|
| 2824 |
+
"hadU": 2048,
|
| 2825 |
+
"hadV": 2048,
|
| 2826 |
+
"in_features": 2048,
|
| 2827 |
+
"linear": {
|
| 2828 |
+
"KV": 7,
|
| 2829 |
+
"L": 16,
|
| 2830 |
+
"V": 2,
|
| 2831 |
+
"bias": false,
|
| 2832 |
+
"in_features": 2048,
|
| 2833 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2834 |
+
"linear_dtype": "float32",
|
| 2835 |
+
"out_features": 2048,
|
| 2836 |
+
"td_x": 16,
|
| 2837 |
+
"td_y": 16,
|
| 2838 |
+
"tlut_bits": 9
|
| 2839 |
+
},
|
| 2840 |
+
"module_type": "IncoherentLinear",
|
| 2841 |
+
"out_features": 2048,
|
| 2842 |
+
"rot_info": "skip_r",
|
| 2843 |
+
"scale": 32.0
|
| 2844 |
+
},
|
| 2845 |
+
"model.layers.9.self_attn.v_proj": {
|
| 2846 |
+
"bias": false,
|
| 2847 |
+
"dtype": "float32",
|
| 2848 |
+
"hadU": 2048,
|
| 2849 |
+
"hadV": 512,
|
| 2850 |
+
"in_features": 2048,
|
| 2851 |
+
"linear": {
|
| 2852 |
+
"KV": 10,
|
| 2853 |
+
"L": 16,
|
| 2854 |
+
"V": 2,
|
| 2855 |
+
"bias": false,
|
| 2856 |
+
"in_features": 2048,
|
| 2857 |
+
"linear_cls": "QTIPLinearTCQ",
|
| 2858 |
+
"linear_dtype": "float32",
|
| 2859 |
+
"out_features": 512,
|
| 2860 |
+
"td_x": 16,
|
| 2861 |
+
"td_y": 16,
|
| 2862 |
+
"tlut_bits": 11
|
| 2863 |
+
},
|
| 2864 |
+
"module_type": "IncoherentLinear",
|
| 2865 |
+
"out_features": 512,
|
| 2866 |
+
"rot_info": "skip_r",
|
| 2867 |
+
"scale": 32.0
|
| 2868 |
+
}
|
| 2869 |
+
}
|
| 2870 |
+
},
|
| 2871 |
+
"rms_norm_eps": 1e-05,
|
| 2872 |
+
"rope_scaling": {
|
| 2873 |
+
"factor": 32.0,
|
| 2874 |
+
"high_freq_factor": 4.0,
|
| 2875 |
+
"low_freq_factor": 1.0,
|
| 2876 |
+
"original_max_position_embeddings": 8192,
|
| 2877 |
+
"rope_type": "llama3"
|
| 2878 |
+
},
|
| 2879 |
+
"rope_theta": 500000.0,
|
| 2880 |
+
"tie_word_embeddings": true,
|
| 2881 |
+
"torch_dtype": "float16",
|
| 2882 |
+
"transformers_version": "4.45.2",
|
| 2883 |
+
"use_cache": true,
|
| 2884 |
+
"vocab_size": 128256
|
| 2885 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 128000,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": 128001,
|
| 6 |
+
"temperature": 0.6,
|
| 7 |
+
"top_p": 0.9,
|
| 8 |
+
"transformers_version": "4.45.2"
|
| 9 |
+
}
|
lib/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
lib/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (158 Bytes). View file
|
|
|
lib/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (318 Bytes). View file
|
|
|
lib/algo/__init__.py
ADDED
|
File without changes
|
lib/algo/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (163 Bytes). View file
|
|
|
lib/algo/__pycache__/ldlq.cpython-311.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
lib/algo/ldlq.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import glog
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import time
|
| 8 |
+
from lib import utils
|
| 9 |
+
|
| 10 |
+
_PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
|
| 11 |
+
4).flatten()
|
| 12 |
+
_INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
|
| 13 |
+
_INV_PERMUTE[_PERMUTE] = torch.arange(256)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def LDLQ_VQ(Wr, L, cb, buf_cols=128):
|
| 17 |
+
buf_cols = max(buf_cols, cb.vec_sz)
|
| 18 |
+
(m, n) = Wr.shape
|
| 19 |
+
assert buf_cols % cb.vec_sz == 0
|
| 20 |
+
assert n % buf_cols == 0
|
| 21 |
+
buf_size = buf_cols // cb.vec_sz
|
| 22 |
+
|
| 23 |
+
hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
|
| 24 |
+
Qidxs_T = torch.zeros(n // cb.vec_sz, m, dtype=cb.idx_dtype, device=L.device)
|
| 25 |
+
|
| 26 |
+
device = Wr.device
|
| 27 |
+
Wr = Wr.cpu()
|
| 28 |
+
utils.clean()
|
| 29 |
+
Wr_T = Wr.T.contiguous().to(device)
|
| 30 |
+
|
| 31 |
+
prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
|
| 32 |
+
for cur_col in tqdm(range(n // cb.vec_sz, 0, -buf_size)):
|
| 33 |
+
b_Wr_T = Wr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz * cur_col]
|
| 34 |
+
b_hatWr_T = hatWr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
|
| 35 |
+
cur_col]
|
| 36 |
+
b_L = L[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
|
| 37 |
+
cur_col].contiguous()
|
| 38 |
+
b_prod = prod_cache[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
|
| 39 |
+
cur_col]
|
| 40 |
+
b_Qidxs_T = Qidxs_T[(cur_col - buf_size):cur_col]
|
| 41 |
+
L_offset = cb.vec_sz * (cur_col - buf_size)
|
| 42 |
+
for i in reversed(range(buf_size)):
|
| 43 |
+
WXWX = b_Wr_T[cb.vec_sz * i : cb.vec_sz * (i + 1)] + \
|
| 44 |
+
b_L[cb.vec_sz * (i + 1):, L_offset + cb.vec_sz * i : L_offset + cb.vec_sz * (i + 1)].T @ \
|
| 45 |
+
(b_Wr_T[cb.vec_sz * (i + 1):] - b_hatWr_T[cb.vec_sz * (i + 1):]) + \
|
| 46 |
+
b_prod[cb.vec_sz * i : cb.vec_sz * (i + 1)]
|
| 47 |
+
|
| 48 |
+
q_out = cb.quantize(WXWX.T)
|
| 49 |
+
b_hatWr_T[cb.vec_sz * i:cb.vec_sz * (i + 1)] = q_out[0].T
|
| 50 |
+
b_Qidxs_T[i:(i + 1)] = q_out[1].T
|
| 51 |
+
|
| 52 |
+
prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
|
| 53 |
+
hatWr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
|
| 54 |
+
cur_col] = b_hatWr_T
|
| 55 |
+
|
| 56 |
+
del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
|
| 57 |
+
utils.clean()
|
| 58 |
+
return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def LDLQ(Wr, L, cb, args, buf_cols=128, for_kernel=True):
|
| 62 |
+
if for_kernel:
|
| 63 |
+
assert args.td_x == 16 and args.td_y == 16
|
| 64 |
+
buf_cols = max(buf_cols, args.td_y)
|
| 65 |
+
trellissz = args.td_x * args.td_y
|
| 66 |
+
(m, n) = Wr.shape
|
| 67 |
+
assert buf_cols % args.td_y == 0
|
| 68 |
+
assert n % buf_cols == 0
|
| 69 |
+
assert args.td_y % args.V == 0
|
| 70 |
+
buf_size = buf_cols // args.td_y
|
| 71 |
+
|
| 72 |
+
hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
|
| 73 |
+
Qidxs_T = torch.zeros(n // args.V, m, dtype=cb.idx_dtype, device=L.device)
|
| 74 |
+
|
| 75 |
+
device = Wr.device
|
| 76 |
+
Wr = Wr.cpu()
|
| 77 |
+
utils.clean()
|
| 78 |
+
Wr_T = Wr.T.contiguous().to(device)
|
| 79 |
+
|
| 80 |
+
# quip
|
| 81 |
+
prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
|
| 82 |
+
for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
|
| 83 |
+
b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
|
| 84 |
+
b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 85 |
+
cur_col]
|
| 86 |
+
b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
|
| 87 |
+
cur_col].contiguous()
|
| 88 |
+
b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
|
| 89 |
+
cur_col]
|
| 90 |
+
b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
|
| 91 |
+
args.V:args.td_y * cur_col // args.V]
|
| 92 |
+
L_offset = args.td_y * (cur_col - buf_size)
|
| 93 |
+
for i in reversed(range(buf_size)):
|
| 94 |
+
WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
|
| 95 |
+
b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
|
| 96 |
+
(b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
|
| 97 |
+
b_prod[args.td_y * i : args.td_y * (i + 1)]
|
| 98 |
+
if trellissz > -1:
|
| 99 |
+
WXWXshape = WXWX.shape
|
| 100 |
+
thing = WXWX.T.reshape(-1, trellissz)
|
| 101 |
+
if for_kernel:
|
| 102 |
+
thing = thing[..., _PERMUTE]
|
| 103 |
+
q_out = cb.quantize(thing)
|
| 104 |
+
if for_kernel:
|
| 105 |
+
thing = q_out[0][..., _INV_PERMUTE].reshape(
|
| 106 |
+
WXWXshape[1], WXWXshape[0])
|
| 107 |
+
else:
|
| 108 |
+
thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
|
| 109 |
+
idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
|
| 110 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
|
| 111 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 112 |
+
(i + 1)] = idxs.T
|
| 113 |
+
else:
|
| 114 |
+
q_out = cb.quantize(WXWX.T)
|
| 115 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
|
| 116 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 117 |
+
(i + 1)] = q_out[1].T
|
| 118 |
+
|
| 119 |
+
prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
|
| 120 |
+
hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 121 |
+
cur_col] = b_hatWr_T
|
| 122 |
+
|
| 123 |
+
del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
|
| 124 |
+
utils.clean()
|
| 125 |
+
return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def LDLQ_combt(Wr, L, cb1, cb2, args, buf_cols=128, for_kernel=True):
|
| 129 |
+
if for_kernel:
|
| 130 |
+
assert args.td_x == 16 and args.td_y == 16
|
| 131 |
+
buf_cols = max(buf_cols, args.td_y)
|
| 132 |
+
trellissz = args.td_x * args.td_y
|
| 133 |
+
(m, n) = Wr.shape
|
| 134 |
+
assert buf_cols % args.td_y == 0
|
| 135 |
+
assert n % buf_cols == 0
|
| 136 |
+
assert args.td_y % args.V == 0
|
| 137 |
+
buf_size = buf_cols // args.td_y
|
| 138 |
+
|
| 139 |
+
hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
|
| 140 |
+
Qidxs_T = torch.zeros(n // args.V, m, dtype=cb1.idx_dtype, device=L.device)
|
| 141 |
+
|
| 142 |
+
device = Wr.device
|
| 143 |
+
Wr = Wr.cpu()
|
| 144 |
+
utils.clean()
|
| 145 |
+
Wr_T = Wr.T.contiguous().to(device)
|
| 146 |
+
|
| 147 |
+
# quip
|
| 148 |
+
prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
|
| 149 |
+
|
| 150 |
+
flag_for_cb1_compile = True
|
| 151 |
+
for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
|
| 152 |
+
b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
|
| 153 |
+
b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 154 |
+
cur_col]
|
| 155 |
+
b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
|
| 156 |
+
cur_col].contiguous()
|
| 157 |
+
b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
|
| 158 |
+
cur_col]
|
| 159 |
+
b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
|
| 160 |
+
args.V:args.td_y * cur_col // args.V]
|
| 161 |
+
L_offset = args.td_y * (cur_col - buf_size)
|
| 162 |
+
for i in reversed(range(buf_size)):
|
| 163 |
+
WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
|
| 164 |
+
b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
|
| 165 |
+
(b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
|
| 166 |
+
b_prod[args.td_y * i : args.td_y * (i + 1)]
|
| 167 |
+
if trellissz > -1:
|
| 168 |
+
WXWXshape = WXWX.shape
|
| 169 |
+
thing = WXWX.T.reshape(-1, trellissz)
|
| 170 |
+
if for_kernel:
|
| 171 |
+
thing = thing[..., _PERMUTE]
|
| 172 |
+
if args.td_y * (cur_col - buf_size) >= n // 2:
|
| 173 |
+
q_out = cb2.quantize(thing)
|
| 174 |
+
else:
|
| 175 |
+
if flag_for_cb1_compile:
|
| 176 |
+
torch._dynamo.reset()
|
| 177 |
+
flag_for_cb1_compile = False
|
| 178 |
+
q_out = cb1.quantize(thing)
|
| 179 |
+
if for_kernel:
|
| 180 |
+
thing = q_out[0][..., _INV_PERMUTE].reshape(
|
| 181 |
+
WXWXshape[1], WXWXshape[0])
|
| 182 |
+
else:
|
| 183 |
+
thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
|
| 184 |
+
idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
|
| 185 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
|
| 186 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 187 |
+
(i + 1)] = idxs.T
|
| 188 |
+
else:
|
| 189 |
+
if args.td_y * (cur_col - buf_size) >= n // 2:
|
| 190 |
+
q_out = cb2.quantize(WXWX.T)
|
| 191 |
+
else:
|
| 192 |
+
q_out = cb1.quantize(WXWX.T)
|
| 193 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
|
| 194 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 195 |
+
(i + 1)] = q_out[1].T
|
| 196 |
+
|
| 197 |
+
prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
|
| 198 |
+
hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 199 |
+
cur_col] = b_hatWr_T
|
| 200 |
+
|
| 201 |
+
del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
|
| 202 |
+
utils.clean()
|
| 203 |
+
return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
|
lib/algo/ldlq_beam_cd.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import glog
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import time
|
| 8 |
+
from lib import utils
|
| 9 |
+
|
| 10 |
+
_PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
|
| 11 |
+
4).flatten()
|
| 12 |
+
|
| 13 |
+
_PERMUTE_HALF = torch.arange(128).reshape(2, 8, 2, 4, 1).permute(1, 3, 2, 0,
|
| 14 |
+
4).flatten()
|
| 15 |
+
_INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
|
| 16 |
+
_INV_PERMUTE[_PERMUTE] = torch.arange(256)
|
| 17 |
+
_INV_PERMUTE_HALF = torch.zeros(128, dtype=torch.int64)
|
| 18 |
+
_INV_PERMUTE_HALF[_PERMUTE_HALF] = torch.arange(128)
|
| 19 |
+
|
| 20 |
+
def LDLQ(Wr, L, cb, args, D=None, buf_cols=128, for_kernel=True, use_beam_search=False, use_diag=False):
|
| 21 |
+
if for_kernel:
|
| 22 |
+
assert args.td_x == 16 and args.td_y == 16
|
| 23 |
+
buf_cols = max(buf_cols, args.td_y)
|
| 24 |
+
trellissz = args.td_x * args.td_y
|
| 25 |
+
(m, n) = Wr.shape
|
| 26 |
+
assert buf_cols % args.td_y == 0
|
| 27 |
+
assert n % buf_cols == 0
|
| 28 |
+
assert args.td_y % args.V == 0
|
| 29 |
+
buf_size = buf_cols // args.td_y
|
| 30 |
+
|
| 31 |
+
hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
|
| 32 |
+
Qidxs_T = torch.zeros(n // args.V, m, dtype=cb.idx_dtype, device=L.device)
|
| 33 |
+
|
| 34 |
+
device = Wr.device
|
| 35 |
+
Wr = Wr.cpu()
|
| 36 |
+
utils.clean()
|
| 37 |
+
Wr_T = Wr.T.contiguous().to(device)
|
| 38 |
+
|
| 39 |
+
# quip
|
| 40 |
+
prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
|
| 41 |
+
for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
|
| 42 |
+
b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
|
| 43 |
+
b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 44 |
+
cur_col]
|
| 45 |
+
b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
|
| 46 |
+
cur_col].contiguous()
|
| 47 |
+
b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
|
| 48 |
+
cur_col]
|
| 49 |
+
b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
|
| 50 |
+
args.V:args.td_y * cur_col // args.V]
|
| 51 |
+
L_offset = args.td_y * (cur_col - buf_size)
|
| 52 |
+
for i in reversed(range(buf_size)):
|
| 53 |
+
WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
|
| 54 |
+
b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
|
| 55 |
+
(b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
|
| 56 |
+
b_prod[args.td_y * i : args.td_y * (i + 1)]
|
| 57 |
+
if trellissz > -1:
|
| 58 |
+
WXWXshape = WXWX.shape
|
| 59 |
+
thing = WXWX.T.reshape(-1, trellissz)
|
| 60 |
+
if for_kernel:
|
| 61 |
+
thing = thing[..., _PERMUTE]
|
| 62 |
+
if use_beam_search:
|
| 63 |
+
# D: (n // td_y, td_y, td_y)
|
| 64 |
+
D_cur = D[cur_col - buf_size + i] # (td_y, td_y)
|
| 65 |
+
D_tiled = torch.kron(torch.eye(args.td_y, device=D_cur.device, dtype=D_cur.dtype), D_cur)
|
| 66 |
+
if for_kernel:
|
| 67 |
+
D_tiled = D_tiled[:, _PERMUTE][_PERMUTE, :]
|
| 68 |
+
q_out = cb.quantize_beam_search_with_hessian(thing, D_tiled, beam_sz=1024)
|
| 69 |
+
else:
|
| 70 |
+
if use_diag:
|
| 71 |
+
D_cur = D[cur_col - buf_size + i] # (td_y, td_y)
|
| 72 |
+
weight = torch.diag(D_cur).repeat(trellissz // args.td_y)[_PERMUTE]
|
| 73 |
+
q_out = cb.quantize(thing, w2=weight)
|
| 74 |
+
else:
|
| 75 |
+
q_out = cb.quantize(thing)
|
| 76 |
+
if for_kernel:
|
| 77 |
+
thing = q_out[0][..., _INV_PERMUTE].reshape(
|
| 78 |
+
WXWXshape[1], WXWXshape[0])
|
| 79 |
+
else:
|
| 80 |
+
thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
|
| 81 |
+
idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
|
| 82 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
|
| 83 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 84 |
+
(i + 1)] = idxs.T
|
| 85 |
+
else:
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
# q_out = cb.quantize(WXWX.T)
|
| 88 |
+
# b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
|
| 89 |
+
# b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 90 |
+
# (i + 1)] = q_out[1].T
|
| 91 |
+
|
| 92 |
+
prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
|
| 93 |
+
hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 94 |
+
cur_col] = b_hatWr_T
|
| 95 |
+
|
| 96 |
+
del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
|
| 97 |
+
utils.clean()
|
| 98 |
+
return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
|
| 99 |
+
|
| 100 |
+
def calc_obj(hatWr_T, Wr_T, HRr):
|
| 101 |
+
diff_T = hatWr_T.cuda() - Wr_T.cuda()
|
| 102 |
+
obj = torch.trace(diff_T.T @ HRr @ diff_T)
|
| 103 |
+
return obj.cpu().item()
|
| 104 |
+
|
| 105 |
+
def CD(Wr, HRr, Qidxs, hatWr, cb, args, buf_cols=128, for_kernel=True, use_beam_search=False):
|
| 106 |
+
if for_kernel:
|
| 107 |
+
assert args.td_x == 16 and args.td_y == 16
|
| 108 |
+
buf_cols = max(buf_cols, args.td_y)
|
| 109 |
+
trellissz = args.td_x * args.td_y
|
| 110 |
+
(m, n) = Wr.shape
|
| 111 |
+
assert buf_cols % args.td_y == 0
|
| 112 |
+
assert n % buf_cols == 0
|
| 113 |
+
assert args.td_y % args.V == 0
|
| 114 |
+
buf_size = buf_cols // args.td_y
|
| 115 |
+
|
| 116 |
+
hatWr_T = hatWr.T.contiguous()
|
| 117 |
+
Qidxs_T = Qidxs.T.contiguous()
|
| 118 |
+
device = hatWr.device
|
| 119 |
+
hatWr = hatWr.cpu()
|
| 120 |
+
utils.clean()
|
| 121 |
+
Wr_T = Wr.T.contiguous().to(device)
|
| 122 |
+
|
| 123 |
+
# obj = calc_obj(hatWr_T, Wr_T, HRr)
|
| 124 |
+
# print("init obj", obj)
|
| 125 |
+
for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
|
| 126 |
+
b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
|
| 127 |
+
b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 128 |
+
cur_col]
|
| 129 |
+
b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
|
| 130 |
+
args.V:args.td_y * cur_col // args.V]
|
| 131 |
+
b_HRr = HRr[args.td_y * (cur_col - buf_size):args.td_y * cur_col, args.td_y * (cur_col - buf_size):args.td_y * cur_col] # (buf_size * td_y, buf_size * td_y)
|
| 132 |
+
|
| 133 |
+
# update global hessian
|
| 134 |
+
res_inds = torch.cat([
|
| 135 |
+
torch.arange(0, (cur_col - buf_size) * args.td_y, device=device),
|
| 136 |
+
torch.arange(cur_col * args.td_y, n, device=device)
|
| 137 |
+
])
|
| 138 |
+
Wr_diff_T = hatWr_T - Wr_T # (n, m)
|
| 139 |
+
b_global_hess = torch.matmul(Wr_diff_T[res_inds].T, HRr[res_inds, args.td_y * (cur_col - buf_size):args.td_y * cur_col]) # (m, buf_size * td_y)
|
| 140 |
+
for i in reversed(range(buf_size)):
|
| 141 |
+
start_col, end_col = args.td_y * i, args.td_y * (i + 1)
|
| 142 |
+
WXWX = b_Wr_T[start_col:end_col] # (td_y, m)
|
| 143 |
+
b_Wr_diff_T = b_hatWr_T - b_Wr_T # (td_y * buf_size, m)
|
| 144 |
+
if trellissz > -1:
|
| 145 |
+
WXWXshape = WXWX.shape
|
| 146 |
+
thing = WXWX.T.reshape(-1, trellissz) # (-1, trellissz)
|
| 147 |
+
if for_kernel:
|
| 148 |
+
thing = thing[..., _PERMUTE]
|
| 149 |
+
|
| 150 |
+
# local hessian
|
| 151 |
+
HRr_cur = b_HRr[start_col:end_col, start_col:end_col].contiguous() # (td_y, td_y)
|
| 152 |
+
HRr_tiled = torch.kron(torch.eye(args.td_y, device=HRr_cur.device, dtype=HRr_cur.dtype), HRr_cur)
|
| 153 |
+
|
| 154 |
+
# global hessian
|
| 155 |
+
cur_global_hess = b_global_hess[:, start_col:end_col] # (m, td_y)
|
| 156 |
+
cur_res_ind = torch.cat([
|
| 157 |
+
torch.arange(0, start_col, device=device),
|
| 158 |
+
torch.arange(end_col, buf_size * args.td_y, device=device)
|
| 159 |
+
]) # 나머지 indices for args.td_y * i : args.td_y * (i + 1)
|
| 160 |
+
|
| 161 |
+
cur_global_hess_res = torch.matmul(b_Wr_diff_T[cur_res_ind].T, b_HRr[cur_res_ind, start_col:end_col]) # (m, td_y)
|
| 162 |
+
cur_weight = cur_global_hess + cur_global_hess_res # (m, td_y)
|
| 163 |
+
cur_weight = cur_weight.reshape(-1, trellissz)
|
| 164 |
+
if for_kernel:
|
| 165 |
+
cur_weight = cur_weight[..., _PERMUTE] # (-1, trellissz)
|
| 166 |
+
HRr_tiled = HRr_tiled[:, _PERMUTE][_PERMUTE, :]
|
| 167 |
+
|
| 168 |
+
cur_hatWr_T = b_hatWr_T[start_col:end_col].T.reshape(-1, trellissz)[..., _PERMUTE].contiguous() # (-1, trellissz)
|
| 169 |
+
cur_qidx = b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V * (i + 1)].T.reshape(-1, trellissz // args.V) # (-1, trellissz)
|
| 170 |
+
diff = cur_hatWr_T - thing
|
| 171 |
+
obj_before = torch.diag(diff @ HRr_tiled @ diff.T) + torch.sum(cur_weight * diff, dim=-1) * 2 # (-1)
|
| 172 |
+
|
| 173 |
+
if use_beam_search:
|
| 174 |
+
q_out = cb.quantize_beam_search_with_hessian(thing, HRr_tiled, U=cur_weight, beam_sz=1024)
|
| 175 |
+
else:
|
| 176 |
+
q_out = cb.quantize(thing, w1=cur_weight * 2, w2=torch.diag(HRr_tiled))
|
| 177 |
+
diff = q_out[0] - thing
|
| 178 |
+
obj_after = torch.diag(diff @ HRr_tiled @ diff.T) + torch.sum(cur_weight * diff, dim=-1) * 2 # (-1)
|
| 179 |
+
|
| 180 |
+
# select only improved
|
| 181 |
+
improved = obj_before > obj_after
|
| 182 |
+
# out[i] = q_out[0][i] if improved[i] else cur_hatWr_T[i]
|
| 183 |
+
new_hatWr_T = torch.where(improved.unsqueeze(-1), q_out[0], cur_hatWr_T)
|
| 184 |
+
new_qidx = torch.where(improved.unsqueeze(-1), q_out[1], cur_qidx)
|
| 185 |
+
|
| 186 |
+
if for_kernel:
|
| 187 |
+
thing = new_hatWr_T[..., _INV_PERMUTE].reshape(
|
| 188 |
+
WXWXshape[1], WXWXshape[0])
|
| 189 |
+
else:
|
| 190 |
+
thing = new_hatWr_T.reshape(WXWXshape[1], WXWXshape[0])
|
| 191 |
+
idxs = new_qidx.reshape(WXWXshape[1], WXWXshape[0] // args.V)
|
| 192 |
+
b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
|
| 193 |
+
b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
|
| 194 |
+
(i + 1)] = idxs.T
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 198 |
+
cur_col] = b_hatWr_T
|
| 199 |
+
else:
|
| 200 |
+
raise NotImplementedError
|
| 201 |
+
hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
|
| 202 |
+
cur_col] = b_hatWr_T
|
| 203 |
+
|
| 204 |
+
# obj = calc_obj(hatWr_T, Wr_T, HRr)
|
| 205 |
+
# print("cur_col", cur_col, "obj", obj)
|
| 206 |
+
|
| 207 |
+
del b_Wr_T, b_hatWr_T
|
| 208 |
+
utils.clean()
|
| 209 |
+
return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
|
lib/codebook/__pycache__/bitshift.cpython-311.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
lib/codebook/__pycache__/vq_codebook.cpython-311.pyc
ADDED
|
Binary file (3.9 kB). View file
|
|
|
lib/codebook/bitshift.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from functools import cache
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from lib.utils.kernel_check import has_kernel
|
| 12 |
+
from lib.utils.kernel_decompress import decode_compressed, bitshift_linear_kernel
|
| 13 |
+
from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
def decode_1mad(x):
|
| 17 |
+
x = x.to(torch.int64)
|
| 18 |
+
x = x & ((1 << 32) - 1)
|
| 19 |
+
x = x * 34038481 + 76625530
|
| 20 |
+
x = x & ((1 << 32) - 1)
|
| 21 |
+
y = (x & 255) + ((x >> 8) & 255) + ((x >> 16) & 255) + ((x >> 24) & 255)
|
| 22 |
+
y = y - 510
|
| 23 |
+
y = y.to(torch.float32)
|
| 24 |
+
y = y / 147.800537109375
|
| 25 |
+
return y
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def decode_2mad(x):
|
| 29 |
+
x = x.to(torch.int64)
|
| 30 |
+
x = x & ((1 << 32) - 1)
|
| 31 |
+
x = x * 264435761 + 1013904223
|
| 32 |
+
x = x & ((1 << 32) - 1)
|
| 33 |
+
x = ((x * 1664525) >> 32) + x
|
| 34 |
+
x = x & ((1 << 32) - 1)
|
| 35 |
+
y = (x & 255) + ((x >> 8) & 255) + ((x >> 16) & 255) + ((x >> 24) & 255)
|
| 36 |
+
y = y - 510
|
| 37 |
+
y = y.to(torch.float32)
|
| 38 |
+
y = y / 147.800537109375
|
| 39 |
+
return y
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def decode_3inst(x):
|
| 43 |
+
|
| 44 |
+
def bfe16_to_fp16(x):
|
| 45 |
+
x[torch.where(x >= 2**15)] -= 2**16
|
| 46 |
+
return torch.tensor(x.to(torch.int16).numpy().view(np.float16))
|
| 47 |
+
|
| 48 |
+
a = 89226354
|
| 49 |
+
b = 64248484
|
| 50 |
+
fpmask = 996162400
|
| 51 |
+
x = x.to(torch.int64)
|
| 52 |
+
x = x & ((1 << 32) - 1)
|
| 53 |
+
x = x * a + b
|
| 54 |
+
mask = (1 << 15) + ((1 << 12) - 1)
|
| 55 |
+
mask = (mask << 16) + mask
|
| 56 |
+
res = (mask & x) ^ fpmask
|
| 57 |
+
top = bfe16_to_fp16(res >> 16)
|
| 58 |
+
bottom = bfe16_to_fp16(res & ((1 << 16) - 1))
|
| 59 |
+
return (top + bottom).float()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def quantlut(tlut, L, nbits):
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
lut = torch.arange(1 << L)
|
| 65 |
+
lut = (lut + 1) * lut
|
| 66 |
+
lut = (lut >> (16 - nbits)) & ((1 << nbits) - 1)
|
| 67 |
+
lut = tlut[lut]
|
| 68 |
+
return lut
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def quantlut_sym(tlut, L, nbits):
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
lut = torch.arange(1 << L, device=tlut.device)
|
| 74 |
+
lut = (lut + 1) * lut
|
| 75 |
+
sflp = 1 - ((lut >> 15) & 1) * 2
|
| 76 |
+
lut = (lut >> (16 - nbits - 1)) & ((1 << nbits) - 1)
|
| 77 |
+
lut = tlut[lut]
|
| 78 |
+
lut[:, 0] = lut[:, 0] * sflp
|
| 79 |
+
return lut
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class bitshift_codebook(nn.Module):
|
| 83 |
+
|
| 84 |
+
def __init__(self,
|
| 85 |
+
L=16,
|
| 86 |
+
KV=4,
|
| 87 |
+
V=2,
|
| 88 |
+
tlut_bits=16,
|
| 89 |
+
decode_mode='lut',
|
| 90 |
+
tlut=None):
|
| 91 |
+
super(bitshift_codebook, self).__init__()
|
| 92 |
+
self.idx_dtype = torch.int32
|
| 93 |
+
self.opt_scale = 1
|
| 94 |
+
|
| 95 |
+
self.L = L
|
| 96 |
+
self.KV = KV
|
| 97 |
+
self.V = V
|
| 98 |
+
self.tlut_bits = tlut_bits
|
| 99 |
+
self.decode_mode = decode_mode
|
| 100 |
+
|
| 101 |
+
if decode_mode == 'lut':
|
| 102 |
+
if tlut is None:
|
| 103 |
+
assert tlut_bits == L
|
| 104 |
+
self.register_buffer('tlut', torch.randn(2**L, V))
|
| 105 |
+
self.register_buffer('lut', self.tlut.T.contiguous())
|
| 106 |
+
else:
|
| 107 |
+
self.tlut = tlut
|
| 108 |
+
self.recons_lut()
|
| 109 |
+
|
| 110 |
+
elif decode_mode == '1mad':
|
| 111 |
+
assert V == 1
|
| 112 |
+
self.register_buffer('lut',
|
| 113 |
+
decode_1mad(torch.arange(2**L)).unsqueeze(0))
|
| 114 |
+
elif decode_mode == '2mad':
|
| 115 |
+
assert V == 1
|
| 116 |
+
self.register_buffer('lut',
|
| 117 |
+
decode_2mad(torch.arange(2**L)).unsqueeze(0))
|
| 118 |
+
elif decode_mode == '3inst':
|
| 119 |
+
assert V == 1
|
| 120 |
+
self.register_buffer('lut',
|
| 121 |
+
decode_3inst(torch.arange(2**L)).unsqueeze(0))
|
| 122 |
+
elif decode_mode == 'quantlut':
|
| 123 |
+
if tlut is None:
|
| 124 |
+
assert tlut_bits > 0
|
| 125 |
+
if V == 1:
|
| 126 |
+
tlut = torch.erfinv((torch.arange(1 << tlut_bits) + 0.5) /
|
| 127 |
+
(1 << tlut_bits) * 2 -
|
| 128 |
+
1) * torch.tensor(2.0).sqrt()
|
| 129 |
+
elif V == 2:
|
| 130 |
+
n = 2**tlut_bits
|
| 131 |
+
tlut = torch.zeros(n)
|
| 132 |
+
R = ((n / (n - torch.arange(n))).log() * 2).sqrt()
|
| 133 |
+
tlut = torch.stack(
|
| 134 |
+
[R * torch.arange(n).sin(), R * torch.arange(n).cos()],
|
| 135 |
+
dim=-1)
|
| 136 |
+
else:
|
| 137 |
+
raise Exception
|
| 138 |
+
self.register_buffer('tlut', tlut.unsqueeze(-1))
|
| 139 |
+
self.register_buffer(
|
| 140 |
+
'lut',
|
| 141 |
+
quantlut(self.tlut, L, tlut_bits).T.contiguous())
|
| 142 |
+
else:
|
| 143 |
+
self.tlut = tlut
|
| 144 |
+
self.recons_lut()
|
| 145 |
+
elif decode_mode == 'quantlut_sym':
|
| 146 |
+
if tlut is None:
|
| 147 |
+
assert tlut_bits > 0
|
| 148 |
+
if V == 2:
|
| 149 |
+
fname = f'assets/lut_cache/kmeans_{tlut_bits}_{V}.pt'
|
| 150 |
+
if not os.path.exists(fname):
|
| 151 |
+
tlut = torch.randn(2**tlut_bits, V)
|
| 152 |
+
import scipy
|
| 153 |
+
data = torch.randn(1 << 20, 2)
|
| 154 |
+
clusters = scipy.cluster.vq.kmeans(data, tlut)
|
| 155 |
+
tlut = torch.tensor(clusters[0])
|
| 156 |
+
tlut = (tlut /
|
| 157 |
+
tlut.std(unbiased=False)) * 0.9682458365518543
|
| 158 |
+
torch.save(tlut, fname)
|
| 159 |
+
else:
|
| 160 |
+
tlut = torch.load(fname)
|
| 161 |
+
else:
|
| 162 |
+
raise Exception
|
| 163 |
+
self.register_buffer('tlut', tlut)
|
| 164 |
+
self.register_buffer(
|
| 165 |
+
'lut',
|
| 166 |
+
quantlut_sym(self.tlut, L, tlut_bits).T.contiguous())
|
| 167 |
+
else:
|
| 168 |
+
self.tlut = tlut
|
| 169 |
+
self.recons_lut()
|
| 170 |
+
else:
|
| 171 |
+
raise Exception
|
| 172 |
+
|
| 173 |
+
self.fakeinf = torch.tensor(torch.inf)
|
| 174 |
+
|
| 175 |
+
self.register_buffer('sumdelta',
|
| 176 |
+
torch.arange(2**(KV)) << (L - KV))
|
| 177 |
+
self.sumdelta = self.sumdelta.view(1, 1, -1)
|
| 178 |
+
|
| 179 |
+
self.register_buffer('state', torch.arange(2**L).unsqueeze(0))
|
| 180 |
+
self.register_buffer('state_cand',
|
| 181 |
+
(self.state >>
|
| 182 |
+
(KV))[0, ::2**(KV)].unsqueeze(-1) +
|
| 183 |
+
self.sumdelta)
|
| 184 |
+
self.register_buffer('recons_state', self.recons(self.state))
|
| 185 |
+
|
| 186 |
+
self.version = 0
|
| 187 |
+
|
| 188 |
+
def recons_lut(self):
|
| 189 |
+
if self.decode_mode == 'lut':
|
| 190 |
+
self.lut = self.tlut.T.contiguous()
|
| 191 |
+
elif self.decode_mode == 'quantlut':
|
| 192 |
+
self.lut = quantlut(self.tlut, self.L,
|
| 193 |
+
self.tlut_bits).T.contiguous()
|
| 194 |
+
elif self.decode_mode == 'quantlut_sym':
|
| 195 |
+
self.lut = quantlut_sym(self.tlut, self.L,
|
| 196 |
+
self.tlut_bits).T.contiguous()
|
| 197 |
+
|
| 198 |
+
def recons(self, encoded, **kwargs):
|
| 199 |
+
return self.lut[:,
|
| 200 |
+
encoded.int().to(self.lut.device)].to(encoded.device)
|
| 201 |
+
|
| 202 |
+
@torch.compile
|
| 203 |
+
def update(self, cost, thing):
|
| 204 |
+
state_err = (self.recons_state -
|
| 205 |
+
thing.unsqueeze(-1)).square().sum(dim=0)
|
| 206 |
+
cand_cost = torch.gather(
|
| 207 |
+
cost.unsqueeze(-2).expand(-1, self.state_cand.shape[1], -1), -1,
|
| 208 |
+
self.state_cand.expand(len(cost), -1, 2**(self.KV)))
|
| 209 |
+
best = torch.min(cand_cost, dim=-1)
|
| 210 |
+
cost = state_err + best.values.unsqueeze(-1).expand(
|
| 211 |
+
-1, -1, 2**(self.KV)).reshape(state_err.shape)
|
| 212 |
+
prev_state = torch.gather(
|
| 213 |
+
self.state_cand.expand(thing.shape[1], -1, -1), -1,
|
| 214 |
+
best.indices.unsqueeze(-1))[..., 0]
|
| 215 |
+
return prev_state, cost
|
| 216 |
+
|
| 217 |
+
def viterbi(self, X, overlap=None):
|
| 218 |
+
"""
|
| 219 |
+
X (T, B)
|
| 220 |
+
"""
|
| 221 |
+
T, B = X.shape
|
| 222 |
+
assert T % self.V == 0
|
| 223 |
+
# cost is (B, 2**L)
|
| 224 |
+
cost = (self.recons_state -
|
| 225 |
+
X[:self.V].unsqueeze(-1)).square().sum(dim=0)
|
| 226 |
+
|
| 227 |
+
if overlap is not None:
|
| 228 |
+
mask = torch.ones(B, 2**self.L, device=X.device) * self.fakeinf
|
| 229 |
+
allow = (overlap <<
|
| 230 |
+
(self.KV)).unsqueeze(-1) + torch.arange(
|
| 231 |
+
2**(self.KV)).to(X.device).view(1, 1, -1)
|
| 232 |
+
mask.scatter_(1, allow[0], 0)
|
| 233 |
+
cost = torch.min(cost + mask, self.fakeinf)
|
| 234 |
+
|
| 235 |
+
from_state = torch.zeros(T // self.V,
|
| 236 |
+
B,
|
| 237 |
+
2**(self.L - self.KV),
|
| 238 |
+
dtype=self.state.dtype,
|
| 239 |
+
device=self.state.device)
|
| 240 |
+
|
| 241 |
+
for i in range(1, T // self.V):
|
| 242 |
+
from_state[i], cost = self.update(cost,
|
| 243 |
+
X[i * self.V:(i + 1) * self.V])
|
| 244 |
+
|
| 245 |
+
if overlap is not None:
|
| 246 |
+
mask = torch.ones(B, 2**self.L, device=X.device) * self.fakeinf
|
| 247 |
+
allow = (overlap.unsqueeze(-1) + self.sumdelta.unsqueeze(0))
|
| 248 |
+
mask.scatter_(1, allow[0, 0], 0)
|
| 249 |
+
cost = torch.min(cost + mask, self.fakeinf)
|
| 250 |
+
|
| 251 |
+
final_state = torch.zeros(T // self.V,
|
| 252 |
+
B,
|
| 253 |
+
dtype=self.idx_dtype,
|
| 254 |
+
device=X.device)
|
| 255 |
+
final_state[T // self.V - 1] = torch.argmin(cost, dim=-1)
|
| 256 |
+
for i in range(T // self.V - 1, 0, -1):
|
| 257 |
+
final_state[i - 1] = torch.gather(
|
| 258 |
+
from_state[i], -1,
|
| 259 |
+
(final_state[i].to(torch.int64).unsqueeze(-1)) >>
|
| 260 |
+
(self.KV))[..., 0]
|
| 261 |
+
return final_state
|
| 262 |
+
|
| 263 |
+
def quantize_seq(self, X, overlap=None, **kwargs):
|
| 264 |
+
T, NO = X.shape
|
| 265 |
+
bs = min(2**(24 - self.L), NO)
|
| 266 |
+
pad_amt = math.ceil(NO / bs) * bs - NO
|
| 267 |
+
X = torch.nn.functional.pad(X, (0, pad_amt))
|
| 268 |
+
T, N = X.shape
|
| 269 |
+
X = X.reshape(T, N // bs, bs).transpose(0, 1).contiguous()
|
| 270 |
+
if overlap is not None:
|
| 271 |
+
overlap = torch.nn.functional.pad(overlap, (0, pad_amt))
|
| 272 |
+
overlap = overlap.reshape(N // bs, bs)
|
| 273 |
+
|
| 274 |
+
Qidxs = torch.zeros(N // bs,
|
| 275 |
+
T // self.V,
|
| 276 |
+
bs,
|
| 277 |
+
dtype=self.idx_dtype,
|
| 278 |
+
device=X.device)
|
| 279 |
+
for i in range(len(X)):
|
| 280 |
+
b_overlap = None if overlap is None else overlap[i]
|
| 281 |
+
Qidxs[i] = self.viterbi(X[i], overlap=b_overlap)
|
| 282 |
+
Qidxs = Qidxs.transpose(0, 1).reshape(T // self.V, N)[:, :NO]
|
| 283 |
+
return Qidxs
|
| 284 |
+
|
| 285 |
+
def quantize(self, X, **kwargs):
|
| 286 |
+
X = X.T.contiguous().to(torch.float16)
|
| 287 |
+
T = X.shape[0]
|
| 288 |
+
roll_X = torch.roll(X, T // (2 * self.V) * self.V, 0)
|
| 289 |
+
state = self.quantize_seq(roll_X, overlap=None)
|
| 290 |
+
overlap = state[T // (2 * self.V)] >> self.KV
|
| 291 |
+
state = self.quantize_seq(X, overlap=overlap)
|
| 292 |
+
hatX = self.recons(state).transpose(0, 1).reshape(X.shape)
|
| 293 |
+
return hatX.T.contiguous().to(X.device), state.T.contiguous().to(
|
| 294 |
+
X.device)
|
| 295 |
+
|
| 296 |
+
def pack_trellis(self, trellis):
|
| 297 |
+
# T is really T // self.V here
|
| 298 |
+
B, T = trellis.shape
|
| 299 |
+
bf = torch.zeros(B,
|
| 300 |
+
T * self.KV + self.L - self.KV,
|
| 301 |
+
dtype=bool,
|
| 302 |
+
device=trellis.device)
|
| 303 |
+
bf[:, :self.L] = (trellis[:, 0].unsqueeze(-1) & (2**torch.arange(
|
| 304 |
+
self.L, device=trellis.device).flip(dims=(-1, ))).unsqueeze(0)) > 0
|
| 305 |
+
K_mask = 2**torch.arange(
|
| 306 |
+
self.KV,
|
| 307 |
+
device=trellis.device).flip(dims=(-1, )).unsqueeze(0)
|
| 308 |
+
for i in range(1, T):
|
| 309 |
+
assert ((trellis[:, i - 1] &
|
| 310 |
+
((1 << (self.L - self.KV)) - 1)) == (
|
| 311 |
+
trellis[:, i] >> (self.KV))).all()
|
| 312 |
+
bf[:,
|
| 313 |
+
(self.L +
|
| 314 |
+
(i - 1) * self.KV):(self.L + i * self.KV)] = (
|
| 315 |
+
(trellis[:, i] &
|
| 316 |
+
((1 <<
|
| 317 |
+
(self.KV)) - 1)).unsqueeze(-1) & K_mask) > 0
|
| 318 |
+
|
| 319 |
+
bf = bf[:, :-(self.L - self.KV)]
|
| 320 |
+
pad_amt = math.ceil(
|
| 321 |
+
T * self.KV / 16) * 16 - T * self.KV
|
| 322 |
+
bf = torch.nn.functional.pad(bf, (0, pad_amt)).reshape(
|
| 323 |
+
-1, (T * self.KV + pad_amt) // 16, 16)
|
| 324 |
+
|
| 325 |
+
uint_mask = (2**torch.arange(
|
| 326 |
+
16, dtype=torch.int32,
|
| 327 |
+
device=bf.device)).flip(dims=(-1, )).unsqueeze(0).unsqueeze(0)
|
| 328 |
+
bf_sum = (bf.to(torch.int32) * uint_mask).sum(dim=-1)
|
| 329 |
+
return bf_sum.to(torch.uint16)
|
| 330 |
+
|
| 331 |
+
class BitshiftLinear(nn.Module):
|
| 332 |
+
|
| 333 |
+
def __init__(self,
|
| 334 |
+
td_x,
|
| 335 |
+
td_y,
|
| 336 |
+
L,
|
| 337 |
+
K,
|
| 338 |
+
V,
|
| 339 |
+
tlut_bits,
|
| 340 |
+
decode_mode,
|
| 341 |
+
dtype=torch.float16,
|
| 342 |
+
tlut=None,
|
| 343 |
+
has_kernel=False):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.td_x = td_x
|
| 346 |
+
self.td_y = td_y
|
| 347 |
+
self.V = V
|
| 348 |
+
self.cb = bitshift_codebook(L, K, V, tlut_bits, decode_mode, tlut=tlut)
|
| 349 |
+
self.internal_dtype = dtype
|
| 350 |
+
self.has_kernel = has_kernel
|
| 351 |
+
self.scale = 32
|
| 352 |
+
|
| 353 |
+
def get_hatW(self, unpacked_trellis, m, n):
|
| 354 |
+
return self.cb.recons(unpacked_trellis).transpose(0, 1).transpose(
|
| 355 |
+
1, 2).reshape(m // self.td_x, n // self.td_y, self.td_x,
|
| 356 |
+
self.td_y).transpose(1, 2).reshape(m, n)
|
| 357 |
+
|
| 358 |
+
def get_hatW_kernel(self, trellis, m, n):
|
| 359 |
+
out = decode_compressed(self.cb.L, self.cb.tlut_bits, self.cb.K,
|
| 360 |
+
int(math.log2(self.V)), m, n, trellis.view(-1),
|
| 361 |
+
self.cb.lut.T)
|
| 362 |
+
return out
|
| 363 |
+
|
| 364 |
+
def cache_hatW(self, packed_trellis, had_left, had_right, K_left, K_right,
|
| 365 |
+
m, n, rcp, tp_rank):
|
| 366 |
+
if self.has_kernel:
|
| 367 |
+
hatW = self.get_hatW_kernel(packed_trellis, m, n)
|
| 368 |
+
else:
|
| 369 |
+
hatW = self.get_hatW(
|
| 370 |
+
self.cb.unpack_trellis(packed_trellis, self.td_x * self.td_y),
|
| 371 |
+
m, n)
|
| 372 |
+
hatW = hatW.float() / self.scale
|
| 373 |
+
|
| 374 |
+
if rcp == 1:
|
| 375 |
+
self.hatW = matmul_hadU_cuda(
|
| 376 |
+
matmul_hadU_cuda(hatW.reshape(tp_rank * m, n // tp_rank),
|
| 377 |
+
had_left, K_left).reshape(m, n).T, had_right,
|
| 378 |
+
K_right).T.contiguous().to(self.internal_dtype)
|
| 379 |
+
elif rcp == 2:
|
| 380 |
+
self.hatW = matmul_hadU_cuda(
|
| 381 |
+
matmul_hadU_cuda(hatW, had_left,
|
| 382 |
+
K_left).T.reshape(tp_rank * n,
|
| 383 |
+
m // tp_rank), had_right,
|
| 384 |
+
K_right).reshape(n, m).T.contiguous().to(self.internal_dtype)
|
| 385 |
+
else:
|
| 386 |
+
self.hatW = matmul_hadU_cuda(
|
| 387 |
+
matmul_hadU_cuda(hatW, had_left, K_left).T, had_right,
|
| 388 |
+
K_right).T.contiguous().to(self.internal_dtype)
|
| 389 |
+
|
| 390 |
+
def forward(self,
|
| 391 |
+
input,
|
| 392 |
+
trellis,
|
| 393 |
+
SU,
|
| 394 |
+
SV,
|
| 395 |
+
had_left,
|
| 396 |
+
had_right,
|
| 397 |
+
K_left,
|
| 398 |
+
K_right,
|
| 399 |
+
rcp,
|
| 400 |
+
tp_rank,
|
| 401 |
+
mode='eval',
|
| 402 |
+
use_prev_kernel=False,
|
| 403 |
+
**kwargs):
|
| 404 |
+
n, m = len(SU), len(SV)
|
| 405 |
+
x = input.view(-1, n).to(torch.float32)
|
| 406 |
+
x = x * SU
|
| 407 |
+
|
| 408 |
+
if mode == 'train-fixW':
|
| 409 |
+
x = (x.to(self.internal_dtype) @ self.hatW.T).float()
|
| 410 |
+
else:
|
| 411 |
+
bs = x.shape[0]
|
| 412 |
+
|
| 413 |
+
if rcp == 1:
|
| 414 |
+
x = matmul_hadUt_cuda(x.reshape(-1, n // tp_rank), had_left,
|
| 415 |
+
K_left).reshape(x.shape) / self.scale
|
| 416 |
+
else:
|
| 417 |
+
x = matmul_hadUt_cuda(x, had_left, K_left) / self.scale
|
| 418 |
+
|
| 419 |
+
if bs == 1 and self.has_kernel:
|
| 420 |
+
wrapper = getattr(
|
| 421 |
+
torch.ops.quip_lib,
|
| 422 |
+
f"decompress_gemm_tcq_{m}_1_{x.numel()}_{self.cb.K}")
|
| 423 |
+
|
| 424 |
+
x = wrapper(trellis, x, self.cb.tlut)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
if mode == 'train-recons':
|
| 428 |
+
self.cb.recons_lut()
|
| 429 |
+
|
| 430 |
+
if self.has_kernel:
|
| 431 |
+
if use_prev_kernel:
|
| 432 |
+
x = BitshiftLinearKernelAG.apply(
|
| 433 |
+
x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
|
| 434 |
+
self.V, self.cb.lut).float()
|
| 435 |
+
else:
|
| 436 |
+
x = bitshift_linear_kernel(
|
| 437 |
+
x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
|
| 438 |
+
self.V, self.cb.lut).float()
|
| 439 |
+
else:
|
| 440 |
+
if mode == 'eval':
|
| 441 |
+
trellis = self.cb.unpack_trellis(
|
| 442 |
+
trellis, self.td_x * self.td_y)
|
| 443 |
+
hatW = self.get_hatW(trellis, m, n)
|
| 444 |
+
x = (x.to(hatW.dtype) @ hatW.T).float()
|
| 445 |
+
|
| 446 |
+
if rcp == 2:
|
| 447 |
+
x = matmul_hadU_cuda(x.reshape(-1, m // tp_rank), had_right,
|
| 448 |
+
K_right).reshape(x.shape)
|
| 449 |
+
else:
|
| 450 |
+
x = matmul_hadU_cuda(x, had_right, K_right)
|
| 451 |
+
|
| 452 |
+
x = x.to(SV.device) * (SV * self.scale)
|
| 453 |
+
return x.view(*input.shape[:-1], m).to(input.dtype)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class BitshiftLinearKernelAG(torch.autograd.Function):
|
| 458 |
+
@staticmethod
|
| 459 |
+
def forward(ctx, input, trellis, m, n, L, tlut_bits, K, V, lut):
|
| 460 |
+
ctx.save_for_backward(trellis, lut)
|
| 461 |
+
ctx.L = L
|
| 462 |
+
ctx.tlut_bits = tlut_bits
|
| 463 |
+
ctx.K = K
|
| 464 |
+
ctx.V = V
|
| 465 |
+
ctx.m = m
|
| 466 |
+
ctx.n = n
|
| 467 |
+
|
| 468 |
+
hatW = decode_compressed(L, tlut_bits, K, int(math.log2(V)),
|
| 469 |
+
m, n, trellis.view(-1), lut.T)
|
| 470 |
+
return input.to(hatW.dtype) @ hatW.T
|
| 471 |
+
|
| 472 |
+
@staticmethod
|
| 473 |
+
def backward(ctx, grad_output):
|
| 474 |
+
trellis, lut = ctx.saved_tensors
|
| 475 |
+
L = ctx.L
|
| 476 |
+
tlut_bits = ctx.tlut_bits
|
| 477 |
+
K = ctx.K
|
| 478 |
+
V = ctx.V
|
| 479 |
+
m = ctx.m
|
| 480 |
+
n = ctx.n
|
| 481 |
+
|
| 482 |
+
hatW = decode_compressed(L, tlut_bits, K, int(math.log2(V)),
|
| 483 |
+
m, n, trellis.view(-1), lut.T)
|
| 484 |
+
|
| 485 |
+
grad_input = grad_output.to(hatW.dtype) @ hatW
|
| 486 |
+
return grad_input, None, None, None, None, None, None, None, None
|
lib/codebook/vq_codebook.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from lib.utils.kmeans import kmeans_flash1d, kmeans_sklearn
|
| 7 |
+
|
| 8 |
+
class vq_codebook(nn.Module):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
vec_sz=2,
|
| 11 |
+
lut_bits=8):
|
| 12 |
+
super(vq_codebook, self).__init__()
|
| 13 |
+
self.idx_dtype = torch.int32
|
| 14 |
+
self.vec_sz = vec_sz
|
| 15 |
+
self.lut_bits = lut_bits
|
| 16 |
+
|
| 17 |
+
fname = f'assets/lut_cache/vq_kmeans_{lut_bits}_{vec_sz}.pt'
|
| 18 |
+
if not os.path.exists(fname):
|
| 19 |
+
if vec_sz == 1:
|
| 20 |
+
data = torch.randn(int(1e8), vec_sz)
|
| 21 |
+
tlut = kmeans_flash1d(data, 2**lut_bits)
|
| 22 |
+
elif vec_sz in [2,4]:
|
| 23 |
+
data = torch.randn(int(1e8), vec_sz)
|
| 24 |
+
if lut_bits <= 5:
|
| 25 |
+
tlut = kmeans_sklearn(data, 2**lut_bits, max_data=int(1e8))
|
| 26 |
+
else:
|
| 27 |
+
tlut = kmeans_sklearn(data, 2**lut_bits, max_data=int(1e7))
|
| 28 |
+
torch.save(tlut, fname)
|
| 29 |
+
else:
|
| 30 |
+
tlut = torch.load(fname)
|
| 31 |
+
self.register_buffer("tlut", tlut)
|
| 32 |
+
self.register_buffer("lut", tlut.T.contiguous())
|
| 33 |
+
|
| 34 |
+
def recons(self, encoded, **kwargs):
|
| 35 |
+
return self.tlut[encoded].contiguous()
|
| 36 |
+
|
| 37 |
+
def quantize(self, X, **kwargs):
|
| 38 |
+
"""
|
| 39 |
+
X : [B, vec_sz]
|
| 40 |
+
"""
|
| 41 |
+
dist = torch.cdist(X, self.tlut.to(X.device, dtype=X.dtype)) # [B, 2**lut_bits]
|
| 42 |
+
state = torch.argmin(dist, dim=-1) # [B,] each entry is in [0, 2**lut_bits)
|
| 43 |
+
hatX = self.recons(state)
|
| 44 |
+
return hatX.to(X.device), state.to(X.device)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
for vec_sz in [4]:
|
| 49 |
+
for lut_bits in [6,7,8,9,10,11,12]:
|
| 50 |
+
# for lut_bits in [1,2,3,4,5,6,7,8,9,10,11,12]:
|
| 51 |
+
if vec_sz == 1 and lut_bits > 8:
|
| 52 |
+
continue
|
| 53 |
+
vq = vq_codebook(vec_sz=vec_sz, lut_bits=lut_bits)
|
| 54 |
+
X = torch.randn(int(1e5), vec_sz)
|
| 55 |
+
hatX, state = vq.quantize(X)
|
| 56 |
+
print(f"vec_sz: {vec_sz}, lut_bits: {lut_bits}, mse: {(hatX-X).pow(2).mean()}")
|
lib/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_KEYS = {
|
| 2 |
+
"meta-llama/Llama-3.1-8B": "3_8b",
|
| 3 |
+
"meta-llama/Llama-3.2-1B": "3_1b",
|
| 4 |
+
"meta-llama/Llama-3.2-3B": "3_3b",
|
| 5 |
+
"Qwen/Qwen2.5-7B": "qwen_7b",
|
| 6 |
+
}
|
lib/linear/__init__.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantized_linear import QuantizedLinear
|
| 2 |
+
from .incoherent_linear import IncoherentLinear
|
| 3 |
+
from .vq_linear import VQLinearPackSIMT, VQLinearPackTensorCore
|
| 4 |
+
from .tcq_linear import QTIPLinearTCQ
|
| 5 |
+
from .comb_linear import CombLinearTCQ, CombtLinearTCQ
|
| 6 |
+
import vq_tensor_kernels
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
kernels = [
|
| 10 |
+
(53248, 16384),
|
| 11 |
+
(16384, 53248),
|
| 12 |
+
(1024, 16384),
|
| 13 |
+
(16384, 16384),
|
| 14 |
+
(4096, 14336),
|
| 15 |
+
(14336, 4096),
|
| 16 |
+
(28672, 4096),
|
| 17 |
+
(2048, 4096),
|
| 18 |
+
(5120, 4096),
|
| 19 |
+
(6144, 4096),
|
| 20 |
+
(1024, 4096),
|
| 21 |
+
(4096, 4096),
|
| 22 |
+
(4096, 11008),
|
| 23 |
+
(11008, 4096),
|
| 24 |
+
(22016, 4096),
|
| 25 |
+
(12288, 4096),
|
| 26 |
+
(8192, 4096),
|
| 27 |
+
(8192, 8192),
|
| 28 |
+
(10240, 8192),
|
| 29 |
+
(57344, 8192),
|
| 30 |
+
(8192, 1024),
|
| 31 |
+
(8192, 28672),
|
| 32 |
+
(28672, 8192),
|
| 33 |
+
(1024, 8192),
|
| 34 |
+
(5120, 5120),
|
| 35 |
+
(10240, 5120),
|
| 36 |
+
(15360, 5120),
|
| 37 |
+
(13568, 5120),
|
| 38 |
+
(27136, 5120),
|
| 39 |
+
(5120, 13568),
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
kdict = {}
|
| 43 |
+
for vtype, max_bits, bit_stride in [
|
| 44 |
+
('sq_dup', 4, 1),
|
| 45 |
+
('sq', 8, 1),
|
| 46 |
+
('vq2', 12, 1),
|
| 47 |
+
]:
|
| 48 |
+
for m, k in kernels:
|
| 49 |
+
for n in [1,2,4,8]:
|
| 50 |
+
for bitrate in range(2,max_bits+1, bit_stride):
|
| 51 |
+
|
| 52 |
+
torch.library.define(
|
| 53 |
+
f"ours_lib::decompress_gemm_{m}_{n}_{k}_{bitrate}_{vtype}",
|
| 54 |
+
"(Tensor compressed, Tensor x, Tensor codebook) -> Tensor")
|
| 55 |
+
|
| 56 |
+
name = f"decompress_gemm_{m}_{n}_{k}_{bitrate}_{vtype}"
|
| 57 |
+
kernel_name = f"vq_tensor_kernels.decompress_gemm_{bitrate}_{m}_{n}_{k}_{vtype}"
|
| 58 |
+
exec(f"""\
|
| 59 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 60 |
+
def {name}_abstract(
|
| 61 |
+
compressed: torch.Tensor,
|
| 62 |
+
x: torch.Tensor,
|
| 63 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
|
| 65 |
+
|
| 66 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 67 |
+
def {name}_cuda(
|
| 68 |
+
compressed: torch.Tensor,
|
| 69 |
+
x: torch.Tensor,
|
| 70 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
|
| 72 |
+
{kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
|
| 73 |
+
return out
|
| 74 |
+
""")
|
| 75 |
+
|
| 76 |
+
for bitrate in range(2,max_bits+1, bit_stride):
|
| 77 |
+
name = f"decompress_gemv_{m}_{k}_{bitrate}_{vtype}"
|
| 78 |
+
kernel_name = f"vq_tensor_kernels.decompress_gemm_{bitrate}_{m}_{1}_{k}_{vtype}"
|
| 79 |
+
exec(f"""\
|
| 80 |
+
@torch.library.custom_op("ours_lib::{name}", mutates_args={{"out"}})
|
| 81 |
+
def {name}(
|
| 82 |
+
compressed: torch.Tensor,
|
| 83 |
+
x: torch.Tensor,
|
| 84 |
+
codebook: torch.Tensor,
|
| 85 |
+
out: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
{kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
|
| 87 |
+
@{name}.register_fake
|
| 88 |
+
def {name}_fake(compressed, x, codebook, out):
|
| 89 |
+
return None
|
| 90 |
+
""")
|
| 91 |
+
|
| 92 |
+
for bitrate in range(2,max_bits+1, bit_stride):
|
| 93 |
+
torch.library.define(
|
| 94 |
+
f"ours_lib::decompress_{bitrate}_{vtype}",
|
| 95 |
+
"(Tensor compressed, Tensor codebook, int m, int k) -> Tensor")
|
| 96 |
+
|
| 97 |
+
name = f"decompress_{bitrate}_{vtype}"
|
| 98 |
+
kernel_name = f"vq_tensor_kernels.decompress_{bitrate}_{vtype}"
|
| 99 |
+
exec(f"""\
|
| 100 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 101 |
+
def {name}_abstract(
|
| 102 |
+
compressed: torch.Tensor,
|
| 103 |
+
codebook: torch.Tensor,
|
| 104 |
+
m: int,
|
| 105 |
+
k: int) -> torch.Tensor:
|
| 106 |
+
return torch.zeros(m, k, dtype=torch.float16, device=compresed.device)
|
| 107 |
+
|
| 108 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 109 |
+
def {name}_cuda(
|
| 110 |
+
compressed: torch.Tensor,
|
| 111 |
+
codebook: torch.Tensor,
|
| 112 |
+
m: int,
|
| 113 |
+
k: int) -> torch.Tensor:
|
| 114 |
+
out = torch.zeros((m, k), dtype=torch.float16, device=compressed.device)
|
| 115 |
+
{kernel_name}(out, compressed.reshape(-1).view(torch.int32), codebook.reshape(-1))
|
| 116 |
+
return out
|
| 117 |
+
""")
|
| 118 |
+
|
| 119 |
+
import tcq_kernels
|
| 120 |
+
import torch
|
| 121 |
+
MKSHAPE = [
|
| 122 |
+
(53248, 16384),
|
| 123 |
+
(16384, 53248),
|
| 124 |
+
(1024, 16384),
|
| 125 |
+
(16384, 16384),
|
| 126 |
+
(4096, 14336),
|
| 127 |
+
(14336, 4096),
|
| 128 |
+
(28672, 4096),
|
| 129 |
+
(5120, 4096),
|
| 130 |
+
(6144, 4096),
|
| 131 |
+
(512, 4096),
|
| 132 |
+
(1024, 4096),
|
| 133 |
+
(2048, 4096),
|
| 134 |
+
(4096, 4096),
|
| 135 |
+
(2048, 11008),
|
| 136 |
+
(4096, 11008),
|
| 137 |
+
(5504, 4096),
|
| 138 |
+
(11008, 4096),
|
| 139 |
+
(22016, 4096),
|
| 140 |
+
(12288, 4096),
|
| 141 |
+
(8192, 4096),
|
| 142 |
+
(8192, 8192),
|
| 143 |
+
(10240, 8192),
|
| 144 |
+
(57344, 8192),
|
| 145 |
+
(8192, 1024),
|
| 146 |
+
(8192, 28672),
|
| 147 |
+
(28672, 8192),
|
| 148 |
+
(1024, 8192),
|
| 149 |
+
(5120, 5120),
|
| 150 |
+
(10240, 5120),
|
| 151 |
+
(15360, 5120),
|
| 152 |
+
(13568, 5120),
|
| 153 |
+
(27136, 5120),
|
| 154 |
+
(5120, 13568),
|
| 155 |
+
(3072, 3072),
|
| 156 |
+
(1024, 3072),
|
| 157 |
+
(4096, 3072),
|
| 158 |
+
(2048, 3072),
|
| 159 |
+
(5120, 3072),
|
| 160 |
+
(8192, 3072),
|
| 161 |
+
(16384, 3072),
|
| 162 |
+
(3072, 8192),
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
kdict = {}
|
| 166 |
+
for S in [9, 10, 11]:
|
| 167 |
+
if S == 9:
|
| 168 |
+
bitrate_list = [2,3,4,5,6,7,8,9,10]
|
| 169 |
+
elif S == 10:
|
| 170 |
+
bitrate_list = [8, 9, 10]
|
| 171 |
+
elif S == 11:
|
| 172 |
+
bitrate_list = [9, 10]
|
| 173 |
+
for m, k in MKSHAPE:
|
| 174 |
+
for n in [1,2,4,8]:
|
| 175 |
+
for bitrate in bitrate_list:
|
| 176 |
+
torch.library.define(
|
| 177 |
+
f"ours_lib::decompress_gemm_tcq_{m}_{n}_{k}_{S}_{bitrate}",
|
| 178 |
+
"(Tensor compressed, Tensor x, Tensor codebook) -> Tensor")
|
| 179 |
+
|
| 180 |
+
name = f"decompress_gemm_tcq_{m}_{n}_{k}_{S}_{bitrate}"
|
| 181 |
+
kernel_name = f"tcq_kernels.decompress_gemm_16_{S}_{bitrate}_1_{m}_{n}_{k}"
|
| 182 |
+
exec(f"""\
|
| 183 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 184 |
+
def {name}_abstract(
|
| 185 |
+
compressed: torch.Tensor,
|
| 186 |
+
x: torch.Tensor,
|
| 187 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
|
| 189 |
+
|
| 190 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 191 |
+
def {name}_cuda(
|
| 192 |
+
compressed: torch.Tensor,
|
| 193 |
+
x: torch.Tensor,
|
| 194 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 195 |
+
out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
|
| 196 |
+
{kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
|
| 197 |
+
return out
|
| 198 |
+
""")
|
| 199 |
+
if bitrate == bitrate_list[-1]:
|
| 200 |
+
continue
|
| 201 |
+
torch.library.define(
|
| 202 |
+
f"ours_lib::decompress_gemm_tcq_comb_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}",
|
| 203 |
+
"(Tensor compressed1, Tensor compressed2, Tensor x, Tensor codebook) -> Tensor")
|
| 204 |
+
|
| 205 |
+
name = f"decompress_gemm_tcq_comb_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}"
|
| 206 |
+
kernel_name = f"tcq_kernels.decompress_gemm_comb_16_{S}_{bitrate}_{int(bitrate+1)}_1_{m}_{n}_{k}"
|
| 207 |
+
exec(f"""\
|
| 208 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 209 |
+
def {name}_abstract(
|
| 210 |
+
compressed1: torch.Tensor,
|
| 211 |
+
compressed2: torch.Tensor,
|
| 212 |
+
x: torch.Tensor,
|
| 213 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
|
| 215 |
+
|
| 216 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 217 |
+
def {name}_cuda(
|
| 218 |
+
compressed1: torch.Tensor,
|
| 219 |
+
compressed2: torch.Tensor,
|
| 220 |
+
x: torch.Tensor,
|
| 221 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 222 |
+
out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
|
| 223 |
+
{kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
|
| 224 |
+
return out
|
| 225 |
+
""")
|
| 226 |
+
torch.library.define(
|
| 227 |
+
f"ours_lib::decompress_gemm_tcq_combt_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}",
|
| 228 |
+
"(Tensor compressed1, Tensor compressed2, Tensor x, Tensor codebook) -> Tensor")
|
| 229 |
+
|
| 230 |
+
name = f"decompress_gemm_tcq_combt_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}"
|
| 231 |
+
kernel_name = f"tcq_kernels.decompress_gemm_combt_16_{S}_{bitrate}_{int(bitrate+1)}_1_{m}_{n}_{k}"
|
| 232 |
+
exec(f"""\
|
| 233 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 234 |
+
def {name}_abstract(
|
| 235 |
+
compressed1: torch.Tensor,
|
| 236 |
+
compressed2: torch.Tensor,
|
| 237 |
+
x: torch.Tensor,
|
| 238 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
|
| 240 |
+
|
| 241 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 242 |
+
def {name}_cuda(
|
| 243 |
+
compressed1: torch.Tensor,
|
| 244 |
+
compressed2: torch.Tensor,
|
| 245 |
+
x: torch.Tensor,
|
| 246 |
+
codebook: torch.Tensor) -> torch.Tensor:
|
| 247 |
+
out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
|
| 248 |
+
{kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
|
| 249 |
+
return out
|
| 250 |
+
""")
|
| 251 |
+
|
| 252 |
+
for S in [9, 10, 11]:
|
| 253 |
+
if S == 9:
|
| 254 |
+
bitrate_list = [2,3,4,5,6,7,8,9,10]
|
| 255 |
+
elif S == 10:
|
| 256 |
+
bitrate_list = [8, 9, 10]
|
| 257 |
+
elif S == 11:
|
| 258 |
+
bitrate_list = [9, 10]
|
| 259 |
+
for bitrate in bitrate_list:
|
| 260 |
+
torch.library.define(
|
| 261 |
+
f"ours_lib::decompress_tcq_{S}_{bitrate}",
|
| 262 |
+
"(Tensor compressed, Tensor codebook, int m, int k) -> Tensor")
|
| 263 |
+
|
| 264 |
+
name = f"decompress_tcq_{S}_{bitrate}"
|
| 265 |
+
kernel_name = f"tcq_kernels.decompress_16_{S}_{bitrate}"
|
| 266 |
+
exec(f"""\
|
| 267 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 268 |
+
def {name}_abstract(
|
| 269 |
+
compressed: torch.Tensor,
|
| 270 |
+
codebook: torch.Tensor,
|
| 271 |
+
m: int,
|
| 272 |
+
k: int) -> torch.Tensor:
|
| 273 |
+
return torch.zeros(m, k, dtype=torch.float16, device=compresed.device)
|
| 274 |
+
|
| 275 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 276 |
+
def {name}_cuda(
|
| 277 |
+
compressed: torch.Tensor,
|
| 278 |
+
codebook: torch.Tensor,
|
| 279 |
+
m: int,
|
| 280 |
+
k: int) -> torch.Tensor:
|
| 281 |
+
out = torch.zeros((m, k), dtype=torch.float16, device=compressed.device)
|
| 282 |
+
{kernel_name}(out, compressed.reshape(-1).view(torch.int32), codebook.reshape(-1))
|
| 283 |
+
return out
|
| 284 |
+
""")
|
| 285 |
+
if bitrate == bitrate_list[-1]:
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
torch.library.define(
|
| 289 |
+
f"ours_lib::decompress_tcq_comb_{S}_{bitrate}_{int(bitrate+1)}",
|
| 290 |
+
"(Tensor compressed1, Tensor compressed2, Tensor codebook, int m, int k) -> Tensor")
|
| 291 |
+
|
| 292 |
+
name = f"decompress_tcq_comb_{S}_{bitrate}_{int(bitrate+1)}"
|
| 293 |
+
kernel_name = f"tcq_kernels.decompress_comb_16_{S}_{bitrate}_{int(bitrate+1)}"
|
| 294 |
+
exec(f"""\
|
| 295 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 296 |
+
def {name}_abstract(
|
| 297 |
+
compressed1: torch.Tensor,
|
| 298 |
+
compressed2: torch.Tensor,
|
| 299 |
+
codebook: torch.Tensor,
|
| 300 |
+
m: int, k: int) -> torch.Tensor:
|
| 301 |
+
return torch.zeros(m, k, dtype=torch.float16, device=compressed1.device)
|
| 302 |
+
|
| 303 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 304 |
+
def {name}_cuda(
|
| 305 |
+
compressed1: torch.Tensor,
|
| 306 |
+
compressed2: torch.Tensor,
|
| 307 |
+
codebook: torch.Tensor,
|
| 308 |
+
m: int, k: int) -> torch.Tensor:
|
| 309 |
+
out = torch.zeros((m, k), dtype=torch.float16, device=compressed1.device)
|
| 310 |
+
{kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), codebook.reshape(-1))
|
| 311 |
+
return out
|
| 312 |
+
""")
|
| 313 |
+
torch.library.define(
|
| 314 |
+
f"ours_lib::decompress_tcq_combt_{S}_{bitrate}_{int(bitrate+1)}",
|
| 315 |
+
"(Tensor compressed1, Tensor compressed2, Tensor codebook, int m, int k) -> Tensor")
|
| 316 |
+
|
| 317 |
+
name = f"decompress_tcq_combt_{S}_{bitrate}_{int(bitrate+1)}"
|
| 318 |
+
kernel_name = f"tcq_kernels.decompress_combt_16_{S}_{bitrate}_{int(bitrate+1)}"
|
| 319 |
+
exec(f"""\
|
| 320 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 321 |
+
def {name}_abstract(
|
| 322 |
+
compressed1: torch.Tensor,
|
| 323 |
+
compressed2: torch.Tensor,
|
| 324 |
+
codebook: torch.Tensor,
|
| 325 |
+
m: int, k: int) -> torch.Tensor:
|
| 326 |
+
return torch.zeros(m, k, dtype=torch.float16, device=compressed1.device)
|
| 327 |
+
|
| 328 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 329 |
+
def {name}_cuda(
|
| 330 |
+
compressed1: torch.Tensor,
|
| 331 |
+
compressed2: torch.Tensor,
|
| 332 |
+
codebook: torch.Tensor,
|
| 333 |
+
m: int, k: int) -> torch.Tensor:
|
| 334 |
+
out = torch.zeros((m, k), dtype=torch.float16, device=compressed1.device)
|
| 335 |
+
{kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), codebook.reshape(-1))
|
| 336 |
+
return out
|
| 337 |
+
""")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
import sq_pack_gemm
|
| 345 |
+
import vq_pack_gemm
|
| 346 |
+
"""
|
| 347 |
+
SQ Pack SIMT
|
| 348 |
+
"""
|
| 349 |
+
torch.library.define("ours_lib::sq_pack_gemm_simt", "(Tensor x, Tensor q_weight, Tensor lut, int bitwidth) -> Tensor")
|
| 350 |
+
@torch.library.register_fake("ours_lib::sq_pack_gemm_simt")
|
| 351 |
+
def sq_pack_gemm_simt_abstract(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int) -> torch.Tensor:
|
| 352 |
+
return torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
|
| 353 |
+
|
| 354 |
+
@torch.library.impl("ours_lib::sq_pack_gemm_simt", "cuda")
|
| 355 |
+
def sq_pack_gemm_simt_cuda(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int) -> torch.Tensor:
|
| 356 |
+
output = torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
|
| 357 |
+
sq_pack_gemm.pack_gemm(x, output, q_weight, lut.view(-1), bitwidth)
|
| 358 |
+
return output
|
| 359 |
+
|
| 360 |
+
torch.library.define("ours_lib::sq_pack_dequant_simt", "(Tensor q_weight, Tensor lut, int bitwidth, int m, int k) -> Tensor")
|
| 361 |
+
@torch.library.register_fake("ours_lib::sq_pack_dequant_simt")
|
| 362 |
+
def sq_pack_dequant_simt_abstract(q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int, m:int, k:int) -> torch.Tensor:
|
| 363 |
+
return torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
|
| 364 |
+
|
| 365 |
+
@torch.library.impl("ours_lib::sq_pack_dequant_simt", "cuda")
|
| 366 |
+
def sq_pack_dequant_simt_cuda(q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int, m:int, k:int) -> torch.Tensor:
|
| 367 |
+
output = torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
|
| 368 |
+
sq_pack_gemm.pack_dequant(output, q_weight, lut.view(-1), bitwidth)
|
| 369 |
+
return output
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@torch.library.custom_op("ours_lib::sq_pack_gemm_inplace_simt", mutates_args={"output"})
|
| 373 |
+
def sq_pack_gemm_inplace_simt(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
|
| 374 |
+
sq_pack_gemm.pack_gemm(x, output, q_weight, lut, bitwidth)
|
| 375 |
+
|
| 376 |
+
@sq_pack_gemm_inplace_simt.register_fake
|
| 377 |
+
def _(x, q_weight, lut, output, bitwidth):
|
| 378 |
+
return None
|
| 379 |
+
|
| 380 |
+
"""
|
| 381 |
+
VQ Pack SIMT
|
| 382 |
+
"""
|
| 383 |
+
codeT_sz = 32
|
| 384 |
+
for vec_sz in [2,4]:
|
| 385 |
+
if vec_sz == 2:
|
| 386 |
+
lut_bits_list = [3,4,5,6,7,8,9,10,11,12]
|
| 387 |
+
elif vec_sz == 4:
|
| 388 |
+
lut_bits_list = [6,7,8,9,10,11,12]
|
| 389 |
+
for lut_bits in lut_bits_list:
|
| 390 |
+
code_n = lut_bits
|
| 391 |
+
recons_n = int(vec_sz * 16)
|
| 392 |
+
for maxm in [1,2,4,8]:
|
| 393 |
+
name = f"vq_pack_gemm_simt_{maxm}_{vec_sz}_{lut_bits}"
|
| 394 |
+
kernel_name = f"vq_pack_gemm.vq_pack_gemm_{maxm}_{lut_bits}_{vec_sz}_{code_n}_{codeT_sz}_{recons_n}"
|
| 395 |
+
torch.library.define(f"ours_lib::{name}", "(Tensor x, Tensor q_weight, Tensor lut) -> Tensor")
|
| 396 |
+
exec(f"""\
|
| 397 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 398 |
+
def {name}_abstract(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
|
| 399 |
+
return torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
|
| 400 |
+
|
| 401 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 402 |
+
def {name}_cuda(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
|
| 403 |
+
output = torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
|
| 404 |
+
{kernel_name}(x, output, q_weight.view(torch.uint32), lut)
|
| 405 |
+
return output
|
| 406 |
+
""")
|
| 407 |
+
name = f"vq_pack_dequant_simt_{vec_sz}_{lut_bits}"
|
| 408 |
+
kernel_name = f"vq_pack_gemm.vq_pack_dequant_{lut_bits}_{vec_sz}_{code_n}_{codeT_sz}_{recons_n}"
|
| 409 |
+
torch.library.define(f"ours_lib::{name}", "(Tensor q_weight, Tensor lut, int m, int k) -> Tensor")
|
| 410 |
+
exec(f"""\
|
| 411 |
+
@torch.library.register_fake("ours_lib::{name}")
|
| 412 |
+
def {name}_abstract(q_weight: torch.Tensor, lut: torch.Tensor, m: int, k: int) -> torch.Tensor:
|
| 413 |
+
return torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
|
| 414 |
+
|
| 415 |
+
@torch.library.impl("ours_lib::{name}", "cuda")
|
| 416 |
+
def {name}_cuda(q_weight: torch.Tensor, lut: torch.Tensor, m: int, k: int) -> torch.Tensor:
|
| 417 |
+
output = torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
|
| 418 |
+
{kernel_name}(output, q_weight.view(torch.uint32), lut)
|
| 419 |
+
return output
|
| 420 |
+
""")
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
# layer = QTIPLinearTCQ(4096, 4096, 16, 16, 16, 8, 1, 9, False, torch.float16)
|
| 425 |
+
# print(layer._info())
|
| 426 |
+
# x = torch.randn(4, 4096)
|
| 427 |
+
# print(layer(x).shape)
|
| 428 |
+
layer = CombLinearTCQ(4096, 4096, 16, 16, (2048, 2048), 16, (3, 4), 2, 9, False)
|
| 429 |
+
print(layer._info())
|
| 430 |
+
layer.forward(torch.randn(1, 4096).cuda())
|
lib/linear/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
lib/linear/__pycache__/comb_linear.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
lib/linear/__pycache__/incoherent_linear.cpython-311.pyc
ADDED
|
Binary file (42.7 kB). View file
|
|
|
lib/linear/__pycache__/quantized_linear.cpython-311.pyc
ADDED
|
Binary file (6.42 kB). View file
|
|
|
lib/linear/__pycache__/tcq_linear.cpython-311.pyc
ADDED
|
Binary file (6.88 kB). View file
|
|
|
lib/linear/__pycache__/vq_linear.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
lib/linear/comb_linear.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class CombLinearTCQ(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
in_features,
|
| 9 |
+
out_features,
|
| 10 |
+
td_x,
|
| 11 |
+
td_y,
|
| 12 |
+
out_part,
|
| 13 |
+
L, # trellis window
|
| 14 |
+
KV, # bpw
|
| 15 |
+
V, # vq dim
|
| 16 |
+
tlut_bits,
|
| 17 |
+
bias=False,
|
| 18 |
+
dtype=torch.float16,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
assert len(out_part) == 2 and len(KV) == 2
|
| 22 |
+
assert out_part[0] + out_part[1] == out_features
|
| 23 |
+
|
| 24 |
+
self.in_features = in_features
|
| 25 |
+
self.out_features = out_features
|
| 26 |
+
self.out_part = out_part
|
| 27 |
+
self.td_x = td_x
|
| 28 |
+
self.td_y = td_y
|
| 29 |
+
self.L = L
|
| 30 |
+
self.KV = KV
|
| 31 |
+
self.V = V
|
| 32 |
+
self.tlut_bits = tlut_bits
|
| 33 |
+
self.dtype = dtype
|
| 34 |
+
# packed into int16
|
| 35 |
+
self.register_buffer(
|
| 36 |
+
'trellis1',
|
| 37 |
+
torch.zeros((out_part[0] // td_x) * (in_features // td_y),
|
| 38 |
+
math.ceil((td_x * td_y) * KV[0] / 16 / V),
|
| 39 |
+
dtype=torch.int16))
|
| 40 |
+
self.register_buffer(
|
| 41 |
+
'trellis2',
|
| 42 |
+
torch.zeros((out_part[1] // td_x) * (in_features // td_y),
|
| 43 |
+
math.ceil((td_x * td_y) * KV[1] / 16 / V),
|
| 44 |
+
dtype=torch.int16))
|
| 45 |
+
self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
|
| 46 |
+
V,
|
| 47 |
+
dtype=torch.float16),
|
| 48 |
+
requires_grad=False)
|
| 49 |
+
|
| 50 |
+
if bias:
|
| 51 |
+
self.register_buffer('bias', torch.ones(out_features))
|
| 52 |
+
else:
|
| 53 |
+
self.bias = None
|
| 54 |
+
|
| 55 |
+
if out_part[0] == out_part[1]:
|
| 56 |
+
self.use_comb_kernel = True
|
| 57 |
+
else:
|
| 58 |
+
self.use_comb_kernel = False
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _info(self):
|
| 62 |
+
info = {
|
| 63 |
+
"in_features": self.in_features,
|
| 64 |
+
"out_features": self.out_features,
|
| 65 |
+
"td_x": self.td_x,
|
| 66 |
+
"td_y": self.td_y,
|
| 67 |
+
"out_part": self.out_part,
|
| 68 |
+
"L": self.L,
|
| 69 |
+
"KV": self.KV,
|
| 70 |
+
"V": self.V,
|
| 71 |
+
'tlut_bits': self.tlut_bits,
|
| 72 |
+
"dtype": self.dtype,
|
| 73 |
+
"trellis1": self.trellis1.detach().cpu(),
|
| 74 |
+
"trellis2": self.trellis2.detach().cpu(),
|
| 75 |
+
"tlut": self.tlut.detach().cpu().half(),
|
| 76 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 77 |
+
}
|
| 78 |
+
return info
|
| 79 |
+
|
| 80 |
+
def forward(self, inp, **kwargs):
|
| 81 |
+
x = inp.view(-1, self.in_features)
|
| 82 |
+
bs = x.shape[0]
|
| 83 |
+
m, k = self.out_features, self.in_features
|
| 84 |
+
if bs <= 8:
|
| 85 |
+
if self.use_comb_kernel:
|
| 86 |
+
wrapper = getattr(
|
| 87 |
+
torch.ops.ours_lib,
|
| 88 |
+
f"decompress_gemm_tcq_comb_{self.out_features}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 89 |
+
)
|
| 90 |
+
x = wrapper(self.trellis1, self.trellis2, x, self.tlut)
|
| 91 |
+
else:
|
| 92 |
+
wrapper1 = getattr(
|
| 93 |
+
torch.ops.ours_lib,
|
| 94 |
+
f"decompress_gemm_tcq_{self.out_part[0]}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}"
|
| 95 |
+
)
|
| 96 |
+
wrapper2 = getattr(
|
| 97 |
+
torch.ops.ours_lib,
|
| 98 |
+
f"decompress_gemm_tcq_{self.out_part[1]}_{bs}_{k}_{self.tlut_bits}_{self.KV[1]}"
|
| 99 |
+
)
|
| 100 |
+
x1 = wrapper1(self.trellis1, x, self.tlut)
|
| 101 |
+
x2 = wrapper2(self.trellis2, x, self.tlut)
|
| 102 |
+
x = torch.cat([x1, x2], dim=1)
|
| 103 |
+
else:
|
| 104 |
+
if self.use_comb_kernel:
|
| 105 |
+
wrapper = getattr(
|
| 106 |
+
torch.ops.ours_lib,
|
| 107 |
+
f"decompress_tcq_comb_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 108 |
+
)
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, k)
|
| 111 |
+
x = x.to(dq.dtype) @ dq.T
|
| 112 |
+
else:
|
| 113 |
+
wrapper1 = getattr(
|
| 114 |
+
torch.ops.ours_lib,
|
| 115 |
+
f"decompress_tcq_{self.tlut_bits}_{self.KV[0]}"
|
| 116 |
+
)
|
| 117 |
+
wrapper2 = getattr(
|
| 118 |
+
torch.ops.ours_lib,
|
| 119 |
+
f"decompress_tcq_{self.tlut_bits}_{self.KV[1]}"
|
| 120 |
+
)
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
dq1 = wrapper1(self.trellis1, self.tlut, self.out_part[0], k)
|
| 123 |
+
dq2 = wrapper2(self.trellis2, self.tlut, self.out_part[1], k)
|
| 124 |
+
x1 = x.to(dq1.dtype) @ dq1.T
|
| 125 |
+
x2 = x.to(dq2.dtype) @ dq2.T
|
| 126 |
+
x = torch.cat([x1, x2], dim=1)
|
| 127 |
+
return x.view(*inp.shape[:-1], m).to(inp.dtype)
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def gen_layer_from_info(info):
|
| 131 |
+
layer = CombLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["out_part"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
|
| 132 |
+
layer.trellis1.data.copy_(info["trellis1"])
|
| 133 |
+
layer.trellis2.data.copy_(info["trellis2"])
|
| 134 |
+
layer.tlut.data.copy_(info["tlut"])
|
| 135 |
+
if info["bias"] is not None:
|
| 136 |
+
layer.bias.data.copy_(info["bias"])
|
| 137 |
+
return layer
|
| 138 |
+
|
| 139 |
+
def get_weight(self):
|
| 140 |
+
wrapper = getattr(
|
| 141 |
+
torch.ops.ours_lib,
|
| 142 |
+
f"decompress_tcq_comb_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 143 |
+
)
|
| 144 |
+
dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, self.in_features)
|
| 145 |
+
return dq
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class CombtLinearTCQ(nn.Module):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
in_features,
|
| 152 |
+
out_features,
|
| 153 |
+
td_x,
|
| 154 |
+
td_y,
|
| 155 |
+
in_part,
|
| 156 |
+
L, # trellis window
|
| 157 |
+
KV, # bpw
|
| 158 |
+
V, # vq dim
|
| 159 |
+
tlut_bits,
|
| 160 |
+
bias=False,
|
| 161 |
+
dtype=torch.float16,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
assert len(in_part) == 2 and len(KV) == 2
|
| 165 |
+
assert in_part[0] + in_part[1] == in_features
|
| 166 |
+
|
| 167 |
+
self.in_features = in_features
|
| 168 |
+
self.out_features = out_features
|
| 169 |
+
self.in_part = in_part
|
| 170 |
+
self.td_x = td_x
|
| 171 |
+
self.td_y = td_y
|
| 172 |
+
self.L = L
|
| 173 |
+
self.KV = KV
|
| 174 |
+
self.V = V
|
| 175 |
+
self.tlut_bits = tlut_bits
|
| 176 |
+
self.dtype = dtype
|
| 177 |
+
# packed into int16
|
| 178 |
+
self.register_buffer(
|
| 179 |
+
'trellis1',
|
| 180 |
+
torch.zeros((out_features // td_x) * (in_part[0] // td_y),
|
| 181 |
+
math.ceil((td_x * td_y) * KV[0] / 16 / V),
|
| 182 |
+
dtype=torch.int16))
|
| 183 |
+
self.register_buffer(
|
| 184 |
+
'trellis2',
|
| 185 |
+
torch.zeros((out_features // td_x) * (in_part[1] // td_y),
|
| 186 |
+
math.ceil((td_x * td_y) * KV[1] / 16 / V),
|
| 187 |
+
dtype=torch.int16))
|
| 188 |
+
self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
|
| 189 |
+
V,
|
| 190 |
+
dtype=torch.float16),
|
| 191 |
+
requires_grad=False)
|
| 192 |
+
|
| 193 |
+
if bias:
|
| 194 |
+
self.register_buffer('bias', torch.ones(out_features))
|
| 195 |
+
else:
|
| 196 |
+
self.bias = None
|
| 197 |
+
|
| 198 |
+
if in_part[0] == in_part[1]:
|
| 199 |
+
self.use_comb_kernel = True
|
| 200 |
+
else:
|
| 201 |
+
self.use_comb_kernel = False
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _info(self):
|
| 205 |
+
info = {
|
| 206 |
+
"in_features": self.in_features,
|
| 207 |
+
"out_features": self.out_features,
|
| 208 |
+
"td_x": self.td_x,
|
| 209 |
+
"td_y": self.td_y,
|
| 210 |
+
"in_part": self.in_part,
|
| 211 |
+
"L": self.L,
|
| 212 |
+
"KV": self.KV,
|
| 213 |
+
"V": self.V,
|
| 214 |
+
'tlut_bits': self.tlut_bits,
|
| 215 |
+
"dtype": self.dtype,
|
| 216 |
+
"trellis1": self.trellis1.detach().cpu(),
|
| 217 |
+
"trellis2": self.trellis2.detach().cpu(),
|
| 218 |
+
"tlut": self.tlut.detach().cpu().half(),
|
| 219 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 220 |
+
}
|
| 221 |
+
return info
|
| 222 |
+
|
| 223 |
+
def forward(self, inp, **kwargs):
|
| 224 |
+
x = inp.view(-1, self.in_features)
|
| 225 |
+
bs = x.shape[0]
|
| 226 |
+
m, k = self.out_features, self.in_features
|
| 227 |
+
if bs <= 8:
|
| 228 |
+
if self.use_comb_kernel:
|
| 229 |
+
wrapper = getattr(
|
| 230 |
+
torch.ops.ours_lib,
|
| 231 |
+
f"decompress_gemm_tcq_combt_{self.out_features}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 232 |
+
)
|
| 233 |
+
x = wrapper(self.trellis1, self.trellis2, x, self.tlut)
|
| 234 |
+
else:
|
| 235 |
+
wrapper1 = getattr(
|
| 236 |
+
torch.ops.ours_lib,
|
| 237 |
+
f"decompress_gemm_tcq_{m}_{bs}_{self.in_part[0]}_{self.tlut_bits}_{self.KV[0]}"
|
| 238 |
+
)
|
| 239 |
+
wrapper2 = getattr(
|
| 240 |
+
torch.ops.ours_lib,
|
| 241 |
+
f"decompress_gemm_tcq_{m}_{bs}_{self.in_part[1]}_{self.tlut_bits}_{self.KV[1]}"
|
| 242 |
+
)
|
| 243 |
+
x1 = wrapper1(self.trellis1, x[:, :self.in_part[0]], self.tlut)
|
| 244 |
+
x2 = wrapper2(self.trellis2, x[:, self.in_part[0]:], self.tlut)
|
| 245 |
+
x = x1 + x2
|
| 246 |
+
else:
|
| 247 |
+
if self.use_comb_kernel:
|
| 248 |
+
wrapper = getattr(
|
| 249 |
+
torch.ops.ours_lib,
|
| 250 |
+
f"decompress_tcq_combt_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 251 |
+
)
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, k)
|
| 254 |
+
x = x.to(dq.dtype) @ dq.T
|
| 255 |
+
else:
|
| 256 |
+
wrapper1 = getattr(
|
| 257 |
+
torch.ops.ours_lib,
|
| 258 |
+
f"decompress_tcq_{self.tlut_bits}_{self.KV[0]}"
|
| 259 |
+
)
|
| 260 |
+
wrapper2 = getattr(
|
| 261 |
+
torch.ops.ours_lib,
|
| 262 |
+
f"decompress_tcq_{self.tlut_bits}_{self.KV[1]}"
|
| 263 |
+
)
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
dq1 = wrapper1(self.trellis1, self.tlut, m, self.in_part[0])
|
| 266 |
+
dq2 = wrapper2(self.trellis2, self.tlut, m, self.in_part[1])
|
| 267 |
+
x1 = x[:, :self.in_part[0]].to(dq1.dtype) @ dq1.T
|
| 268 |
+
x2 = x[:, self.in_part[0]:].to(dq2.dtype) @ dq2.T
|
| 269 |
+
x = x1 + x2
|
| 270 |
+
return x.view(*inp.shape[:-1], m).to(inp.dtype)
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def gen_layer_from_info(info):
|
| 274 |
+
layer = CombtLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["in_part"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
|
| 275 |
+
layer.trellis1.data.copy_(info["trellis1"])
|
| 276 |
+
layer.trellis2.data.copy_(info["trellis2"])
|
| 277 |
+
layer.tlut.data.copy_(info["tlut"])
|
| 278 |
+
if info["bias"] is not None:
|
| 279 |
+
layer.bias.data.copy_(info["bias"])
|
| 280 |
+
return layer
|
| 281 |
+
|
| 282 |
+
def get_weight(self):
|
| 283 |
+
wrapper = getattr(
|
| 284 |
+
torch.ops.ours_lib,
|
| 285 |
+
f"decompress_tcq_combt_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
|
| 286 |
+
)
|
| 287 |
+
dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, self.in_features)
|
| 288 |
+
return dq
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def merge_infos(info1, info2):
|
| 292 |
+
assert info1["in_features"] == info2["in_features"]
|
| 293 |
+
assert info1["td_x"] == info2["td_x"]
|
| 294 |
+
assert info1["td_y"] == info2["td_y"]
|
| 295 |
+
assert info1["L"] == info2["L"]
|
| 296 |
+
assert info1["KV"] == info2["KV"]
|
| 297 |
+
assert info1["V"] == info2["V"]
|
| 298 |
+
assert info1["tlut_bits"] == info2["tlut_bits"]
|
| 299 |
+
if not torch.allclose(info1["tlut"], info2["tlut"], atol=1e-4):
|
| 300 |
+
print("warning: tlut is not close. it is unexpected behavior if you do not use dummy quantizers.")
|
| 301 |
+
assert info1["bias"] is None and info2["bias"] is None
|
| 302 |
+
assert info1["dtype"] == info2["dtype"]
|
| 303 |
+
info = {}
|
| 304 |
+
info["in_features"] = info1["in_features"]
|
| 305 |
+
info["out_features"] = info1["out_features"] + info2["out_features"]
|
| 306 |
+
info["td_x"] = info1["td_x"]
|
| 307 |
+
info["td_y"] = info1["td_y"]
|
| 308 |
+
info["L"] = info1["L"]
|
| 309 |
+
info["KV"] = info1["KV"]
|
| 310 |
+
info["V"] = info1["V"]
|
| 311 |
+
info["tlut_bits"] = info1["tlut_bits"]
|
| 312 |
+
info["bias"] = None
|
| 313 |
+
info["dtype"] = info1["dtype"]
|
| 314 |
+
info["trellis1"] = torch.cat([info1["trellis1"], info2["trellis1"]], dim=0)
|
| 315 |
+
info["trellis2"] = torch.cat([info1["trellis2"], info2["trellis2"]], dim=0)
|
| 316 |
+
info["tlut"] = info1["tlut"]
|
| 317 |
+
info["in_part"] = info1["in_part"]
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
return info
|
| 321 |
+
|
| 322 |
+
if __name__ == "__main__":
|
| 323 |
+
layer = CombLinearTCQ(4096, 4096, 16, 16, (2048, 2048), 16, (3, 4), 2, 9, False)
|
| 324 |
+
print(layer._info())
|
| 325 |
+
layer.forward(torch.randn(1, 4096).cuda())
|
lib/linear/incoherent_linear.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from lib.utils import (get_hadK, matmul_hadU_cuda, matmul_hadUt_cuda, matmul_hadUt, matmul_hadU, matmul_hadUt_head, matmul_hadU_head, matmul_hadU_head_cuda, matmul_hadUt_head_cuda)
|
| 5 |
+
from lib.linear.tcq_linear import QTIPLinearTCQ
|
| 6 |
+
from lib.linear.vq_linear import VQLinearPackTensorCore, VQLinearPackSIMT
|
| 7 |
+
from lib.linear.comb_linear import CombLinearTCQ, CombtLinearTCQ
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
from model.llama import LlamaRotaryEmbedding, repeat_kv, apply_rotary_pos_emb, Cache
|
| 12 |
+
|
| 13 |
+
def make_linear(info, use_simt=False):
|
| 14 |
+
if "tcq" in info["quant_info"]["quantizer_str"]:
|
| 15 |
+
linear = QTIPLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 16 |
+
elif use_simt and ("sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]):
|
| 17 |
+
linear = VQLinearPackSIMT.gen_layer_from_info(info["linear_info"])
|
| 18 |
+
elif "sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]:
|
| 19 |
+
linear = VQLinearPackTensorCore.gen_layer_from_info(info["linear_info"])
|
| 20 |
+
elif "tcomb" in info["quant_info"]["quantizer_str"]:
|
| 21 |
+
linear = CombtLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 22 |
+
elif "comb" in info["quant_info"]["quantizer_str"]:
|
| 23 |
+
linear = CombLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 24 |
+
else:
|
| 25 |
+
linear = nn.Linear(info["in_features"], info["out_features"], bias=False)
|
| 26 |
+
return linear
|
| 27 |
+
|
| 28 |
+
class IncoherentSdpaAttention(nn.Module):
|
| 29 |
+
def __init__(self, config, merge_qk=False, merge_kv=False, merge_qv=False, merge_qkv=False, layer_idx=None, dtype=torch.float16):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.config = config
|
| 32 |
+
self.attention_dropout = config.attention_dropout
|
| 33 |
+
self.hidden_size = config.hidden_size
|
| 34 |
+
self.num_heads = config.num_attention_heads
|
| 35 |
+
self.head_dim = getattr(config, "head_dim",
|
| 36 |
+
self.hidden_size // self.num_heads)
|
| 37 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 38 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 39 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 40 |
+
self.kv_out = self.hidden_size * self.num_key_value_heads // self.num_heads
|
| 41 |
+
self.rope_theta = config.rope_theta
|
| 42 |
+
self.is_causal = True
|
| 43 |
+
|
| 44 |
+
self.q_proj = None
|
| 45 |
+
self.k_proj = None
|
| 46 |
+
self.v_proj = None
|
| 47 |
+
self.o_proj = None
|
| 48 |
+
self.qk_proj = None
|
| 49 |
+
self.qkv_proj = None
|
| 50 |
+
self.kv_proj = None
|
| 51 |
+
self.dtype=dtype
|
| 52 |
+
self.layer_idx = layer_idx
|
| 53 |
+
self.register_buffer("SU_qkv", torch.ones(config.hidden_size, dtype=self.dtype))
|
| 54 |
+
self.register_buffer("SU_o", torch.ones(config.hidden_size, dtype=self.dtype))
|
| 55 |
+
|
| 56 |
+
hidden_had, hidden_K = get_hadK(config.hidden_size)
|
| 57 |
+
|
| 58 |
+
hidden_had_T = hidden_had.T.contiguous().cuda() if hidden_had is not None else None
|
| 59 |
+
|
| 60 |
+
self.register_buffer('Wscale_qkv', torch.ones(config.hidden_size + 2 * self.kv_out, dtype=self.dtype), persistent=False)
|
| 61 |
+
self.register_buffer('Wscale_o', torch.ones(config.hidden_size, dtype=self.dtype), persistent=False)
|
| 62 |
+
self.register_buffer('had_left_qkv_T', hidden_had_T, persistent=False)
|
| 63 |
+
self.register_buffer('had_left_o_T', hidden_had_T, persistent=False)
|
| 64 |
+
|
| 65 |
+
self.hidden_K = hidden_K
|
| 66 |
+
self.scale = 64.0
|
| 67 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
| 68 |
+
|
| 69 |
+
self.merge_qk = merge_qk
|
| 70 |
+
self.merge_kv = merge_kv
|
| 71 |
+
self.merge_qv = merge_qv
|
| 72 |
+
self.merge_qkv = merge_qkv
|
| 73 |
+
|
| 74 |
+
assert sum([self.merge_qk, self.merge_kv, self.merge_qv, self.merge_qkv]) <= 1, "Only one of merge_qk, merge_kv, merge_qv, merge_qkv can be True"
|
| 75 |
+
|
| 76 |
+
def compute_qkv(self, input):
|
| 77 |
+
|
| 78 |
+
n = len(self.SU_qkv)
|
| 79 |
+
x = input.view(-1, n).half()
|
| 80 |
+
|
| 81 |
+
x = matmul_hadU_cuda(x * self.SU_qkv, self.had_left_qkv_T, self.hidden_K) / self.scale
|
| 82 |
+
if self.merge_qkv:
|
| 83 |
+
qkv = self.qkv_proj(x.half()) * self.Wscale_qkv * self.scale
|
| 84 |
+
q, k, v = qkv.split([self.hidden_size, self.kv_out, self.kv_out], dim=-1)
|
| 85 |
+
elif self.merge_qk:
|
| 86 |
+
qk = self.qk_proj(x.half()) * self.Wscale_qkv[:self.hidden_size + self.kv_out] * self.scale
|
| 87 |
+
q, k = qk.split([self.hidden_size, self.kv_out], dim=-1)
|
| 88 |
+
v = self.v_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
|
| 89 |
+
elif self.merge_kv:
|
| 90 |
+
kv = self.kv_proj(x.half()) * self.Wscale_qkv[self.hidden_size:] * self.scale
|
| 91 |
+
k, v = kv.split([self.kv_out, self.kv_out], dim=-1)
|
| 92 |
+
q = self.q_proj(x.half()) * self.Wscale_qkv[:self.hidden_size] * self.scale
|
| 93 |
+
elif self.merge_qv:
|
| 94 |
+
qv = self.qv_proj(x.half()) * self.Wscale_qkv[:self.hidden_size + self.kv_out] * self.scale
|
| 95 |
+
q, v = qv.split([self.hidden_size, self.kv_out], dim=-1)
|
| 96 |
+
k = self.k_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
|
| 97 |
+
else:
|
| 98 |
+
q = self.q_proj(x.half()) * self.Wscale_qkv[:self.hidden_size] * self.scale
|
| 99 |
+
k = self.k_proj(x.half()) * self.Wscale_qkv[self.hidden_size:self.hidden_size + self.kv_out] * self.scale
|
| 100 |
+
v = self.v_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
|
| 101 |
+
return q.view(*input.shape[:-1], n), k.view(*input.shape[:-1], self.kv_out), v.view(*input.shape[:-1], self.kv_out)
|
| 102 |
+
|
| 103 |
+
def compute_o(self, input):
|
| 104 |
+
n = len(self.SU_o)
|
| 105 |
+
x = input.view(-1, n).half()
|
| 106 |
+
x = matmul_hadU_cuda(x * self.SU_o, self.had_left_o_T, self.hidden_K) / self.scale
|
| 107 |
+
x = self.o_proj(x.half()) * self.Wscale_o * self.scale
|
| 108 |
+
return x.view(*input.shape[:-1], n)
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
hidden_states: torch.Tensor,
|
| 113 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 114 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 115 |
+
past_key_value: Optional[Cache] = None,
|
| 116 |
+
output_attentions: bool = False,
|
| 117 |
+
use_cache: bool = False,
|
| 118 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 119 |
+
position_embeddings: Optional[
|
| 120 |
+
Tuple[torch.Tensor,
|
| 121 |
+
torch.Tensor]] = None, # will become mandatory in v4.46
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
| 124 |
+
Optional[Tuple[torch.Tensor]]]:
|
| 125 |
+
if output_attentions:
|
| 126 |
+
return super().forward(
|
| 127 |
+
hidden_states=hidden_states,
|
| 128 |
+
attention_mask=attention_mask,
|
| 129 |
+
position_ids=position_ids,
|
| 130 |
+
past_key_value=past_key_value,
|
| 131 |
+
output_attentions=output_attentions,
|
| 132 |
+
use_cache=use_cache,
|
| 133 |
+
cache_position=cache_position,
|
| 134 |
+
position_embeddings=position_embeddings,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
bsz, q_len, _ = hidden_states.size()
|
| 138 |
+
|
| 139 |
+
# query_states = self.q_proj(hidden_states)
|
| 140 |
+
# key_states = self.k_proj(hidden_states)
|
| 141 |
+
# value_states = self.v_proj(hidden_states)
|
| 142 |
+
query_states, key_states, value_states = self.compute_qkv(hidden_states)
|
| 143 |
+
|
| 144 |
+
query_states = query_states.view(bsz, q_len, self.num_heads,
|
| 145 |
+
self.head_dim).transpose(1, 2)
|
| 146 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
|
| 147 |
+
self.head_dim).transpose(1, 2)
|
| 148 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
|
| 149 |
+
self.head_dim).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
if position_embeddings is None:
|
| 152 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
| 153 |
+
else:
|
| 154 |
+
cos, sin = position_embeddings
|
| 155 |
+
|
| 156 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 157 |
+
query_states, key_states, cos, sin)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if past_key_value is not None:
|
| 161 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 162 |
+
cache_kwargs = {
|
| 163 |
+
"sin": sin,
|
| 164 |
+
"cos": cos,
|
| 165 |
+
"cache_position": cache_position
|
| 166 |
+
}
|
| 167 |
+
key_states, value_states = past_key_value.update(
|
| 168 |
+
key_states, value_states, self.layer_idx, cache_kwargs)
|
| 169 |
+
|
| 170 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 171 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 172 |
+
|
| 173 |
+
causal_mask = attention_mask
|
| 174 |
+
if attention_mask is not None:
|
| 175 |
+
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
|
| 176 |
+
|
| 177 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 178 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 179 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 180 |
+
query_states = query_states.contiguous()
|
| 181 |
+
key_states = key_states.contiguous()
|
| 182 |
+
value_states = value_states.contiguous()
|
| 183 |
+
|
| 184 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 185 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 186 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 187 |
+
|
| 188 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 189 |
+
query_states,
|
| 190 |
+
key_states.to(query_states.device),
|
| 191 |
+
value_states.to(query_states.device),
|
| 192 |
+
attn_mask=causal_mask,
|
| 193 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 194 |
+
is_causal=is_causal,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 198 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
| 199 |
+
|
| 200 |
+
# attn_output = self.o_proj(attn_output)
|
| 201 |
+
attn_output = self.compute_o(attn_output)
|
| 202 |
+
return attn_output, None, past_key_value
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def gen_layer_from_info(config, layer_idx, info_q, info_k, info_v, info_o, merge_qk=False, merge_qv=False, merge_kv=False, merge_qkv=False, dummy=False, use_simt=False, use_simt_q=None, use_simt_k=None, use_simt_v=None, use_simt_o=None):
|
| 206 |
+
attn = IncoherentSdpaAttention(config, merge_qk=merge_qk, merge_qv=merge_qv, merge_kv=merge_kv, merge_qkv=merge_qkv, layer_idx=layer_idx)
|
| 207 |
+
if not dummy:
|
| 208 |
+
attn.SU_qkv.data.copy_(info_q["SU"])
|
| 209 |
+
attn.SU_o.data.copy_(info_o["SU"])
|
| 210 |
+
if not merge_qv:
|
| 211 |
+
attn.Wscale_qkv.data.copy_(torch.cat([info_q["Wscale"], info_k["Wscale"], info_v["Wscale"]], dim=-1))
|
| 212 |
+
else:
|
| 213 |
+
attn.Wscale_qkv.data.copy_(torch.cat([info_q["Wscale"], info_v["Wscale"], info_k["Wscale"]], dim=-1))
|
| 214 |
+
attn.Wscale_o.data.copy_(info_o["Wscale"])
|
| 215 |
+
|
| 216 |
+
use_simt_q = use_simt if use_simt_q is None else use_simt_q
|
| 217 |
+
use_simt_k = use_simt if use_simt_k is None else use_simt_k
|
| 218 |
+
use_simt_v = use_simt if use_simt_v is None else use_simt_v
|
| 219 |
+
use_simt_o = use_simt if use_simt_o is None else use_simt_o
|
| 220 |
+
|
| 221 |
+
if merge_qk:
|
| 222 |
+
to_merged, rest, target_proj, rest_proj = [info_q, info_k], [info_v], "qk_proj", ["v_proj"]
|
| 223 |
+
elif merge_kv:
|
| 224 |
+
to_merged, rest, target_proj, rest_proj = [info_k, info_v], [info_q], "kv_proj", ["q_proj"]
|
| 225 |
+
elif merge_qv:
|
| 226 |
+
to_merged, rest, target_proj, rest_proj = [info_q, info_v], [info_k], "qv_proj", ["k_proj"]
|
| 227 |
+
elif merge_qkv:
|
| 228 |
+
to_merged, rest, target_proj, rest_proj = [info_q, info_k, info_v], [], "qkv_proj", []
|
| 229 |
+
else:
|
| 230 |
+
to_merged, rest, target_proj, rest_proj = [], [info_q, info_k, info_v], "", ["q_proj", "k_proj", "v_proj"]
|
| 231 |
+
|
| 232 |
+
if merge_qk or merge_kv or merge_qv or merge_qkv:
|
| 233 |
+
if merge_kv: use_simt_merge = use_simt_k
|
| 234 |
+
elif merge_qk or merge_qv or merge_qkv: use_simt_merge = use_simt_q
|
| 235 |
+
else: raise ValueError
|
| 236 |
+
|
| 237 |
+
if "tcq" in to_merged[0]["quant_info"]["quantizer_str"]:
|
| 238 |
+
merged_linear = QTIPLinearTCQ
|
| 239 |
+
elif "sq" in to_merged[0]["quant_info"]["quantizer_str"] or "vq" in to_merged[0]["quant_info"]["quantizer_str"] or "ldlq" in to_merged[0]["quant_info"]["quantizer_str"]:
|
| 240 |
+
merged_linear = VQLinearPackTensorCore if not use_simt_merge else VQLinearPackSIMT
|
| 241 |
+
elif "tcomb" in to_merged[0]["quant_info"]["quantizer_str"]:
|
| 242 |
+
merged_linear = CombtLinearTCQ
|
| 243 |
+
elif "comb" in to_merged[0]["quant_info"]["quantizer_str"]:
|
| 244 |
+
merged_linear = CombLinearTCQ
|
| 245 |
+
merged = to_merged[0]['linear_info']
|
| 246 |
+
for info in to_merged[1:]:
|
| 247 |
+
merged = merged_linear.merge_infos(merged, info['linear_info'])
|
| 248 |
+
setattr(attn, target_proj, merged_linear.gen_layer_from_info(merged))
|
| 249 |
+
for info, proj in zip(rest, rest_proj):
|
| 250 |
+
if proj == "q_proj": cur_use_simt = use_simt_q
|
| 251 |
+
elif proj == "k_proj": cur_use_simt = use_simt_k
|
| 252 |
+
elif proj == "v_proj": cur_use_simt = use_simt_v
|
| 253 |
+
else: raise ValueError
|
| 254 |
+
setattr(attn, proj, make_linear(info, use_simt=cur_use_simt))
|
| 255 |
+
attn.o_proj = make_linear(info_o, use_simt=use_simt_o)
|
| 256 |
+
return attn
|
| 257 |
+
|
| 258 |
+
@staticmethod
|
| 259 |
+
def gen_layer_from_quantizer_str_and_key(config, layer_idx, quant_dir, quantizer_str_q, quantizer_str_k, quantizer_str_v, quantizer_str_o, key_q, key_k, key_v, key_o, merge_qk=False, merge_qv=False, merge_kv=False, merge_qkv=False, dummy=False, use_simt=False, use_simt_q=None, use_simt_k=None, use_simt_v=None, use_simt_o=None):
|
| 260 |
+
if not dummy:
|
| 261 |
+
info_q = torch.load(f"{quant_dir}/{quantizer_str_q}/{key_q}.pt")
|
| 262 |
+
info_k = torch.load(f"{quant_dir}/{quantizer_str_k}/{key_k}.pt")
|
| 263 |
+
info_v = torch.load(f"{quant_dir}/{quantizer_str_v}/{key_v}.pt")
|
| 264 |
+
info_o = torch.load(f"{quant_dir}/{quantizer_str_o}/{key_o}.pt")
|
| 265 |
+
else:
|
| 266 |
+
from lib.utils.mem_op import get_dummy_quant_results
|
| 267 |
+
from lib.config import MODEL_KEYS
|
| 268 |
+
model_key = MODEL_KEYS[config._name_or_path]
|
| 269 |
+
info_q = get_dummy_quant_results(model_key, f"self_attn.q_proj", quantizer_str_q)
|
| 270 |
+
info_k = get_dummy_quant_results(model_key, f"self_attn.k_proj", quantizer_str_k)
|
| 271 |
+
info_v = get_dummy_quant_results(model_key, f"self_attn.v_proj", quantizer_str_v)
|
| 272 |
+
info_o = get_dummy_quant_results(model_key, f"self_attn.o_proj", quantizer_str_o)
|
| 273 |
+
|
| 274 |
+
return IncoherentSdpaAttention.gen_layer_from_info(config, layer_idx, info_q, info_k, info_v, info_o, merge_qk=merge_qk, merge_qv=merge_qv, merge_kv=merge_kv, merge_qkv=merge_qkv, dummy=dummy, use_simt=use_simt, use_simt_q=use_simt_q, use_simt_k=use_simt_k, use_simt_v=use_simt_v, use_simt_o=use_simt_o)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class IncoherentMLP(nn.Module):
|
| 280 |
+
"""
|
| 281 |
+
only support left only and unified SU for upgates.
|
| 282 |
+
"""
|
| 283 |
+
def __init__(self, hidden_size, intermediate_size, hidden_act, merge_ug=False, bias=False, dtype=torch.float16):
|
| 284 |
+
super().__init__()
|
| 285 |
+
assert bias is False, "bias is not supported"
|
| 286 |
+
self.hidden_size = hidden_size
|
| 287 |
+
self.intermediate_size = intermediate_size
|
| 288 |
+
self.dtype = dtype
|
| 289 |
+
|
| 290 |
+
self.up_proj = None
|
| 291 |
+
self.gate_proj = None
|
| 292 |
+
self.ug_proj = None
|
| 293 |
+
self.down_proj = None
|
| 294 |
+
|
| 295 |
+
self.register_buffer("SU_ug", torch.ones(hidden_size, dtype=self.dtype))
|
| 296 |
+
self.register_buffer("SU_dp", torch.ones(intermediate_size, dtype=self.dtype))
|
| 297 |
+
|
| 298 |
+
hidden_had, hidden_K = get_hadK(hidden_size)
|
| 299 |
+
inter_had, inter_K = get_hadK(intermediate_size)
|
| 300 |
+
|
| 301 |
+
inter_had_T = inter_had.T.contiguous().cuda() if inter_had is not None else None
|
| 302 |
+
hidden_had_T = hidden_had.T.contiguous().cuda() if hidden_had is not None else None
|
| 303 |
+
|
| 304 |
+
self.register_buffer('Wscale_ug', torch.ones(intermediate_size * 2, dtype=self.dtype), persistent=False)
|
| 305 |
+
self.register_buffer('Wscale_dp', torch.ones(hidden_size, dtype=self.dtype), persistent=False)
|
| 306 |
+
self.register_buffer('had_left_ug_T', hidden_had_T, persistent=False)
|
| 307 |
+
self.register_buffer('had_left_dp_T', inter_had_T, persistent=False)
|
| 308 |
+
|
| 309 |
+
self.hidden_K = hidden_K
|
| 310 |
+
self.inter_K = inter_K
|
| 311 |
+
|
| 312 |
+
self.scale = 64.0
|
| 313 |
+
|
| 314 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 315 |
+
self.merge_ug = merge_ug
|
| 316 |
+
|
| 317 |
+
def forward(self, input):
|
| 318 |
+
n = len(self.SU_ug)
|
| 319 |
+
x = input.view(-1, n).half()
|
| 320 |
+
x = self.compute_ug(x)
|
| 321 |
+
x = self.compute_dp(x)
|
| 322 |
+
return x.view(*input.shape[:-1], n).to(input.dtype)
|
| 323 |
+
|
| 324 |
+
def compute_ug(self, x):
|
| 325 |
+
x = matmul_hadU_cuda(x * self.SU_ug, self.had_left_ug_T, self.hidden_K) / self.scale
|
| 326 |
+
if self.merge_ug:
|
| 327 |
+
x = self.ug_proj(x.half()) * self.Wscale_ug * self.scale
|
| 328 |
+
x_up, x_gate = x.split(self.intermediate_size, dim=-1)
|
| 329 |
+
else:
|
| 330 |
+
x_up = self.up_proj(x.half()) * self.Wscale_ug[:self.intermediate_size] * self.scale
|
| 331 |
+
x_gate = self.gate_proj(x.half()) * self.Wscale_ug[self.intermediate_size:] * self.scale
|
| 332 |
+
x = self.act_fn(x_gate) * x_up
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
def compute_dp(self, x):
|
| 336 |
+
x = matmul_hadU_cuda(x * self.SU_dp, self.had_left_dp_T, self.inter_K) / self.scale
|
| 337 |
+
x = self.down_proj(x.half()) * self.Wscale_dp * self.scale
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def gen_layer_from_info(config, info_up, info_gate, info_down, merge_ug=False, dummy=False, use_simt=False, use_simt_u=None, use_simt_g=None, use_simt_d=None):
|
| 342 |
+
mlp = IncoherentMLP(
|
| 343 |
+
hidden_size=config.hidden_size,
|
| 344 |
+
intermediate_size=config.intermediate_size,
|
| 345 |
+
hidden_act=config.hidden_act,
|
| 346 |
+
merge_ug=merge_ug
|
| 347 |
+
)
|
| 348 |
+
if not dummy:
|
| 349 |
+
mlp.SU_ug.data.copy_(info_up["SU"])
|
| 350 |
+
mlp.SU_dp.data.copy_(info_down["SU"])
|
| 351 |
+
mlp.Wscale_ug.data.copy_(torch.cat([info_up["Wscale"], info_gate["Wscale"]], dim=-1))
|
| 352 |
+
mlp.Wscale_dp.data.copy_(info_down["Wscale"])
|
| 353 |
+
|
| 354 |
+
use_simt_u = use_simt if use_simt_u is None else use_simt_u
|
| 355 |
+
use_simt_g = use_simt if use_simt_g is None else use_simt_g
|
| 356 |
+
use_simt_d = use_simt if use_simt_d is None else use_simt_d
|
| 357 |
+
|
| 358 |
+
if merge_ug:
|
| 359 |
+
if "tcq" in info_up["quant_info"]["quantizer_str"]:
|
| 360 |
+
linear_info_ug = QTIPLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
|
| 361 |
+
mlp.ug_proj = QTIPLinearTCQ.gen_layer_from_info(linear_info_ug)
|
| 362 |
+
elif "vq" in info_up["quant_info"]["quantizer_str"] or "sq" in info_up["quant_info"]["quantizer_str"] or "ldlq" in info_up["quant_info"]["quantizer_str"]:
|
| 363 |
+
if use_simt_u:
|
| 364 |
+
linear_info_ug = VQLinearPackSIMT.merge_infos(info_up['linear_info'], info_gate['linear_info'])
|
| 365 |
+
mlp.ug_proj = VQLinearPackSIMT.gen_layer_from_info(linear_info_ug)
|
| 366 |
+
else:
|
| 367 |
+
linear_info_ug = VQLinearPackTensorCore.merge_infos(info_up['linear_info'], info_gate['linear_info'])
|
| 368 |
+
mlp.ug_proj = VQLinearPackTensorCore.gen_layer_from_info(linear_info_ug)
|
| 369 |
+
elif "tcomb" in info_up["quant_info"]["quantizer_str"]:
|
| 370 |
+
linear_info_ug = CombtLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
|
| 371 |
+
mlp.ug_proj = CombtLinearTCQ.gen_layer_from_info(linear_info_ug)
|
| 372 |
+
elif "comb" in info_up["quant_info"]["quantizer_str"]:
|
| 373 |
+
linear_info_ug = CombLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
|
| 374 |
+
mlp.ug_proj = CombLinearTCQ.gen_layer_from_info(linear_info_ug)
|
| 375 |
+
else:
|
| 376 |
+
mlp.up_proj = make_linear(info_up, use_simt=use_simt_u)
|
| 377 |
+
mlp.gate_proj = make_linear(info_gate, use_simt=use_simt_g)
|
| 378 |
+
mlp.down_proj = make_linear(info_down, use_simt=use_simt_d)
|
| 379 |
+
return mlp
|
| 380 |
+
|
| 381 |
+
@staticmethod
|
| 382 |
+
def gen_layer_from_quantizer_str_and_key(config, quant_dir, quantizer_str_up, quantizer_str_gate, quantizer_str_down, key_up, key_gate, key_down, merge_ug=False, dummy=False, use_simt=False, use_simt_u=None, use_simt_g=None, use_simt_d=None):
|
| 383 |
+
if not dummy:
|
| 384 |
+
info_up = torch.load(f"{quant_dir}/{quantizer_str_up}/{key_up}.pt")
|
| 385 |
+
info_gate = torch.load(f"{quant_dir}/{quantizer_str_gate}/{key_gate}.pt")
|
| 386 |
+
info_down = torch.load(f"{quant_dir}/{quantizer_str_down}/{key_down}.pt")
|
| 387 |
+
else:
|
| 388 |
+
from lib.utils.mem_op import get_dummy_quant_results
|
| 389 |
+
from lib.config import MODEL_KEYS
|
| 390 |
+
model_key = MODEL_KEYS[config._name_or_path]
|
| 391 |
+
info_up = get_dummy_quant_results(model_key, f"mlp.up_proj", quantizer_str_up)
|
| 392 |
+
info_gate = get_dummy_quant_results(model_key, f"mlp.gate_proj", quantizer_str_gate)
|
| 393 |
+
info_down = get_dummy_quant_results(model_key, f"mlp.down_proj", quantizer_str_down)
|
| 394 |
+
return IncoherentMLP.gen_layer_from_info(config, info_up, info_gate, info_down, merge_ug, dummy=dummy, use_simt=use_simt, use_simt_u=use_simt_u, use_simt_g=use_simt_g, use_simt_d=use_simt_d)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class IncoherentLinear(nn.Module):
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
in_features,
|
| 402 |
+
out_features,
|
| 403 |
+
hadU,
|
| 404 |
+
hadV,
|
| 405 |
+
bias=False,
|
| 406 |
+
dtype=torch.float16,
|
| 407 |
+
use_linear=True,
|
| 408 |
+
):
|
| 409 |
+
super().__init__()
|
| 410 |
+
|
| 411 |
+
self.in_features = in_features
|
| 412 |
+
self.out_features = out_features
|
| 413 |
+
self.dtype = dtype
|
| 414 |
+
|
| 415 |
+
if use_linear:
|
| 416 |
+
self.linear = nn.Linear(in_features, out_features, bias=False, dtype=dtype)
|
| 417 |
+
else:
|
| 418 |
+
self.linear = None
|
| 419 |
+
|
| 420 |
+
if bias:
|
| 421 |
+
self.register_buffer('bias', torch.ones(out_features))
|
| 422 |
+
else:
|
| 423 |
+
self.bias = None
|
| 424 |
+
|
| 425 |
+
self.register_buffer("SU", torch.ones(in_features, dtype=self.dtype))
|
| 426 |
+
self.register_buffer("SV", torch.ones(out_features, dtype=self.dtype))
|
| 427 |
+
|
| 428 |
+
self.hadU = hadU
|
| 429 |
+
self.hadV = hadV
|
| 430 |
+
had_left, K_left = get_hadK(hadU)
|
| 431 |
+
had_right, K_right = get_hadK(hadV)
|
| 432 |
+
if had_left is not None:
|
| 433 |
+
had_left_T = had_left.T.contiguous().cuda()
|
| 434 |
+
else:
|
| 435 |
+
had_left_T = None
|
| 436 |
+
if had_right is not None:
|
| 437 |
+
had_right = had_right.cuda()
|
| 438 |
+
self.register_buffer('Wscale', torch.ones(out_features, dtype=self.dtype), persistent=True)
|
| 439 |
+
self.register_buffer('had_right', had_right, persistent=False)
|
| 440 |
+
self.register_buffer('had_left_T', had_left_T, persistent=False)
|
| 441 |
+
self.K_left = K_left
|
| 442 |
+
self.K_right = K_right
|
| 443 |
+
self.scale = 32.0
|
| 444 |
+
|
| 445 |
+
self.rot_info = "all"
|
| 446 |
+
|
| 447 |
+
self.skip_l = False
|
| 448 |
+
self.skip_r = False
|
| 449 |
+
|
| 450 |
+
def apply_rot_info(self):
|
| 451 |
+
if self.rot_info == "all":
|
| 452 |
+
self.skip_l = False
|
| 453 |
+
self.skip_r = False
|
| 454 |
+
elif self.rot_info == "skip_l":
|
| 455 |
+
self.skip_l = True
|
| 456 |
+
self.skip_r = False
|
| 457 |
+
elif self.rot_info == "skip_r":
|
| 458 |
+
self.skip_l = False
|
| 459 |
+
self.skip_r = True
|
| 460 |
+
elif self.rot_info == "skip_lr":
|
| 461 |
+
self.skip_l = True
|
| 462 |
+
self.skip_r = True
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(f"Invalid rot_info: {self.rot_info}")
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def save_info(self, path, quant_info=None):
|
| 468 |
+
linear_info = self.linear._info()
|
| 469 |
+
info = {
|
| 470 |
+
"in_features": self.in_features,
|
| 471 |
+
"out_features": self.out_features,
|
| 472 |
+
"hadU": self.hadU,
|
| 473 |
+
"hadV": self.hadV,
|
| 474 |
+
"dtype": self.dtype,
|
| 475 |
+
"scale": self.scale,
|
| 476 |
+
"Wscale": self.Wscale.detach().cpu(),
|
| 477 |
+
"rot_info": self.rot_info,
|
| 478 |
+
"linear_info": linear_info,
|
| 479 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 480 |
+
"SU": self.SU.detach().cpu(),
|
| 481 |
+
"SV": self.SV.detach().cpu(),
|
| 482 |
+
"quant_info": quant_info,
|
| 483 |
+
}
|
| 484 |
+
torch.save(info, path)
|
| 485 |
+
|
| 486 |
+
def forward(self, input):
|
| 487 |
+
n, m = len(self.SU), len(self.SV)
|
| 488 |
+
x = input.view(-1, n).half()#.to(torch.float32)
|
| 489 |
+
if not self.skip_l:
|
| 490 |
+
x = x * self.SU
|
| 491 |
+
x = matmul_hadU_head_cuda(x, self.had_left_T, self.K_left, self.hadU) / self.scale
|
| 492 |
+
else:
|
| 493 |
+
# x = x * self.SU
|
| 494 |
+
x = x / self.scale
|
| 495 |
+
x = self.linear(x.half()) * self.Wscale#.float()
|
| 496 |
+
if not self.skip_r:
|
| 497 |
+
x = matmul_hadU_head_cuda(x, self.had_right, self.K_right, self.hadV)
|
| 498 |
+
x = x.to(self.SV.device) * (self.SV * self.scale)
|
| 499 |
+
else:
|
| 500 |
+
# x = x.to(self.SV.device) * (self.SV * self.scale)
|
| 501 |
+
x = x * self.scale
|
| 502 |
+
|
| 503 |
+
x = x.view(*input.shape[:-1], m).to(input.dtype)
|
| 504 |
+
if self.bias is not None:
|
| 505 |
+
x = x + self.bias
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
@staticmethod
|
| 509 |
+
def gen_layer_from_info(info, merge_layers=False, dummy=False, use_simt=False):
|
| 510 |
+
layer = IncoherentLinear(
|
| 511 |
+
in_features=info["in_features"],
|
| 512 |
+
out_features=info["out_features"],
|
| 513 |
+
hadU=info["hadU"] if "hadU" in info else info["in_features"],
|
| 514 |
+
hadV=info["hadV"] if "hadV" in info else info["out_features"],
|
| 515 |
+
bias=info["bias"] is not None,
|
| 516 |
+
dtype=info["dtype"],
|
| 517 |
+
use_linear=False,
|
| 518 |
+
)
|
| 519 |
+
if not dummy:
|
| 520 |
+
if info["bias"] is not None:
|
| 521 |
+
layer.bias.data.copy_(info["bias"])
|
| 522 |
+
layer.SU.data.copy_(info["SU"])
|
| 523 |
+
layer.SV.data.copy_(info["SV"])
|
| 524 |
+
layer.Wscale.data.copy_(info["Wscale"])
|
| 525 |
+
if info["quant_info"] is not None:
|
| 526 |
+
if "tcq" in info["quant_info"]["quantizer_str"]:
|
| 527 |
+
layer.linear = QTIPLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 528 |
+
elif "sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]:
|
| 529 |
+
if use_simt:
|
| 530 |
+
layer.linear = VQLinearPackSIMT.gen_layer_from_info(info["linear_info"])
|
| 531 |
+
else:
|
| 532 |
+
layer.linear = VQLinearPackTensorCore.gen_layer_from_info(info["linear_info"])
|
| 533 |
+
elif "tcomb" in info["quant_info"]["quantizer_str"]:
|
| 534 |
+
layer.linear = CombtLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 535 |
+
elif "comb" in info["quant_info"]["quantizer_str"]:
|
| 536 |
+
layer.linear = CombLinearTCQ.gen_layer_from_info(info["linear_info"])
|
| 537 |
+
if "rot_info" in info["quant_info"]:
|
| 538 |
+
layer.rot_info = info["quant_info"]["rot_info"]
|
| 539 |
+
elif "rot_info" in info:
|
| 540 |
+
layer.rot_info = info["rot_info"]
|
| 541 |
+
else:
|
| 542 |
+
layer.rot_info = "all"
|
| 543 |
+
if merge_layers:
|
| 544 |
+
layer.apply_rot_info()
|
| 545 |
+
return layer
|
| 546 |
+
|
| 547 |
+
@staticmethod
|
| 548 |
+
def gen_layer_from_quantizer_str_and_key(config, quant_dir, quantizer_str, key, merge_layers=False, dummy=False, use_simt=False):
|
| 549 |
+
if not dummy:
|
| 550 |
+
info = torch.load(f"{quant_dir}/{quantizer_str}/{key}.pt")
|
| 551 |
+
else:
|
| 552 |
+
from lib.utils.mem_op import get_dummy_quant_results
|
| 553 |
+
from lib.config import MODEL_KEYS
|
| 554 |
+
model_key = MODEL_KEYS[config._name_or_path]
|
| 555 |
+
layer_id = key.split("_")[0]
|
| 556 |
+
layer_key = key.replace(f"{layer_id}_", "")
|
| 557 |
+
info = get_dummy_quant_results(model_key, f"{layer_key}", quantizer_str)
|
| 558 |
+
return IncoherentLinear.gen_layer_from_info(info, merge_layers=merge_layers, dummy=dummy, use_simt=use_simt)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def calc_kurtosis(W):
|
| 562 |
+
# W: (-1, n), ||W[i]|| = 1
|
| 563 |
+
W = W.to(torch.float64)
|
| 564 |
+
return W.pow(4).mean(-1) - 3.0
|
| 565 |
+
|
| 566 |
+
def calc_skewness(W):
|
| 567 |
+
# W: (-1, n), ||W[i]|| = 1
|
| 568 |
+
W = W.to(torch.float64)
|
| 569 |
+
return W.pow(3).mean(-1)
|
| 570 |
+
|
| 571 |
+
def linear_to_incoherent(linear, hadU, hadV, SU=None, SV=None, lnorm=None, rot_info="all"):
|
| 572 |
+
dtype_ = torch.float32
|
| 573 |
+
dtype = linear.weight.data.dtype
|
| 574 |
+
device = linear.weight.device
|
| 575 |
+
inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype)
|
| 576 |
+
if SU is None:
|
| 577 |
+
SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device)
|
| 578 |
+
if SV is None:
|
| 579 |
+
SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device)
|
| 580 |
+
if lnorm is not None:
|
| 581 |
+
lnorm = lnorm.to(device).to(dtype_)
|
| 582 |
+
|
| 583 |
+
if linear.bias is not None:
|
| 584 |
+
inc_linear.bias.data.copy_(linear.bias)
|
| 585 |
+
|
| 586 |
+
W = linear.weight.data.to(dtype_)
|
| 587 |
+
Wr = (W.to(torch.float64).to(device) @ torch.diag(lnorm).to(torch.float64)).to(dtype_).to(device) if lnorm is not None else W
|
| 588 |
+
if hadU != linear.in_features or hadV != linear.out_features:
|
| 589 |
+
Wr = matmul_hadUt_head(matmul_hadUt_head(Wr.T.to(device) * SV, hadV).T * SU, hadU)
|
| 590 |
+
else:
|
| 591 |
+
Wr = matmul_hadUt(matmul_hadUt(Wr.T.to(device) * SV).T * SU)
|
| 592 |
+
# Wscale = Wr.square().mean().sqrt()
|
| 593 |
+
Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_)
|
| 594 |
+
|
| 595 |
+
Wr = Wr / Wscale
|
| 596 |
+
|
| 597 |
+
inc_linear.SU.data.copy_(SU.to(inc_linear.SU.dtype))
|
| 598 |
+
# inc_linear.SV.data.copy_((SV * Wscale).to(inc_linear.SV.dtype))
|
| 599 |
+
inc_linear.SV.data.copy_((SV).to(inc_linear.SV.dtype))
|
| 600 |
+
inc_linear.Wscale.data.copy_(Wscale.view(-1))
|
| 601 |
+
inc_linear.linear.weight.data.copy_(Wr.to(inc_linear.linear.weight.dtype))
|
| 602 |
+
inc_linear = inc_linear.to(dtype).to(device)
|
| 603 |
+
inc_linear.rot_info = rot_info
|
| 604 |
+
inc_linear.apply_rot_info()
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# anal weight
|
| 608 |
+
kurt = calc_kurtosis(inc_linear.linear.weight.data)
|
| 609 |
+
skew = calc_skewness(inc_linear.linear.weight.data)
|
| 610 |
+
# print(kurt.pow(2).mean(), kurt.mean(), kurt.std(), kurt.max(), kurt.min())
|
| 611 |
+
# print pretty
|
| 612 |
+
print(f"E[kurt^2]: {kurt.pow(2).mean():.4f}, E[kurt]: {kurt.mean():.4f}, std[kurt]: {kurt.std():.4f}, max[kurt]: {kurt.max():.4f}, min[kurt]: {kurt.min():.4f}")
|
| 613 |
+
print(f"E[skew^2]: {skew.pow(2).mean():.4f}, E[skew]: {skew.mean():.4f}, std[skew]: {skew.std():.4f}, max[skew]: {skew.max():.4f}, min[skew]: {skew.min():.4f}")
|
| 614 |
+
kurt_stats = {
|
| 615 |
+
"kurt_pow2_mean": kurt.pow(2).mean(),
|
| 616 |
+
"kurt_mean": kurt.mean(),
|
| 617 |
+
"kurt_std": kurt.std(),
|
| 618 |
+
"kurt_max": kurt.max(),
|
| 619 |
+
"kurt_min": kurt.min(),
|
| 620 |
+
"skew_pow2_mean": skew.pow(2).mean(),
|
| 621 |
+
"skew_mean": skew.mean(),
|
| 622 |
+
"skew_std": skew.std(),
|
| 623 |
+
"skew_max": skew.max(),
|
| 624 |
+
"skew_min": skew.min(),
|
| 625 |
+
}
|
| 626 |
+
return inc_linear, kurt_stats
|
| 627 |
+
|
| 628 |
+
if __name__ == "__main__":
|
| 629 |
+
linear = nn.Linear(4096, 4096, bias=True, dtype=torch.float16).cuda()
|
| 630 |
+
# linear.weight.data = linear.weight.data * 5 + 4
|
| 631 |
+
inc_linear = linear_to_incoherent(linear)
|
| 632 |
+
|
| 633 |
+
ran = torch.randn(4096, 4096, dtype=torch.float16).cuda()
|
| 634 |
+
orig = linear(ran)
|
| 635 |
+
inc = inc_linear(ran)
|
| 636 |
+
|
| 637 |
+
print((orig - inc).pow(2).mean() / orig.pow(2).mean())
|
| 638 |
+
|
| 639 |
+
|
lib/linear/quantized_linear.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from lib.codebook import bitshift
|
| 8 |
+
from lib.utils import (clean, dtype_from_str, get_hadK, has_kernel,
|
| 9 |
+
matmul_hadU_cuda)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class QuantizedLinear(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features,
|
| 17 |
+
out_features,
|
| 18 |
+
td_x,
|
| 19 |
+
td_y,
|
| 20 |
+
L, # trellis window
|
| 21 |
+
K, # bpw
|
| 22 |
+
V, # vq dim
|
| 23 |
+
tlut_bits, # tunable LUT bits
|
| 24 |
+
decode_mode,
|
| 25 |
+
bias=False,
|
| 26 |
+
dtype=torch.float16,
|
| 27 |
+
mode='eval',
|
| 28 |
+
use_prev_kernel=True,
|
| 29 |
+
grad_ckpt=False,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.in_features = in_features
|
| 34 |
+
self.out_features = out_features
|
| 35 |
+
self.td_x = td_x
|
| 36 |
+
self.td_y = td_y
|
| 37 |
+
self.L = L
|
| 38 |
+
self.K = K
|
| 39 |
+
self.V = V
|
| 40 |
+
self.tlut_bits = tlut_bits
|
| 41 |
+
self.decode_mode = decode_mode
|
| 42 |
+
self.register_buffer('rcp', torch.tensor(0))
|
| 43 |
+
# TP rank, not used unless rcp != 0
|
| 44 |
+
self.register_buffer('tp_rank', torch.tensor(8))
|
| 45 |
+
self.dtype = dtype
|
| 46 |
+
# packed into int16
|
| 47 |
+
self.register_buffer(
|
| 48 |
+
'trellis',
|
| 49 |
+
torch.zeros((out_features // td_x) * (in_features // td_y),
|
| 50 |
+
math.ceil((td_x * td_y) * K / 16),
|
| 51 |
+
dtype=torch.int16))
|
| 52 |
+
|
| 53 |
+
if decode_mode in ['lut', 'quantlut', 'quantlut_sym']:
|
| 54 |
+
self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
|
| 55 |
+
V,
|
| 56 |
+
dtype=torch.float16),
|
| 57 |
+
requires_grad=False)
|
| 58 |
+
else:
|
| 59 |
+
self.tlut = None
|
| 60 |
+
|
| 61 |
+
if bias:
|
| 62 |
+
self.register_buffer('bias', torch.ones(out_features))
|
| 63 |
+
else:
|
| 64 |
+
self.bias = None
|
| 65 |
+
|
| 66 |
+
self.register_buffer("SU", torch.ones(in_features, dtype=self.dtype))
|
| 67 |
+
self.register_buffer("SV", torch.ones(out_features,
|
| 68 |
+
dtype=torch.float32))
|
| 69 |
+
|
| 70 |
+
self.built_codebook_class = False
|
| 71 |
+
self.built_graph = False
|
| 72 |
+
|
| 73 |
+
had_left, K_left = get_hadK(in_features)
|
| 74 |
+
had_right, K_right = get_hadK(out_features)
|
| 75 |
+
self.register_buffer('had_left', had_left, persistent=False)
|
| 76 |
+
self.register_buffer('had_right', had_right, persistent=False)
|
| 77 |
+
self.K_left = K_left
|
| 78 |
+
self.K_right = K_right
|
| 79 |
+
self.mode = mode
|
| 80 |
+
self.use_prev_kernel = use_prev_kernel
|
| 81 |
+
self.grad_ckpt = grad_ckpt
|
| 82 |
+
self.has_kernel = has_kernel(decode_mode, L, K, V, tlut_bits, td_x,
|
| 83 |
+
td_y)
|
| 84 |
+
|
| 85 |
+
def forward(self, input):
|
| 86 |
+
if self.grad_ckpt:
|
| 87 |
+
return self.ckpt_forward(input)
|
| 88 |
+
return self.no_ckpt_forward(input)
|
| 89 |
+
|
| 90 |
+
def ckpt_forward(self, input):
|
| 91 |
+
return torch.utils.checkpoint.checkpoint(self.no_ckpt_forward,
|
| 92 |
+
input,
|
| 93 |
+
use_reentrant=True)
|
| 94 |
+
|
| 95 |
+
def no_ckpt_forward(self, input):
|
| 96 |
+
if not self.built_codebook_class:
|
| 97 |
+
self.codebook_class = bitshift.BitshiftLinear(
|
| 98 |
+
self.td_x,
|
| 99 |
+
self.td_y,
|
| 100 |
+
self.L,
|
| 101 |
+
self.K,
|
| 102 |
+
self.V,
|
| 103 |
+
self.tlut_bits,
|
| 104 |
+
self.decode_mode,
|
| 105 |
+
dtype=self.dtype,
|
| 106 |
+
tlut=self.tlut,
|
| 107 |
+
has_kernel=self.has_kernel)
|
| 108 |
+
|
| 109 |
+
rcp = self.rcp.item()
|
| 110 |
+
del self.rcp
|
| 111 |
+
self.rcp = rcp
|
| 112 |
+
|
| 113 |
+
if self.mode == 'eval':
|
| 114 |
+
pass
|
| 115 |
+
elif self.mode == 'train-recons':
|
| 116 |
+
if not self.has_kernel:
|
| 117 |
+
self.packed_trellis = self.trellis.cpu()
|
| 118 |
+
unpacked_trellis = self.codebook_class.cb.unpack_trellis(
|
| 119 |
+
self.trellis, self.td_x * self.td_y)
|
| 120 |
+
self.trellis = unpacked_trellis
|
| 121 |
+
clean()
|
| 122 |
+
elif self.mode == 'train-fixW':
|
| 123 |
+
self.codebook_class.cache_hatW(self.trellis, self.had_left,
|
| 124 |
+
self.had_right, self.K_left,
|
| 125 |
+
self.K_right, len(self.SV),
|
| 126 |
+
len(self.SU), self.rcp,
|
| 127 |
+
self.tp_rank)
|
| 128 |
+
self.trellis = self.trellis.cpu()
|
| 129 |
+
del self.had_left, self.had_right, self.K_left, self.K_right
|
| 130 |
+
clean()
|
| 131 |
+
self.had_left = None
|
| 132 |
+
self.had_right = None
|
| 133 |
+
self.K_left = None
|
| 134 |
+
self.K_right = None
|
| 135 |
+
else:
|
| 136 |
+
raise Exception
|
| 137 |
+
|
| 138 |
+
self.built_codebook_class = True
|
| 139 |
+
|
| 140 |
+
result = self.codebook_class(input,
|
| 141 |
+
self.trellis,
|
| 142 |
+
self.SU,
|
| 143 |
+
self.SV,
|
| 144 |
+
self.had_left,
|
| 145 |
+
self.had_right,
|
| 146 |
+
self.K_left,
|
| 147 |
+
self.K_right,
|
| 148 |
+
self.rcp,
|
| 149 |
+
self.tp_rank,
|
| 150 |
+
mode=self.mode,
|
| 151 |
+
use_prev_kernel=self.use_prev_kernel) + 0
|
| 152 |
+
if self.bias is not None:
|
| 153 |
+
return result + self.bias
|
| 154 |
+
return result
|
lib/linear/rotation.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib.utils import matmul_hadU_cuda, matmul_hadUt_cuda, matmul_hadUt_head, matmul_hadU_head, get_hadK
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class RotateWeights(nn.Module):
|
| 5 |
+
def __init__(self, had_dim_U, had_dim_V, SU=None, SV=None):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.had_dim_U = had_dim_U
|
| 8 |
+
self.had_dim_V = had_dim_V
|
| 9 |
+
self.SU = SU
|
| 10 |
+
self.SV = SV
|
| 11 |
+
|
| 12 |
+
self.had_left_U, self.K_left_U = get_hadK(had_dim_U)
|
| 13 |
+
self.had_left_V, self.K_left_V = get_hadK(had_dim_V)
|
| 14 |
+
|
| 15 |
+
def apply_weights(self, weights):
|
| 16 |
+
return matmul_hadUt_head(matmul_hadUt_head(weights.T, self.had_left_U, self.K_left_U), self.had_left_V, self.K_left_V)
|
lib/linear/tcq_linear.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class QTIPLinearTCQ(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
in_features,
|
| 9 |
+
out_features,
|
| 10 |
+
td_x,
|
| 11 |
+
td_y,
|
| 12 |
+
L, # trellis window
|
| 13 |
+
KV, # bpw
|
| 14 |
+
V, # vq dim
|
| 15 |
+
tlut_bits,
|
| 16 |
+
bias=False,
|
| 17 |
+
dtype=torch.float16,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.in_features = in_features
|
| 22 |
+
self.out_features = out_features
|
| 23 |
+
self.td_x = td_x
|
| 24 |
+
self.td_y = td_y
|
| 25 |
+
self.L = L
|
| 26 |
+
self.KV = KV
|
| 27 |
+
self.V = V
|
| 28 |
+
self.tlut_bits = tlut_bits
|
| 29 |
+
self.dtype = dtype
|
| 30 |
+
# packed into int16
|
| 31 |
+
self.register_buffer(
|
| 32 |
+
'trellis',
|
| 33 |
+
torch.zeros((out_features // td_x) * (in_features // td_y),
|
| 34 |
+
math.ceil((td_x * td_y) * KV / 16 / V),
|
| 35 |
+
dtype=torch.int16))
|
| 36 |
+
|
| 37 |
+
self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
|
| 38 |
+
V,
|
| 39 |
+
dtype=torch.float16),
|
| 40 |
+
requires_grad=False)
|
| 41 |
+
|
| 42 |
+
if bias:
|
| 43 |
+
self.register_buffer('bias', torch.ones(out_features))
|
| 44 |
+
else:
|
| 45 |
+
self.bias = None
|
| 46 |
+
|
| 47 |
+
def _info(self):
|
| 48 |
+
info = {
|
| 49 |
+
"in_features": self.in_features,
|
| 50 |
+
"out_features": self.out_features,
|
| 51 |
+
"td_x": self.td_x,
|
| 52 |
+
"td_y": self.td_y,
|
| 53 |
+
"L": self.L,
|
| 54 |
+
"KV": self.KV,
|
| 55 |
+
"V": self.V,
|
| 56 |
+
'tlut_bits': self.tlut_bits,
|
| 57 |
+
"dtype": self.dtype,
|
| 58 |
+
"trellis": self.trellis.detach().cpu(),
|
| 59 |
+
"tlut": self.tlut.detach().cpu().half(),
|
| 60 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 61 |
+
}
|
| 62 |
+
return info
|
| 63 |
+
|
| 64 |
+
def forward(self, inp, **kwargs):
|
| 65 |
+
x = inp.view(-1, self.in_features)#.to(torch.float32)
|
| 66 |
+
bs = x.shape[0]
|
| 67 |
+
m, k = self.out_features, self.in_features
|
| 68 |
+
if bs <= 8:
|
| 69 |
+
wrapper = getattr(
|
| 70 |
+
torch.ops.ours_lib,
|
| 71 |
+
f"decompress_gemm_tcq_{m}_{bs}_{k}_{self.tlut_bits}_{self.KV}")
|
| 72 |
+
|
| 73 |
+
x = wrapper(self.trellis, x, self.tlut)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
wrapper = getattr(
|
| 77 |
+
torch.ops.ours_lib,
|
| 78 |
+
f"decompress_tcq_{self.tlut_bits}_{self.KV}"
|
| 79 |
+
)
|
| 80 |
+
# dq = wrapper(self.trellis, self.tlut).to(x.dtype)
|
| 81 |
+
# x = x @ dq.T
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
dq = wrapper(self.trellis, self.tlut, m, k) #.to(x.dtype)
|
| 84 |
+
x = (x.to(dq.dtype) @ dq.T)#.to(x.dtype)
|
| 85 |
+
return x.view(*inp.shape[:-1], m).to(inp.dtype)
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def gen_layer_from_info(info):
|
| 89 |
+
layer = QTIPLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
|
| 90 |
+
layer.trellis.data.copy_(info["trellis"])
|
| 91 |
+
layer.tlut.data.copy_(info["tlut"])
|
| 92 |
+
if info["bias"] is not None:
|
| 93 |
+
layer.bias.data.copy_(info["bias"])
|
| 94 |
+
return layer
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def merge_infos(info1, info2):
|
| 98 |
+
assert info1["in_features"] == info2["in_features"]
|
| 99 |
+
assert info1["td_x"] == info2["td_x"]
|
| 100 |
+
assert info1["td_y"] == info2["td_y"]
|
| 101 |
+
assert info1["L"] == info2["L"]
|
| 102 |
+
assert info1["KV"] == info2["KV"]
|
| 103 |
+
assert info1["V"] == info2["V"]
|
| 104 |
+
assert info1["tlut_bits"] == info2["tlut_bits"]
|
| 105 |
+
if not torch.allclose(info1["tlut"], info2["tlut"], atol=1e-4):
|
| 106 |
+
print("warning: tlut is not close. it is unexpected behavior if you do not use dummy quantizers.")
|
| 107 |
+
assert info1["bias"] is None and info2["bias"] is None
|
| 108 |
+
assert info1["dtype"] == info2["dtype"]
|
| 109 |
+
info = {}
|
| 110 |
+
info["in_features"] = info1["in_features"]
|
| 111 |
+
info["out_features"] = info1["out_features"] + info2["out_features"]
|
| 112 |
+
info["td_x"] = info1["td_x"]
|
| 113 |
+
info["td_y"] = info1["td_y"]
|
| 114 |
+
info["L"] = info1["L"]
|
| 115 |
+
info["KV"] = info1["KV"]
|
| 116 |
+
info["V"] = info1["V"]
|
| 117 |
+
info["tlut_bits"] = info1["tlut_bits"]
|
| 118 |
+
info["bias"] = None
|
| 119 |
+
info["dtype"] = info1["dtype"]
|
| 120 |
+
info["trellis"] = torch.cat([info1["trellis"], info2["trellis"]], dim=0)
|
| 121 |
+
info["tlut"] = info1["tlut"]
|
| 122 |
+
return info
|
lib/linear/vq_linear.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class VQLinearPackTensorCore(nn.Module):
|
| 6 |
+
def __init__(self, in_features, out_features, lut_bits, vec_sz=2, bias=False, dtype=torch.half):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.in_features = in_features
|
| 10 |
+
self.out_features = out_features
|
| 11 |
+
self.lut_bits = lut_bits
|
| 12 |
+
self.dtype = dtype
|
| 13 |
+
self.vec_sz = vec_sz
|
| 14 |
+
|
| 15 |
+
self.register_buffer(
|
| 16 |
+
'qweight',
|
| 17 |
+
torch.randint(0, 4, (out_features, lut_bits*in_features // 32 // vec_sz), dtype=torch.int32, device='cuda')
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
self.register_buffer(
|
| 21 |
+
'lut',
|
| 22 |
+
torch.randn((2 ** lut_bits, vec_sz), dtype=self.dtype, device='cuda')
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
if bias:
|
| 26 |
+
self.register_buffer(
|
| 27 |
+
"bias",
|
| 28 |
+
torch.randn((out_features,), dtype=self.dtype, device='cuda')
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
self.bias = None
|
| 32 |
+
|
| 33 |
+
self.vq_type = f"vq{self.vec_sz}" if self.vec_sz > 1 else "sq_dup" if lut_bits <= 4 else "sq"
|
| 34 |
+
|
| 35 |
+
def _info(self):
|
| 36 |
+
info = {
|
| 37 |
+
"in_features": self.in_features,
|
| 38 |
+
"out_features": self.out_features,
|
| 39 |
+
"lut_bits": self.lut_bits,
|
| 40 |
+
"dtype": self.dtype,
|
| 41 |
+
"vec_sz": self.vec_sz,
|
| 42 |
+
"qweight": self.qweight.detach().cpu(),
|
| 43 |
+
"lut": self.lut.detach().cpu().half(),
|
| 44 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 45 |
+
}
|
| 46 |
+
return info
|
| 47 |
+
|
| 48 |
+
def forward(self, inp, **kwargs):
|
| 49 |
+
x = inp.view(-1, self.in_features)
|
| 50 |
+
bs = x.shape[0]
|
| 51 |
+
m, k = self.out_features, self.in_features
|
| 52 |
+
if bs <= 8:
|
| 53 |
+
wrapper = getattr(
|
| 54 |
+
torch.ops.ours_lib,
|
| 55 |
+
f"decompress_gemm_{m}_{bs}_{k}_{self.lut_bits}_{self.vq_type}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
x = wrapper(self.qweight, x, self.lut)
|
| 59 |
+
else:
|
| 60 |
+
wrapper = getattr(
|
| 61 |
+
torch.ops.ours_lib,
|
| 62 |
+
f"decompress_{self.lut_bits}_{self.vq_type}"
|
| 63 |
+
)
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
dq = wrapper(self.qweight, self.lut, m, k)
|
| 66 |
+
x = (x.to(dq.dtype) @ dq.T)
|
| 67 |
+
|
| 68 |
+
return x.view(*inp.shape[:-1], m).to(inp.dtype)
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def gen_layer_from_info(info):
|
| 72 |
+
layer = VQLinearPackTensorCore(info["in_features"], info["out_features"], info["lut_bits"], info["vec_sz"], info["bias"] is not None, info["dtype"])
|
| 73 |
+
layer.qweight.data.copy_(info["qweight"])
|
| 74 |
+
layer.lut.data.copy_(info["lut"])
|
| 75 |
+
if info["bias"] is not None:
|
| 76 |
+
layer.bias.data.copy_(info["bias"])
|
| 77 |
+
return layer
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def merge_infos(info1, info2):
|
| 81 |
+
assert info1["in_features"] == info2["in_features"]
|
| 82 |
+
assert info1["lut_bits"] == info2["lut_bits"]
|
| 83 |
+
assert info1["vec_sz"] == info2["vec_sz"]
|
| 84 |
+
assert info1["bias"] is None and info2["bias"] is None
|
| 85 |
+
assert info1["dtype"] == info2["dtype"]
|
| 86 |
+
if not torch.allclose(info1["lut"], info2["lut"], atol=1e-4):
|
| 87 |
+
print("warning: lut is not close. it is unexpected behavior if you do not use dummy quantizers.")
|
| 88 |
+
info = {}
|
| 89 |
+
info["in_features"] = info1["in_features"]
|
| 90 |
+
info["out_features"] = info1["out_features"] + info2["out_features"]
|
| 91 |
+
info["lut_bits"] = info1["lut_bits"]
|
| 92 |
+
info["vec_sz"] = info1["vec_sz"]
|
| 93 |
+
info["bias"] = None
|
| 94 |
+
info["dtype"] = info1["dtype"]
|
| 95 |
+
info["qweight"] = torch.cat([info1["qweight"], info2["qweight"]], dim=0)
|
| 96 |
+
info["lut"] = info1["lut"]
|
| 97 |
+
return info
|
| 98 |
+
|
| 99 |
+
class VQLinearPackSIMT(nn.Module):
|
| 100 |
+
def __init__(self, in_features, out_features, lut_bits, vec_sz=1, bias=False, dtype=torch.half):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.in_features = in_features
|
| 103 |
+
self.out_features = out_features
|
| 104 |
+
self.lut_bits = lut_bits
|
| 105 |
+
self.dtype = dtype
|
| 106 |
+
self.vec_sz = vec_sz
|
| 107 |
+
|
| 108 |
+
self.register_buffer(
|
| 109 |
+
'qweight',
|
| 110 |
+
torch.randint(0, 4, (out_features, lut_bits*in_features // 32 // vec_sz), dtype=torch.int32, device='cuda')
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.register_buffer(
|
| 114 |
+
'lut',
|
| 115 |
+
torch.randn((2 ** lut_bits, vec_sz), dtype=self.dtype, device='cuda')
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if bias:
|
| 119 |
+
self.register_buffer(
|
| 120 |
+
"bias",
|
| 121 |
+
torch.randn((out_features,), dtype=self.dtype, device='cuda')
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
self.bias = None
|
| 125 |
+
|
| 126 |
+
def _info(self):
|
| 127 |
+
info = {
|
| 128 |
+
"in_features": self.in_features,
|
| 129 |
+
"out_features": self.out_features,
|
| 130 |
+
"lut_bits": self.lut_bits,
|
| 131 |
+
"dtype": self.dtype,
|
| 132 |
+
"vec_sz": self.vec_sz,
|
| 133 |
+
"qweight": self.qweight.detach().cpu(),
|
| 134 |
+
"lut": self.lut.detach().cpu().half(),
|
| 135 |
+
"bias": self.bias.detach().cpu() if self.bias is not None else None,
|
| 136 |
+
}
|
| 137 |
+
return info
|
| 138 |
+
|
| 139 |
+
def forward(self, inp, **kwargs):
|
| 140 |
+
x = inp.view(-1, 1, self.in_features)
|
| 141 |
+
bs = x.shape[0]
|
| 142 |
+
m, k = self.out_features, self.in_features
|
| 143 |
+
if bs <= 8:
|
| 144 |
+
if self.vec_sz == 1:
|
| 145 |
+
wrapper = getattr(
|
| 146 |
+
torch.ops.ours_lib,
|
| 147 |
+
f"sq_pack_gemm_simt"
|
| 148 |
+
)
|
| 149 |
+
x = wrapper(x, self.qweight, self.lut, self.lut_bits)
|
| 150 |
+
else:
|
| 151 |
+
wrapper = getattr(
|
| 152 |
+
torch.ops.ours_lib,
|
| 153 |
+
f"vq_pack_gemm_simt_{bs}_{self.vec_sz}_{self.lut_bits}"
|
| 154 |
+
)
|
| 155 |
+
x = wrapper(x, self.qweight, self.lut)
|
| 156 |
+
else:
|
| 157 |
+
if self.vec_sz == 1:
|
| 158 |
+
wrapper = getattr(
|
| 159 |
+
torch.ops.ours_lib,
|
| 160 |
+
f"sq_pack_dequant_simt"
|
| 161 |
+
)
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
dq = wrapper(self.qweight, self.lut, self.lut_bits, m, k)
|
| 164 |
+
else:
|
| 165 |
+
wrapper = getattr(
|
| 166 |
+
torch.ops.ours_lib,
|
| 167 |
+
f"vq_pack_dequant_simt_{self.vec_sz}_{self.lut_bits}"
|
| 168 |
+
)
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
dq = wrapper(self.qweight, self.lut, m, k)
|
| 171 |
+
x = (x.to(dq.dtype) @ dq.T)
|
| 172 |
+
return x.view(*inp.shape[:-1], m).to(inp.dtype)
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def gen_layer_from_info(info):
|
| 176 |
+
layer = VQLinearPackSIMT(info["in_features"], info["out_features"], info["lut_bits"], info["vec_sz"], info["bias"] is not None, info["dtype"])
|
| 177 |
+
if info["vec_sz"] <= 2:
|
| 178 |
+
from lib.quantizer.quant_op import convert_tensor_core_to_simt
|
| 179 |
+
# qweight is stored in tensor core format in default.
|
| 180 |
+
# we should convert it to simt format.
|
| 181 |
+
converted_qweight = convert_tensor_core_to_simt(info["qweight"], info["out_features"], info["in_features"], info["vec_sz"], info["lut_bits"], code_n=info["lut_bits"])
|
| 182 |
+
layer.qweight.data.copy_(converted_qweight)
|
| 183 |
+
else:
|
| 184 |
+
layer.qweight.data.copy_(info["qweight"])
|
| 185 |
+
layer.lut.data.copy_(info["lut"])
|
| 186 |
+
if info["bias"] is not None:
|
| 187 |
+
layer.bias.data.copy_(info["bias"])
|
| 188 |
+
return layer
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def merge_infos(info1, info2):
|
| 192 |
+
assert info1["in_features"] == info2["in_features"]
|
| 193 |
+
assert info1["lut_bits"] == info2["lut_bits"]
|
| 194 |
+
assert info1["vec_sz"] == info2["vec_sz"]
|
| 195 |
+
assert info1["bias"] is None and info2["bias"] is None
|
| 196 |
+
assert info1["dtype"] == info2["dtype"]
|
| 197 |
+
if not torch.allclose(info1["lut"], info2["lut"], atol=1e-4):
|
| 198 |
+
print("warning: lut is not close. it is unexpected behavior if you do not use dummy quantizers.")
|
| 199 |
+
info = {}
|
| 200 |
+
info["in_features"] = info1["in_features"]
|
| 201 |
+
info["out_features"] = info1["out_features"] + info2["out_features"]
|
| 202 |
+
info["lut_bits"] = info1["lut_bits"]
|
| 203 |
+
info["vec_sz"] = info1["vec_sz"]
|
| 204 |
+
info["bias"] = None
|
| 205 |
+
info["dtype"] = info1["dtype"]
|
| 206 |
+
info["qweight"] = torch.cat([info1["qweight"], info2["qweight"]], dim=0)
|
| 207 |
+
info["lut"] = info1["lut"]
|
| 208 |
+
return info
|
lib/quantizer/__pycache__/comb_quant.cpython-311.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
lib/quantizer/__pycache__/nuq_op.cpython-311.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.1.nbc
ADDED
|
Binary file (43.8 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.nbi
ADDED
|
Binary file (1.59 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.pack_32-242.py311.1.nbc
ADDED
|
Binary file (39.4 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.pack_32-242.py311.nbi
ADDED
|
Binary file (1.62 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b3053fa289922289be50c3c5f21b9101f8426050f437286e34141275fa858e8
|
| 3 |
+
size 116854
|
lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.nbi
ADDED
|
Binary file (1.59 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.1.nbc
ADDED
|
Binary file (67.1 kB). View file
|
|
|
lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.nbi
ADDED
|
Binary file (1.71 kB). View file
|
|
|
lib/quantizer/__pycache__/quant_op.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
lib/quantizer/__pycache__/tcq_quant.cpython-311.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
lib/quantizer/__pycache__/vq_quant.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
lib/quantizer/__pycache__/vq_quant_ldlq.cpython-311.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
lib/quantizer/comb_quant.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from lib.quantizer.quant_op import load_hessian, load_group_hessian
|
| 3 |
+
from lib.quantizer.tcq_quant import qtip_quantize_mat, linear_to_incoherent_for_tcq, Args
|
| 4 |
+
from lib.linear import CombLinearTCQ, CombtLinearTCQ
|
| 5 |
+
from lib.codebook.bitshift import bitshift_codebook
|
| 6 |
+
from lib.utils import clean
|
| 7 |
+
import time
|
| 8 |
+
from lib.utils import block_LDL
|
| 9 |
+
from lib.algo.ldlq import LDLQ_combt
|
| 10 |
+
|
| 11 |
+
def pack_trellis(Qidxs, td_x, td_y, cb, m, n, KV, V):
|
| 12 |
+
Qidxs = Qidxs.cpu()
|
| 13 |
+
packed = cb.pack_trellis(
|
| 14 |
+
Qidxs.reshape(m // td_x, td_x, n // td_y,
|
| 15 |
+
td_y // V).transpose(1, 2).reshape(
|
| 16 |
+
-1, td_x * td_y // V))
|
| 17 |
+
|
| 18 |
+
packed_8 = packed.view(torch.uint8).view(-1, 2)
|
| 19 |
+
packed_4 = torch.cat([packed_8.unsqueeze(-1) & (2 ** 4 - 1), (packed_8.unsqueeze(-1) & (2 ** 8 - 2 ** 4)) >> 4], dim=-1).view(-1, 4).flip(
|
| 20 |
+
(-1, ))
|
| 21 |
+
|
| 22 |
+
packed_4 = packed_4.reshape(m // 16 // 2, 2, n // 16 // 2, 2, 16 * 16 // 8,
|
| 23 |
+
KV).permute(0, 2, 4, 3, 1, 5).flip(
|
| 24 |
+
(-1, )).contiguous().flatten()
|
| 25 |
+
packed_8 = torch.sum(packed_4.view(-1, 2) * torch.Tensor([[1, 2 ** 4]]).to(torch.uint8), dim=-1).to(torch.uint8).contiguous()
|
| 26 |
+
packed = packed_8.view(torch.int16).reshape(packed.shape).cuda()
|
| 27 |
+
return packed
|
| 28 |
+
|
| 29 |
+
def combt_quantize_mat(Wr, HRr, Wscale, cb1, cb2, td_x=16, td_y=16, KV=(4,5), V=2, use_hess=True):
|
| 30 |
+
(m, n) = Wr.shape
|
| 31 |
+
Wr = Wr.to(torch.float64)
|
| 32 |
+
HRr_orig = HRr.clone()
|
| 33 |
+
gs = HRr.shape[-1]
|
| 34 |
+
LRrs = []
|
| 35 |
+
diag = torch.arange(n, device=HRr.device)
|
| 36 |
+
if not use_hess:
|
| 37 |
+
eye = torch.eye(n, device=Wr.device, dtype=torch.float64)
|
| 38 |
+
LRr, D = block_LDL(eye, td_y)
|
| 39 |
+
LRr[diag, diag] = 0
|
| 40 |
+
LRrs.append(LRr)
|
| 41 |
+
else:
|
| 42 |
+
for i in range(gs):
|
| 43 |
+
LRr, D = block_LDL(HRr[:,:,i], td_y)
|
| 44 |
+
LRr[diag, diag] = 0
|
| 45 |
+
LRrs.append(LRr)
|
| 46 |
+
|
| 47 |
+
args = Args(td_x, td_y, V)
|
| 48 |
+
|
| 49 |
+
Qidxs_list = []
|
| 50 |
+
hatWr_list = []
|
| 51 |
+
for i in range(gs):
|
| 52 |
+
cur_Wr = Wr[m // gs * i:m // gs * (i+1)]
|
| 53 |
+
hatWr, Qidxs = LDLQ_combt(cur_Wr, LRrs[i], cb1.cuda(), cb2.cuda(), args, for_kernel=True)
|
| 54 |
+
hatWr_list.append(hatWr)
|
| 55 |
+
Qidxs_list.append(Qidxs)
|
| 56 |
+
torch._dynamo.reset()
|
| 57 |
+
|
| 58 |
+
hatWr = torch.cat(hatWr_list, dim=0)
|
| 59 |
+
Qidxs = torch.cat(Qidxs_list, dim=0)
|
| 60 |
+
assert hatWr.shape == Wr.shape, f"hatWr.shape {hatWr.shape} != Wr.shape {Wr.shape}"
|
| 61 |
+
|
| 62 |
+
packed1 = pack_trellis(Qidxs[:, :n//2//V].contiguous(), td_x, td_y, cb1, m, n//2, KV[0], V)
|
| 63 |
+
packed2 = pack_trellis(Qidxs[:, n//2//V:].contiguous(), td_x, td_y, cb2, m, n//2, KV[1], V)
|
| 64 |
+
|
| 65 |
+
Wr *= Wscale.reshape(-1, 1)
|
| 66 |
+
hatWr *= Wscale.reshape(-1, 1)
|
| 67 |
+
|
| 68 |
+
orig_err = (Wr - hatWr).pow(2).mean()
|
| 69 |
+
err = orig_err / Wr.pow(2).mean()
|
| 70 |
+
print(
|
| 71 |
+
f'err {err.item()} orig_err {orig_err.item()}'
|
| 72 |
+
)
|
| 73 |
+
quant_info = {
|
| 74 |
+
"quantizer": "combt_ldlq",
|
| 75 |
+
"td_x": td_x,
|
| 76 |
+
"td_y": td_y,
|
| 77 |
+
"KV": KV,
|
| 78 |
+
"V": V,
|
| 79 |
+
"tlut_bits": cb1.tlut_bits,
|
| 80 |
+
"use_hess": use_hess,
|
| 81 |
+
"orig_err": orig_err.item(),
|
| 82 |
+
"err": err.item(),
|
| 83 |
+
}
|
| 84 |
+
return packed1, packed2, hatWr, quant_info
|
| 85 |
+
|
| 86 |
+
def inc_linear_to_inc_combt_linear(inc_linear, HRr, cb1, cb2, td_x=16, td_y=16, in_part=(2048, 2048), KV=(3, 4), V=2, scale_override=0.9, use_hess=True):
|
| 87 |
+
Wr = (inc_linear.linear.weight.data * scale_override).to(HRr.dtype)
|
| 88 |
+
Wscale = inc_linear.Wscale.data / scale_override
|
| 89 |
+
inc_linear.Wscale.data.copy_(Wscale)
|
| 90 |
+
assert in_part[0] + in_part[1] == Wr.shape[1], "in_part is not correct"
|
| 91 |
+
assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
|
| 92 |
+
|
| 93 |
+
packed1, packed2, hatWr, quant_info = combt_quantize_mat(Wr, HRr, Wscale, cb1, cb2, td_x=td_x, td_y=td_y, KV=KV, V=V, use_hess=use_hess)
|
| 94 |
+
torch._dynamo.reset()
|
| 95 |
+
out_features, in_features = Wr.shape
|
| 96 |
+
comb_linear = CombtLinearTCQ(
|
| 97 |
+
in_features,
|
| 98 |
+
out_features,
|
| 99 |
+
td_x=td_x,
|
| 100 |
+
td_y=td_y,
|
| 101 |
+
in_part=in_part,
|
| 102 |
+
L=16,
|
| 103 |
+
KV=KV,
|
| 104 |
+
V=V,
|
| 105 |
+
tlut_bits=cb1.tlut_bits,
|
| 106 |
+
bias=inc_linear.bias is not None,
|
| 107 |
+
dtype=inc_linear.dtype,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
comb_linear.trellis1.data.copy_(packed1)
|
| 111 |
+
comb_linear.trellis2.data.copy_(packed2)
|
| 112 |
+
comb_linear.tlut.data.copy_(cb1.tlut)
|
| 113 |
+
inc_linear.linear = comb_linear
|
| 114 |
+
return inc_linear, quant_info
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def inc_linear_to_inc_comb_linear(inc_linear, HRr, cb1, cb2, td_x=16, td_y=16, out_part=(2048, 2048), KV=(3, 4), V=2, scale_override=0.9, use_hess=True):
|
| 118 |
+
Wr = (inc_linear.linear.weight.data * scale_override).to(HRr.dtype)
|
| 119 |
+
Wscale = inc_linear.Wscale.data / scale_override
|
| 120 |
+
inc_linear.Wscale.data.copy_(Wscale)
|
| 121 |
+
assert out_part[0] + out_part[1] == Wr.shape[0], "out_part is not correct"
|
| 122 |
+
assert len(HRr.shape) == 3 and HRr.shape[0] == HRr.shape[1] and HRr.shape[-1] == 1, f"support only none-grouped hessian but shape: {HRr.shape}"
|
| 123 |
+
|
| 124 |
+
packed1, hatWr1, quant_info1 = qtip_quantize_mat(Wr[:out_part[0]], HRr, Wscale[:out_part[0]], cb1, td_x=td_x, td_y=td_y, KV=KV[0], V=V, use_hess=use_hess)
|
| 125 |
+
torch._dynamo.reset()
|
| 126 |
+
packed2, hatWr2, quant_info2 = qtip_quantize_mat(Wr[out_part[0]:], HRr, Wscale[out_part[0]:], cb2, td_x=td_x, td_y=td_y, KV=KV[1], V=V, use_hess=use_hess)
|
| 127 |
+
torch._dynamo.reset()
|
| 128 |
+
out_features, in_features = Wr.shape
|
| 129 |
+
comb_linear = CombLinearTCQ(
|
| 130 |
+
in_features,
|
| 131 |
+
out_features,
|
| 132 |
+
td_x=td_x,
|
| 133 |
+
td_y=td_y,
|
| 134 |
+
out_part=out_part,
|
| 135 |
+
L=16,
|
| 136 |
+
KV=KV,
|
| 137 |
+
V=V,
|
| 138 |
+
tlut_bits=cb1.tlut_bits,
|
| 139 |
+
bias=inc_linear.bias is not None,
|
| 140 |
+
dtype=inc_linear.dtype,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
comb_linear.trellis1.data.copy_(packed1)
|
| 144 |
+
comb_linear.trellis2.data.copy_(packed2)
|
| 145 |
+
comb_linear.tlut.data.copy_(cb1.tlut)
|
| 146 |
+
|
| 147 |
+
hatWr = torch.cat([hatWr1, hatWr2], dim=0).to(HRr.dtype)
|
| 148 |
+
orig_err = (Wr - hatWr).pow(2).mean()
|
| 149 |
+
err = orig_err / Wr.pow(2).mean()
|
| 150 |
+
|
| 151 |
+
quant_info = {
|
| 152 |
+
"quantizer": "comb_ldlq",
|
| 153 |
+
"td_x": td_x,
|
| 154 |
+
"td_y": td_y,
|
| 155 |
+
"KV": KV,
|
| 156 |
+
"V": V,
|
| 157 |
+
"use_hess": use_hess,
|
| 158 |
+
"orig_err": orig_err.item(),
|
| 159 |
+
"err": err.item(),
|
| 160 |
+
"quant_info1": quant_info1,
|
| 161 |
+
"quant_info2": quant_info2,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
inc_linear.linear = comb_linear
|
| 165 |
+
return inc_linear, quant_info
|
| 166 |
+
|
| 167 |
+
def linear_to_comb_linear(target_layer, hess_path, cb1, cb2, scale_override=0.9, out_part=(2048, 2048), KV=[3, 4], V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
|
| 168 |
+
assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
|
| 169 |
+
t0 = time.time()
|
| 170 |
+
out_features, in_features = target_layer.weight.shape
|
| 171 |
+
if ghess_key == "":
|
| 172 |
+
HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
|
| 173 |
+
else:
|
| 174 |
+
HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
|
| 175 |
+
layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb1, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
|
| 176 |
+
HRr = HRr.cuda()
|
| 177 |
+
layer = layer.cuda()
|
| 178 |
+
layer, quant_info = inc_linear_to_inc_comb_linear(layer, HRr, cb1, cb2, scale_override=1.0, td_x=16, td_y=16, out_part=out_part, KV=KV, V=V, use_hess=use_hess)
|
| 179 |
+
quant_info["scale_override"] = scale_override
|
| 180 |
+
quant_info["hess_path"] = hess_path
|
| 181 |
+
quant_info["time"] = time.time() - t0
|
| 182 |
+
|
| 183 |
+
return layer.to(torch.float16), quant_info
|
| 184 |
+
|
| 185 |
+
def linear_to_combt_linear(target_layer, hess_path, cb1, cb2, scale_override=0.9, in_part=(2048, 2048), KV=[3, 4], V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
|
| 186 |
+
assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
|
| 187 |
+
t0 = time.time()
|
| 188 |
+
out_features, in_features = target_layer.weight.shape
|
| 189 |
+
if ghess_key == "":
|
| 190 |
+
HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
|
| 191 |
+
else:
|
| 192 |
+
HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
|
| 193 |
+
layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb1, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
|
| 194 |
+
HRr = HRr.cuda()
|
| 195 |
+
layer = layer.cuda()
|
| 196 |
+
layer, quant_info = inc_linear_to_inc_combt_linear(layer, HRr, cb1, cb2, scale_override=1.0, td_x=16, td_y=16, in_part=in_part, KV=KV, V=V, use_hess=use_hess)
|
| 197 |
+
quant_info["scale_override"] = scale_override
|
| 198 |
+
quant_info["hess_path"] = hess_path
|
| 199 |
+
quant_info["time"] = time.time() - t0
|
| 200 |
+
|
| 201 |
+
return layer.to(torch.float16), quant_info
|
lib/quantizer/nuq_op.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
def get_progress_bar(total: int, desc: str):
|
| 9 |
+
return tqdm(
|
| 10 |
+
total=total,
|
| 11 |
+
desc=desc,
|
| 12 |
+
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def objective_function(
|
| 17 |
+
W: torch.Tensor,
|
| 18 |
+
H: torch.Tensor,
|
| 19 |
+
P: torch.Tensor,
|
| 20 |
+
C: torch.Tensor,
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Calculate the quantization error (objective value).
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
W: Weight matrix (row_count * group_count, group_size)
|
| 27 |
+
H: Hessian matrix (blk_num, group_size, group_size)
|
| 28 |
+
P: Assignment matrix (row_count * group_count, group_size//vec_sz, n_cluster)
|
| 29 |
+
C: Centroid matrix (n_cluster, vec_sz)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Objective value (scalar)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda")
|
| 36 |
+
P, C = P.to(device), C.to(device)
|
| 37 |
+
W_hat = torch.einsum('ijc,ck->ijk', P, C) # Shape: (row_count * group_count, group_size//vec_sz, vec_sz)
|
| 38 |
+
W_hat = W_hat.view(W_hat.shape[0], -1) # Shape: (row_count * group_count, group_size)
|
| 39 |
+
delta_w = W_hat - W
|
| 40 |
+
|
| 41 |
+
blk_num = H.shape[0]
|
| 42 |
+
blk_size = W.shape[0] // blk_num
|
| 43 |
+
|
| 44 |
+
delta_w = delta_w.reshape(blk_num, blk_size, delta_w.shape[-1])
|
| 45 |
+
objective_value = torch.einsum('nij,njk,nik->i', delta_w, H, delta_w)
|
| 46 |
+
total_error = objective_value.mean()
|
| 47 |
+
|
| 48 |
+
return total_error
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def parallel_objective_function_sub(
|
| 53 |
+
W: torch.Tensor, # Shape: (b, g_cd)
|
| 54 |
+
quadratic: torch.Tensor, # Shape: (g_cd, g_cd)
|
| 55 |
+
linear: torch.Tensor, # Shape: (b, g_cd)
|
| 56 |
+
W_hat_options: torch.Tensor, # Shape: (b, g_cd, num_options)
|
| 57 |
+
) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Calculate the quantization error (objective value), and return the list of errors for each options.
|
| 60 |
+
W_hat is a tensor with possible options concatenated along the last dimension.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
W: Weight matrix (b, g_cd)
|
| 64 |
+
quadratic: torch.Tensor, # Shape: (g_cd, g_cd)
|
| 65 |
+
linear: torch.Tensor, # Shape: (b, g_cd)
|
| 66 |
+
W_hat_options: torch.Tensor, # Shape: (b, g_cd, num_options)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Possible objective values for each options (b, num_options)
|
| 70 |
+
"""
|
| 71 |
+
device = torch.device("cuda")
|
| 72 |
+
W_hat_options = W_hat_options.to(device)
|
| 73 |
+
b, g_cd, num_options = W_hat_options.shape
|
| 74 |
+
|
| 75 |
+
delta_w_g = W_hat_options - W.unsqueeze(2).expand(-1, -1, num_options)
|
| 76 |
+
|
| 77 |
+
quadratic_term = torch.einsum('jk,ijp,ikp->ip', quadratic, delta_w_g, delta_w_g)
|
| 78 |
+
linear_term = torch.einsum('ij,ijp->ip', linear, delta_w_g)
|
| 79 |
+
total_error_quad = quadratic_term + linear_term
|
| 80 |
+
return total_error_quad
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def update_batch_P(
|
| 85 |
+
W: torch.Tensor, # Shape: (b, group_size)
|
| 86 |
+
H: torch.Tensor, # Shape: (blk_num, group_size, group_size)
|
| 87 |
+
P: torch.Tensor, # Shape: (b, group_size // vec_sz, n_cluster)
|
| 88 |
+
C: torch.Tensor, # Shape: (n_cluster, vec_sz)
|
| 89 |
+
iteration: int,
|
| 90 |
+
g_cd: int, # Number of weights to update at a time
|
| 91 |
+
cd_cycles: int,
|
| 92 |
+
verbose: bool = False,
|
| 93 |
+
):
|
| 94 |
+
device = torch.device("cuda")
|
| 95 |
+
C = C.to(device)
|
| 96 |
+
C_ = C.unsqueeze(0).expand(P.shape[0], -1, -1)
|
| 97 |
+
assignments_prev = P.argmax(dim=-1).to(device) # Shape: (b, group_size // vec_sz)
|
| 98 |
+
b, d = assignments_prev.shape
|
| 99 |
+
n_cluster, vec_sz = C_.size(1), C_.size(2)
|
| 100 |
+
assert H.shape[0] == 1
|
| 101 |
+
H_ = H[0]
|
| 102 |
+
|
| 103 |
+
assignments = assignments_prev.clone()
|
| 104 |
+
update_size = cd_cycles * d
|
| 105 |
+
|
| 106 |
+
for update_start_idx in range(0, update_size, g_cd):
|
| 107 |
+
start_idx = update_start_idx % d
|
| 108 |
+
end_idx = min(start_idx + g_cd, d)
|
| 109 |
+
indices = torch.arange(start_idx * vec_sz, end_idx * vec_sz, device=device)
|
| 110 |
+
indices_assignments = torch.arange(start_idx, end_idx, device=device)
|
| 111 |
+
# Generate all possible assignments for the group
|
| 112 |
+
num_options = n_cluster ** g_cd
|
| 113 |
+
if num_options > 1e6:
|
| 114 |
+
print(f"Skipping group starting at index {start_idx} due to large number of assignments ({num_options}).")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Create all possible assignments for the group
|
| 118 |
+
from itertools import product
|
| 119 |
+
assignments_list = list(product(range(n_cluster), repeat=g_cd))
|
| 120 |
+
assignments_array = torch.tensor(assignments_list, device=device).T # Shape: (g_cd, num_options)
|
| 121 |
+
assignments_array = assignments_array.unsqueeze(0).expand(b, -1, -1) # Shape: (b, g_cd, num_options, vec_sz)
|
| 122 |
+
|
| 123 |
+
# Creating options for g_cd weights
|
| 124 |
+
C_expanded = C_.unsqueeze(1).expand(-1, g_cd, -1, -1) # Shape: (b, g_cd, n_cluster, vec_sz)
|
| 125 |
+
W_g_hat_options = torch.gather(C_expanded, dim=2, index=assignments_array.unsqueeze(-1).expand(-1, -1, -1, vec_sz)) # Shape: (b, g_cd, num_options, vec_sz)
|
| 126 |
+
|
| 127 |
+
# Gathering original quantized weights and compute linear & quadratic terms
|
| 128 |
+
# Expand C and gather original weights
|
| 129 |
+
C_expanded_org = C_.unsqueeze(1).expand(-1, d, -1, -1) # Shape: (b, d, n_cluster, vec_sz)
|
| 130 |
+
W_hat_org = torch.gather(C_expanded_org, dim=2, index=assignments.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, vec_sz)).squeeze(2) # Shape: (b, d, vec_sz)
|
| 131 |
+
|
| 132 |
+
# Compute deltas
|
| 133 |
+
delta_w_org = W_hat_org.view(b, -1) - W # Shape: (b, group_size)
|
| 134 |
+
|
| 135 |
+
# Get indices and slices
|
| 136 |
+
notg_indices = torch.cat([
|
| 137 |
+
torch.arange(0, start_idx * vec_sz, device=device),
|
| 138 |
+
torch.arange(end_idx * vec_sz, d * vec_sz, device=device)
|
| 139 |
+
])
|
| 140 |
+
|
| 141 |
+
H_g_notg = H_[indices, :][:, notg_indices] # Shape: (g_cd * vec_sz, group_size - g_cd * vec_sz)
|
| 142 |
+
|
| 143 |
+
delta_w_org_notg = delta_w_org[:, notg_indices].to(device) # Shape: (b, group_size - g_cd * vec_sz)
|
| 144 |
+
|
| 145 |
+
# Compute quadratic and linear terms
|
| 146 |
+
quadratic = H_[indices, :][:, indices] # Shape: (g_cd * vec_sz, g_cd * vec_sz)
|
| 147 |
+
linear = 2 * torch.einsum('gd,id->ig', H_g_notg, delta_w_org_notg) # Shape: (b, g_cd * vec_sz)
|
| 148 |
+
W_g = W[:, indices] # Shape: (b, g_cd * vec_sz)
|
| 149 |
+
|
| 150 |
+
W_g_hat_options = W_g_hat_options.permute(0, 1, 3, 2).view(b, g_cd * vec_sz, num_options) # Shape: (b, g_cd * vec_sz, num_options)
|
| 151 |
+
# Objective function computation
|
| 152 |
+
cur_obj_value = parallel_objective_function_sub(W_g, quadratic, linear, W_g_hat_options) # Shape: (b, num_options)
|
| 153 |
+
|
| 154 |
+
# Update assignments
|
| 155 |
+
min_obj, argmin_obj = cur_obj_value.min(dim=1, keepdim=True)
|
| 156 |
+
expanded_argmin_obj = argmin_obj.unsqueeze(1).expand(-1, g_cd, -1).to(device)
|
| 157 |
+
assignments[:, indices_assignments] = assignments_array.gather(dim=2, index=expanded_argmin_obj).squeeze(-1) # Shape: (row_count * group_count, g_cd)
|
| 158 |
+
|
| 159 |
+
num_changed = (assignments_prev != assignments).sum().item()
|
| 160 |
+
total_assignments = assignments_prev.numel()
|
| 161 |
+
percentage_changed = num_changed / total_assignments * 100
|
| 162 |
+
if verbose:
|
| 163 |
+
logging.info(f"Percentage of assignments changed: {percentage_changed:.2f}%%")
|
| 164 |
+
|
| 165 |
+
# Convert assignments to one-hot encoding to create new P
|
| 166 |
+
P = torch.zeros((b, d, n_cluster), dtype=torch.float32, device=assignments.device)
|
| 167 |
+
P.scatter_(2, assignments.long().unsqueeze(-1), 1.0)
|
| 168 |
+
|
| 169 |
+
return P
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def update_P(
|
| 173 |
+
W: torch.Tensor, # Shape: (row_count * group_count, group_size)
|
| 174 |
+
H: torch.Tensor, # Shape: (blk_num, group_size, group_size)
|
| 175 |
+
P: torch.Tensor, # Shape: (row_count * group_count, group_size//vec_sz, n_cluster)
|
| 176 |
+
C: torch.Tensor, # Shape: (n_cluster, vec_sz)
|
| 177 |
+
iteration: int,
|
| 178 |
+
g_cd: int = 1,
|
| 179 |
+
cd_cycles: int = 4,
|
| 180 |
+
):
|
| 181 |
+
n_cluster = C.shape[0]
|
| 182 |
+
batch_output_size = 4096 # * 32 // max(32, n_cluster)
|
| 183 |
+
device = torch.device("cuda")
|
| 184 |
+
updated_P_list = []
|
| 185 |
+
|
| 186 |
+
pb = get_progress_bar((W.size(0) - 1) // batch_output_size + 1, f"Updating P (cd_cycles={cd_cycles})")
|
| 187 |
+
for out_idx in range(0, W.size(0), batch_output_size):
|
| 188 |
+
torch.cuda.reset_peak_memory_stats() # Reset memory stats at start of iteration
|
| 189 |
+
|
| 190 |
+
W_batch = W[out_idx:out_idx+batch_output_size].to(device)
|
| 191 |
+
P_batch = P[out_idx:out_idx+batch_output_size].to(device)
|
| 192 |
+
C_batch = C.to(device)
|
| 193 |
+
|
| 194 |
+
verbose = False # (out_idx == 0)
|
| 195 |
+
|
| 196 |
+
updated_P_batch = update_batch_P(W_batch, H, P_batch, C_batch, iteration, g_cd=g_cd, cd_cycles=cd_cycles, verbose=verbose).cpu()
|
| 197 |
+
updated_P_list.append(updated_P_batch)
|
| 198 |
+
pb.update(1)
|
| 199 |
+
pb.close()
|
| 200 |
+
|
| 201 |
+
# Log max CUDA memory usage
|
| 202 |
+
P = torch.cat(updated_P_list, dim=0)
|
| 203 |
+
return P
|
| 204 |
+
|
| 205 |
+
def project_to_pd(H, eps=1e-2):
|
| 206 |
+
H_sym = (H + H.T) / 2
|
| 207 |
+
eigenvalues, eigenvectors = torch.linalg.eigh(H_sym)
|
| 208 |
+
eigenvalues = torch.clamp(eigenvalues, min=eps)
|
| 209 |
+
H_spd = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T
|
| 210 |
+
H_spd = (H_spd + H_spd.T) / 2
|
| 211 |
+
H_spd = H_spd.to(H.dtype)
|
| 212 |
+
return H_spd
|
| 213 |
+
|
| 214 |
+
import torch
|
| 215 |
+
def kron_with_identity_vec(P: torch.Tensor, vec_sz: int) -> torch.Tensor:
|
| 216 |
+
B, d, c = P.shape
|
| 217 |
+
I_vec_sz = torch.eye(vec_sz, vec_sz).to(P.device)
|
| 218 |
+
P_expanded = P.unsqueeze(-1).unsqueeze(-1) # (B, d, c, 1, 1)
|
| 219 |
+
I_expanded = I_vec_sz.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, vec_sz, vec_sz)
|
| 220 |
+
|
| 221 |
+
out = P_expanded * I_expanded # (B, d, c, vec_sz, vec_sz)
|
| 222 |
+
out = out.permute(0, 1, 3, 2, 4) # (B, d, vec_sz, c, vec_sz)
|
| 223 |
+
out = out.reshape(B, d * vec_sz, c * vec_sz)
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
def update_C(
|
| 227 |
+
W: torch.Tensor, # (row, gs)
|
| 228 |
+
H: torch.Tensor, # (1, gs, gs)
|
| 229 |
+
P: torch.Tensor, # (row, gs//vec_sz, n_cluster)
|
| 230 |
+
C: torch.Tensor, # (n_cluster, vec_sz)
|
| 231 |
+
batch_size: int = 256
|
| 232 |
+
):
|
| 233 |
+
device = W.device
|
| 234 |
+
dtype = W.dtype
|
| 235 |
+
|
| 236 |
+
L = torch.linalg.cholesky(H[0]) # (gs, gs)
|
| 237 |
+
LT = L.transpose(0, 1) # (gs, gs)
|
| 238 |
+
|
| 239 |
+
row, gs = W.shape
|
| 240 |
+
n_cluster, vec_sz = C.shape
|
| 241 |
+
|
| 242 |
+
A = torch.zeros(n_cluster * vec_sz, n_cluster * vec_sz, device=device, dtype=dtype)
|
| 243 |
+
b = torch.zeros(n_cluster * vec_sz, device=device, dtype=dtype)
|
| 244 |
+
|
| 245 |
+
for start in range(0, row, batch_size):
|
| 246 |
+
end = min(start + batch_size, row)
|
| 247 |
+
|
| 248 |
+
# (B, gs // vec_sz, n_cluster)
|
| 249 |
+
P_chunk = P[start:end].to(device)
|
| 250 |
+
# (B, gs)
|
| 251 |
+
W_chunk = W[start:end].to(device)
|
| 252 |
+
B = P_chunk.shape[0]
|
| 253 |
+
|
| 254 |
+
# kronecker product with identity.
|
| 255 |
+
P_chunk_expanded = kron_with_identity_vec(P_chunk, vec_sz) # Shape: (B, gs, n_cluster * vec_sz)
|
| 256 |
+
|
| 257 |
+
X_temp = torch.einsum('ij,bjk->bik', LT, P_chunk_expanded) # Shape: (B, gs, n_cluster * vec_sz)
|
| 258 |
+
W_temp = torch.einsum('ij,bj->bi', LT, W_chunk) # Shape: (B, gs)
|
| 259 |
+
|
| 260 |
+
A += torch.einsum('bik,bil->kl', X_temp, X_temp) # Shape: (n_cluster * vec_sz, n_cluster * vec_sz)
|
| 261 |
+
b += torch.einsum('bik,bi->k', X_temp, W_temp) # Shape: (n_cluster * vec_sz)
|
| 262 |
+
|
| 263 |
+
C_flat = torch.linalg.solve(A, b) # Shape: (n_cluster * vec_sz)
|
| 264 |
+
C = C_flat.view(n_cluster, vec_sz) # Shape: (n_cluster, vec_sz)
|
| 265 |
+
return C
|
| 266 |
+
|
| 267 |
+
def train_least_squares(
|
| 268 |
+
W: np.ndarray, # Shape: (row_count * group_count, group_size)
|
| 269 |
+
init_P: np.ndarray, # Shape: (row_count * group_count, group_size//vec_sz, n_cluster)
|
| 270 |
+
init_centroids: np.ndarray, # Shape: (n_cluster, vec_sz)
|
| 271 |
+
H: np.ndarray, # Shape: (blk_num, group_size, group_size)
|
| 272 |
+
num_iterations: int = 3,
|
| 273 |
+
cd_cycles: int = 4,
|
| 274 |
+
eig_threshold: float = 1e-3,
|
| 275 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 276 |
+
device = torch.device("cuda")
|
| 277 |
+
|
| 278 |
+
P = torch.tensor(init_P, dtype=torch.float32, device="cpu")
|
| 279 |
+
C = torch.tensor(init_centroids, dtype=torch.float32, device="cpu")
|
| 280 |
+
W = torch.tensor(W, dtype=torch.float32).to(device)
|
| 281 |
+
H = torch.tensor(H, dtype=torch.float32).to(device)
|
| 282 |
+
|
| 283 |
+
# eigenvalues = torch.linalg.eigvalsh(H)
|
| 284 |
+
# for i in range(eigenvalues.shape[0]):
|
| 285 |
+
# top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
|
| 286 |
+
# logging.info(f"{i+1}-th H has Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}, Projecting to PD with eps=1e-6 for numerical stability")
|
| 287 |
+
# H[i] = project_to_pd(H[i], eps=1e-6)
|
| 288 |
+
|
| 289 |
+
# eps = eig_threshold * 10
|
| 290 |
+
# while not torch.all(eigenvalues[i] > eig_threshold):
|
| 291 |
+
# top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
|
| 292 |
+
# logging.info(f"{i+1}-th H not PD, Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}, Projecting to PD with eps={eps}")
|
| 293 |
+
# H[i] = project_to_pd(H[i], eps=eps)
|
| 294 |
+
# eigenvalues = torch.linalg.eigvalsh(H)
|
| 295 |
+
# eps *= 10
|
| 296 |
+
# top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
|
| 297 |
+
# logging.info(f"{i+1}-th H PD, Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}")
|
| 298 |
+
diag = torch.arange(H.shape[1], device=device)
|
| 299 |
+
for i in range(H.shape[0]):
|
| 300 |
+
avg_diag = torch.mean(torch.diag(H[i]))
|
| 301 |
+
damp, prev_damp = 1e-7, 0.
|
| 302 |
+
while True:
|
| 303 |
+
try:
|
| 304 |
+
torch.linalg.cholesky(H[i])
|
| 305 |
+
logging.info(f"{i+1}-th H is PD, dampening factor={prev_damp:.2e}")
|
| 306 |
+
break
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(e)
|
| 309 |
+
logging.info(f"{i+1}-th H is not PD, try dampening with factor={damp:.2e}")
|
| 310 |
+
H[i, diag, diag] += (damp - prev_damp) * avg_diag
|
| 311 |
+
prev_damp = damp
|
| 312 |
+
damp *= 10
|
| 313 |
+
if damp > 1e0:
|
| 314 |
+
exit()
|
| 315 |
+
|
| 316 |
+
best_obj_value = objective_function(W, H, P, C).item()
|
| 317 |
+
best_P, best_C = P.detach().cpu().clone(), C.detach().cpu().clone()
|
| 318 |
+
logging.info(f"Initial objective: {best_obj_value:.6f}")
|
| 319 |
+
|
| 320 |
+
log_dict = {"objective": [], "iteration": []}
|
| 321 |
+
log_dict["objective"].append(best_obj_value)
|
| 322 |
+
log_dict["iteration"].append(0)
|
| 323 |
+
|
| 324 |
+
for iteration in range(num_iterations):
|
| 325 |
+
start_time = time.time()
|
| 326 |
+
|
| 327 |
+
######### Update P #########
|
| 328 |
+
if iteration > 0:
|
| 329 |
+
P = update_P(W, H, P, C, iteration, cd_cycles=cd_cycles)
|
| 330 |
+
|
| 331 |
+
# Compute objective value for logging
|
| 332 |
+
obj_value = objective_function(W, H, P, C).item()
|
| 333 |
+
logging.info(f"Iteration {iteration + 1} (P update): Objective: {obj_value:.4f}")
|
| 334 |
+
log_dict["objective"].append(obj_value)
|
| 335 |
+
log_dict["iteration"].append(iteration + 1)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
######### Update C #########
|
| 339 |
+
C = update_C(W, H, P, C)
|
| 340 |
+
|
| 341 |
+
# Check if the objective value improved
|
| 342 |
+
current_obj_value = objective_function(W, H, P, C).item()
|
| 343 |
+
log_dict["objective"].append(current_obj_value)
|
| 344 |
+
log_dict["iteration"].append(iteration + 1)
|
| 345 |
+
if current_obj_value < best_obj_value:
|
| 346 |
+
best_obj_value = current_obj_value
|
| 347 |
+
best_P, best_C = P.detach().cpu().clone(), C.detach().cpu().clone()
|
| 348 |
+
logging.info(f"Iteration {iteration + 1} (C update): Objective: {current_obj_value:.4f} | Improved and using this one.")
|
| 349 |
+
else:
|
| 350 |
+
logging.info(f"Iteration {iteration + 1} (C update): Objective: {current_obj_value:.4f} | Not improved. Using previous best values.")
|
| 351 |
+
P, C = best_P, best_C
|
| 352 |
+
break # Early stopping
|
| 353 |
+
|
| 354 |
+
end_time = time.time()
|
| 355 |
+
|
| 356 |
+
logging.info(f"Iteration {iteration + 1} / {num_iterations} completed. "
|
| 357 |
+
f"Update time: {end_time - start_time:.2f} sec")
|
| 358 |
+
|
| 359 |
+
end_time = time.time()
|
| 360 |
+
logging.info(f"Least squares training time: {end_time - start_time:.2f} seconds")
|
| 361 |
+
|
| 362 |
+
P = P.detach().cpu()
|
| 363 |
+
C = C.detach().cpu().to(torch.float32)
|
| 364 |
+
|
| 365 |
+
return P, C, log_dict
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def test():
|
| 369 |
+
from lib.utils.kmeans import fit_kmeans
|
| 370 |
+
# set seed
|
| 371 |
+
torch.manual_seed(0)
|
| 372 |
+
np.random.seed(0)
|
| 373 |
+
torch.cuda.manual_seed(0)
|
| 374 |
+
torch.backends.cudnn.deterministic = True
|
| 375 |
+
torch.backends.cudnn.benchmark = False
|
| 376 |
+
|
| 377 |
+
vec_sz = 4
|
| 378 |
+
lut_size = (1 << (2 * vec_sz))
|
| 379 |
+
W = torch.randn(4096, 4096)
|
| 380 |
+
H = torch.randn(4096, 4096)
|
| 381 |
+
|
| 382 |
+
rand_data = torch.randn(10000, vec_sz)
|
| 383 |
+
C = fit_kmeans(rand_data, lut_size)[0]
|
| 384 |
+
print("C", C.shape)
|
| 385 |
+
|
| 386 |
+
W_vec = W.view(4096, 4096 // vec_sz, vec_sz)
|
| 387 |
+
W_vec = W_vec.unsqueeze(0) # Shape: (1, 4096, 2048, 2)
|
| 388 |
+
C_ = C.unsqueeze(1).unsqueeze(1) # Shape: (16, 1, 1, 2)
|
| 389 |
+
diff = W_vec - C_ # Shape: (16, 4096, 2048, 2)
|
| 390 |
+
dist_sq = diff.pow(2).sum(-1) # Shape: (16, 4096, 2048)
|
| 391 |
+
idx = dist_sq.argmin(dim=0) # Shape: (4096, 2048)
|
| 392 |
+
init_P = torch.zeros(4096, 4096 // vec_sz, lut_size)
|
| 393 |
+
init_P.scatter_(2, idx.unsqueeze(-1), 1)
|
| 394 |
+
|
| 395 |
+
H = H @ H.T
|
| 396 |
+
H = H + 1e-6 * torch.eye(4096, 4096)
|
| 397 |
+
H = H.unsqueeze(0)
|
| 398 |
+
|
| 399 |
+
P, C, log_dict = train_least_squares(
|
| 400 |
+
W=W.numpy(),
|
| 401 |
+
init_P=init_P.numpy(),
|
| 402 |
+
init_centroids=C.numpy(),
|
| 403 |
+
H=H.numpy(),
|
| 404 |
+
num_iterations=10,
|
| 405 |
+
cd_cycles=4,
|
| 406 |
+
eig_threshold=1e-3,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
for i in range(len(log_dict["objective"])):
|
| 410 |
+
logging.info(f"Iteration {log_dict['iteration'][i]}: Objective: {log_dict['objective'][i]:.4f}")
|
| 411 |
+
print("P", P.shape)
|
| 412 |
+
print("C", C.shape)
|
| 413 |
+
|
| 414 |
+
# recons
|
| 415 |
+
W_hat = torch.einsum('ijc,ck->ijk', P, C)
|
| 416 |
+
W_hat = W_hat.view(W_hat.shape[0], -1)
|
| 417 |
+
err = (W - W_hat).pow(2).mean()
|
| 418 |
+
print("err", err.item())
|
| 419 |
+
|
| 420 |
+
dWHdW = (W - W_hat) @ H[0] @ (W - W_hat).T
|
| 421 |
+
err_tr = torch.trace(dWHdW) / H.shape[1]
|
| 422 |
+
print("err_tr", err_tr.item())
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
logging.basicConfig(
|
| 426 |
+
level=logging.INFO,
|
| 427 |
+
format='[%(levelname)s] %(asctime)s - %(message)s',
|
| 428 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 429 |
+
)
|
| 430 |
+
test()
|
| 431 |
+
|
lib/quantizer/pack_op.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numba
|
| 2 |
+
import numpy as np
|
| 3 |
+
@numba.njit(cache=True)
|
| 4 |
+
def general_pack(unpacked, nbits, codeT, code_n):
|
| 5 |
+
'''
|
| 6 |
+
sequentially packing codes into codeT type array.
|
| 7 |
+
args:
|
| 8 |
+
unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
|
| 9 |
+
nbits: int
|
| 10 |
+
codeT: np.dtype
|
| 11 |
+
|
| 12 |
+
return: out_code (code_n,) dtype: codeT
|
| 13 |
+
'''
|
| 14 |
+
n_unpacked = unpacked.shape[0]
|
| 15 |
+
codeT_sz = codeT.itemsize * 8
|
| 16 |
+
assert n_unpacked * nbits / codeT_sz == code_n, "code_n must be equal to n_unpacked * nbits / codeT_sz"
|
| 17 |
+
out_code = np.zeros(code_n, dtype=codeT)
|
| 18 |
+
for i in range(n_unpacked):
|
| 19 |
+
val = codeT(unpacked[i])
|
| 20 |
+
offset = i * nbits
|
| 21 |
+
wIndex = offset // codeT_sz
|
| 22 |
+
bIndex = offset % codeT_sz
|
| 23 |
+
out_code[wIndex] |= (val << bIndex) & np.iinfo(codeT).max
|
| 24 |
+
|
| 25 |
+
bits_in_word = codeT_sz - bIndex
|
| 26 |
+
if bits_in_word < nbits:
|
| 27 |
+
upper = val >> bits_in_word
|
| 28 |
+
out_code[wIndex + 1] |= upper & np.iinfo(codeT).max
|
| 29 |
+
|
| 30 |
+
return out_code
|
| 31 |
+
|
| 32 |
+
@numba.njit(cache=True)
|
| 33 |
+
def general_pack_8(unpacked, nbits, code_n):
|
| 34 |
+
'''
|
| 35 |
+
sequentially packing codes into codeT type array.
|
| 36 |
+
args:
|
| 37 |
+
unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
|
| 38 |
+
nbits: int
|
| 39 |
+
codeT: np.dtype
|
| 40 |
+
|
| 41 |
+
return: out_code (code_n,) dtype: codeT
|
| 42 |
+
'''
|
| 43 |
+
n_unpacked = unpacked.shape[0]
|
| 44 |
+
assert n_unpacked * nbits / 8 == code_n, "code_n must be equal to n_unpacked * nbits / 8"
|
| 45 |
+
out_code = np.zeros(code_n, dtype=np.uint8)
|
| 46 |
+
for i in range(n_unpacked):
|
| 47 |
+
val = unpacked[i]
|
| 48 |
+
offset = i * nbits
|
| 49 |
+
wIndex = offset // 8
|
| 50 |
+
bIndex = offset % 8
|
| 51 |
+
out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint8).max
|
| 52 |
+
|
| 53 |
+
bits_in_word = 8 - bIndex
|
| 54 |
+
if bits_in_word < nbits:
|
| 55 |
+
upper = val >> bits_in_word
|
| 56 |
+
out_code[wIndex + 1] |= upper & np.iinfo(np.uint8).max
|
| 57 |
+
return out_code
|
| 58 |
+
|
| 59 |
+
@numba.njit(cache=True)
|
| 60 |
+
def general_pack_16(unpacked, nbits, code_n):
|
| 61 |
+
'''
|
| 62 |
+
sequentially packing codes into codeT type array.
|
| 63 |
+
args:
|
| 64 |
+
unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
|
| 65 |
+
nbits: int
|
| 66 |
+
codeT: np.dtype
|
| 67 |
+
|
| 68 |
+
return: out_code (code_n,) dtype: codeT
|
| 69 |
+
'''
|
| 70 |
+
n_unpacked = unpacked.shape[0]
|
| 71 |
+
assert n_unpacked * nbits / 16 == code_n, "code_n must be equal to n_unpacked * nbits / 16"
|
| 72 |
+
out_code = np.zeros(code_n, dtype=np.uint16)
|
| 73 |
+
for i in range(n_unpacked):
|
| 74 |
+
val = unpacked[i]
|
| 75 |
+
offset = i * nbits
|
| 76 |
+
wIndex = offset // 16
|
| 77 |
+
bIndex = offset % 16
|
| 78 |
+
out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint16).max
|
| 79 |
+
|
| 80 |
+
bits_in_word = 16 - bIndex
|
| 81 |
+
if bits_in_word < nbits:
|
| 82 |
+
upper = val >> bits_in_word
|
| 83 |
+
out_code[wIndex + 1] |= upper & np.iinfo(np.uint16).max
|
| 84 |
+
|
| 85 |
+
return out_code
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@numba.njit(cache=True)
|
| 89 |
+
def general_pack_32(unpacked, nbits, code_n):
|
| 90 |
+
'''
|
| 91 |
+
sequentially packing codes into codeT type array.
|
| 92 |
+
args:
|
| 93 |
+
unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
|
| 94 |
+
nbits: int
|
| 95 |
+
codeT: np.dtype
|
| 96 |
+
|
| 97 |
+
return: out_code (code_n,) dtype: codeT
|
| 98 |
+
'''
|
| 99 |
+
n_unpacked = unpacked.shape[0]
|
| 100 |
+
assert n_unpacked * nbits / 32 == code_n, "code_n must be equal to n_unpacked * nbits / 32"
|
| 101 |
+
out_code = np.zeros(code_n, dtype=np.uint32)
|
| 102 |
+
for i in range(n_unpacked):
|
| 103 |
+
val = unpacked[i]
|
| 104 |
+
offset = i * nbits
|
| 105 |
+
wIndex = offset // 32
|
| 106 |
+
bIndex = offset % 32
|
| 107 |
+
out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint32).max
|
| 108 |
+
|
| 109 |
+
bits_in_word = 32 - bIndex
|
| 110 |
+
if bits_in_word < nbits:
|
| 111 |
+
upper = val >> bits_in_word
|
| 112 |
+
out_code[wIndex + 1] |= upper & np.iinfo(np.uint32).max
|
| 113 |
+
|
| 114 |
+
return out_code
|
| 115 |
+
|
| 116 |
+
@numba.njit(cache=True)
|
| 117 |
+
def general_pack_64(unpacked, nbits, code_n):
|
| 118 |
+
'''
|
| 119 |
+
sequentially packing codes into codeT type array.
|
| 120 |
+
args:
|
| 121 |
+
unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
|
| 122 |
+
nbits: int
|
| 123 |
+
codeT: np.dtype
|
| 124 |
+
|
| 125 |
+
return: out_code (code_n,) dtype: codeT
|
| 126 |
+
'''
|
| 127 |
+
n_unpacked = unpacked.shape[0]
|
| 128 |
+
assert n_unpacked * nbits / 64 == code_n, "code_n must be equal to n_unpacked * nbits / 64"
|
| 129 |
+
out_code = np.zeros(code_n, dtype=np.uint64)
|
| 130 |
+
for i in range(n_unpacked):
|
| 131 |
+
val = unpacked[i]
|
| 132 |
+
offset = i * nbits
|
| 133 |
+
wIndex = offset // 64
|
| 134 |
+
bIndex = offset % 64
|
| 135 |
+
out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint64).max
|
| 136 |
+
|
| 137 |
+
bits_in_word = 64 - bIndex
|
| 138 |
+
if bits_in_word < nbits:
|
| 139 |
+
upper = val >> bits_in_word
|
| 140 |
+
out_code[wIndex + 1] |= upper & np.iinfo(np.uint64).max
|
| 141 |
+
|
| 142 |
+
return out_code
|
| 143 |
+
|
| 144 |
+
@numba.njit(cache=True)
|
| 145 |
+
def pack_codes_8(codes, nbits, code_n):
|
| 146 |
+
'''
|
| 147 |
+
sequentially packing codes into codeT type array.
|
| 148 |
+
args:
|
| 149 |
+
codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
|
| 150 |
+
nbits: int
|
| 151 |
+
codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
|
| 152 |
+
code_n: int
|
| 153 |
+
|
| 154 |
+
return:
|
| 155 |
+
packed_codes (-1, code_n) dtype: codeT
|
| 156 |
+
'''
|
| 157 |
+
n_samples = codes.shape[0]
|
| 158 |
+
n_unpacked = code_n * 8 // nbits
|
| 159 |
+
packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint8)
|
| 160 |
+
for i in range(n_samples // n_unpacked):
|
| 161 |
+
unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
|
| 162 |
+
packed_codes[i] = general_pack_8(unpacked, nbits, code_n)
|
| 163 |
+
return packed_codes
|
| 164 |
+
|
| 165 |
+
@numba.njit(cache=True)
|
| 166 |
+
def pack_codes_16(codes, nbits, code_n):
|
| 167 |
+
'''
|
| 168 |
+
sequentially packing codes into codeT type array.
|
| 169 |
+
args:
|
| 170 |
+
codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
|
| 171 |
+
nbits: int
|
| 172 |
+
codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
|
| 173 |
+
code_n: int
|
| 174 |
+
|
| 175 |
+
return:
|
| 176 |
+
packed_codes (-1, code_n) dtype: codeT
|
| 177 |
+
'''
|
| 178 |
+
n_samples = codes.shape[0]
|
| 179 |
+
n_unpacked = code_n * 16 // nbits
|
| 180 |
+
packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint16)
|
| 181 |
+
for i in range(n_samples // n_unpacked):
|
| 182 |
+
unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
|
| 183 |
+
packed_codes[i] = general_pack_16(unpacked, nbits, code_n)
|
| 184 |
+
return packed_codes
|
| 185 |
+
|
| 186 |
+
@numba.njit(cache=True)
|
| 187 |
+
def pack_codes_32(codes, nbits, code_n):
|
| 188 |
+
'''
|
| 189 |
+
sequentially packing codes into codeT type array.
|
| 190 |
+
args:
|
| 191 |
+
codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
|
| 192 |
+
nbits: int
|
| 193 |
+
codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
|
| 194 |
+
code_n: int
|
| 195 |
+
|
| 196 |
+
return:
|
| 197 |
+
packed_codes (-1, code_n) dtype: codeT
|
| 198 |
+
'''
|
| 199 |
+
n_samples = codes.shape[0]
|
| 200 |
+
n_unpacked = code_n * 32 // nbits
|
| 201 |
+
packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint32)
|
| 202 |
+
for i in range(n_samples // n_unpacked):
|
| 203 |
+
unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
|
| 204 |
+
packed_codes[i] = general_pack_32(unpacked, nbits, code_n)
|
| 205 |
+
return packed_codes
|
| 206 |
+
|
| 207 |
+
@numba.njit(cache=True)
|
| 208 |
+
def pack_codes_64(codes, nbits, code_n):
|
| 209 |
+
'''
|
| 210 |
+
sequentially packing codes into codeT type array.
|
| 211 |
+
args:
|
| 212 |
+
codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
|
| 213 |
+
nbits: int
|
| 214 |
+
codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
|
| 215 |
+
code_n: int
|
| 216 |
+
|
| 217 |
+
return:
|
| 218 |
+
packed_codes (-1, code_n) dtype: codeT
|
| 219 |
+
'''
|
| 220 |
+
n_samples = codes.shape[0]
|
| 221 |
+
n_unpacked = code_n * 64 // nbits
|
| 222 |
+
packed_codes = np.zeros((int(n_samples // n_unpacked), code_n), dtype=np.uint64)
|
| 223 |
+
for i in range(n_samples // n_unpacked):
|
| 224 |
+
unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
|
| 225 |
+
packed_codes[i] = general_pack_64(unpacked, nbits, code_n)
|
| 226 |
+
return packed_codes
|
| 227 |
+
|
| 228 |
+
def pack_codes(codes, nbits, code_n, codeT_sz):
|
| 229 |
+
if codeT_sz == 8:
|
| 230 |
+
return pack_codes_8(codes, nbits, code_n)
|
| 231 |
+
elif codeT_sz == 16:
|
| 232 |
+
return pack_codes_16(codes, nbits, code_n)
|
| 233 |
+
elif codeT_sz == 32:
|
| 234 |
+
return pack_codes_32(codes, nbits, code_n)
|
| 235 |
+
elif codeT_sz == 64:
|
| 236 |
+
return pack_codes_64(codes, nbits, code_n)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Unsupported codeT_sz: {codeT_sz}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@numba.njit(cache=True)
|
| 243 |
+
def pack_32(cluster_idx: np.ndarray, nbits: int) -> np.ndarray:
|
| 244 |
+
"""
|
| 245 |
+
NumPy 버전의 pack_32 함수.
|
| 246 |
+
|
| 247 |
+
Parameters
|
| 248 |
+
----------
|
| 249 |
+
cluster_idx : np.ndarray of shape (32,), dtype=int
|
| 250 |
+
길이 32의 정수 배열. (C 코드에서 const int* cluster_idx와 동일 역할)
|
| 251 |
+
nbits : int
|
| 252 |
+
각 정수를 몇 비트로 저장할지.
|
| 253 |
+
|
| 254 |
+
Returns
|
| 255 |
+
-------
|
| 256 |
+
out_code : np.ndarray of shape (out_size,), dtype=np.uint32
|
| 257 |
+
32개의 값(각각 nbits 비트)으로 구성된 연속 비트열을
|
| 258 |
+
32비트 워드(uint32) 단위로 나눈 결과.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
# 32개의 값을 nbits비트씩 사용하면 총 32*nbits 비트가 필요.
|
| 262 |
+
# 이를 32비트 단위로 나누면 아래처럼 워드 수가 결정됨.
|
| 263 |
+
out_size = (32 * nbits + 31) // 32 # 올림
|
| 264 |
+
|
| 265 |
+
# 결과 버퍼 (np.uint32로)
|
| 266 |
+
out_code = np.zeros(out_size, dtype=np.uint32)
|
| 267 |
+
|
| 268 |
+
for i in range(32):
|
| 269 |
+
# cluster_idx[i]를 unsigned 처리
|
| 270 |
+
val = np.uint32(cluster_idx[i])
|
| 271 |
+
|
| 272 |
+
offset = i * nbits
|
| 273 |
+
wIndex = offset // 32 # 몇 번째 워드인지
|
| 274 |
+
bIndex = offset % 32 # 그 워드 내에서 몇 번째 비트부터 시작?
|
| 275 |
+
|
| 276 |
+
# 첫 번째 워드에 bIndex부터 nbits비트 중 일부 혹은 전부를 저장
|
| 277 |
+
out_code[wIndex] |= (val << bIndex) & np.uint32(0xFFFFFFFF)
|
| 278 |
+
|
| 279 |
+
# 현재 워드에 다 못 들어가는 나머지 비트가 있으면, 다음 워드에 저장
|
| 280 |
+
bits_in_word = 32 - bIndex
|
| 281 |
+
if bits_in_word < nbits:
|
| 282 |
+
upper = val >> bits_in_word
|
| 283 |
+
out_code[wIndex + 1] |= upper & np.uint32(0xFFFFFFFF)
|
| 284 |
+
|
| 285 |
+
return out_code
|
| 286 |
+
|
| 287 |
+
@numba.njit(cache=True)
|
| 288 |
+
def pack_for_sq_pack_kernel(unpacked_code: np.ndarray,
|
| 289 |
+
nbits: int,
|
| 290 |
+
blockDimX: int=32) -> np.ndarray:
|
| 291 |
+
"""
|
| 292 |
+
Python 버전의 pack_for_sq_pack_kernel.
|
| 293 |
+
unpacked_code: shape = (N, K), dtype=uint32
|
| 294 |
+
nbits: int
|
| 295 |
+
blockDimX: int
|
| 296 |
+
return: Bcode (1D np.uint32 배열, 길이 = N * (K*nbits//32))
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
N, K = unpacked_code.shape
|
| 300 |
+
out_size = N * (K * nbits // 32)
|
| 301 |
+
Bcode = np.zeros(out_size, dtype=np.uint32)
|
| 302 |
+
|
| 303 |
+
K_iter = int(np.ceil(K / (32 * blockDimX)))
|
| 304 |
+
|
| 305 |
+
# 임시 버퍼
|
| 306 |
+
unpacked_Bcode_row = np.zeros(32, dtype=np.uint32)
|
| 307 |
+
|
| 308 |
+
for n_ in range(N):
|
| 309 |
+
eff_warp_size = blockDimX
|
| 310 |
+
for k_ in range(K_iter):
|
| 311 |
+
for thx in range(blockDimX):
|
| 312 |
+
if k_ == K // (32 * blockDimX):
|
| 313 |
+
eff_warp_size = (K % (32 * blockDimX)) // 32
|
| 314 |
+
if thx >= eff_warp_size:
|
| 315 |
+
break
|
| 316 |
+
k_val = k_ * 32 * blockDimX + 8 * thx
|
| 317 |
+
k_code = k_ * nbits * blockDimX + thx
|
| 318 |
+
|
| 319 |
+
# unpacked_Bcode_row에 32개 로드
|
| 320 |
+
idx_out = 0
|
| 321 |
+
for j in range(4):
|
| 322 |
+
k_val_idx = k_val + 8 * j * eff_warp_size
|
| 323 |
+
for i in range(8):
|
| 324 |
+
unpacked_Bcode_row[idx_out] = unpacked_code[n_, k_val_idx + i]
|
| 325 |
+
idx_out += 1
|
| 326 |
+
|
| 327 |
+
# pack_32 호출
|
| 328 |
+
Bcode_row = pack_32(unpacked_Bcode_row, nbits)
|
| 329 |
+
|
| 330 |
+
# Bcode에 저장
|
| 331 |
+
for j in range(nbits):
|
| 332 |
+
k_code_idx = k_code + j * eff_warp_size
|
| 333 |
+
Bcode[n_ * (K * nbits // 32) + k_code_idx] = Bcode_row[j]
|
| 334 |
+
|
| 335 |
+
return Bcode
|
lib/quantizer/quant_op.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lib.utils as utils
|
| 2 |
+
from lib.quantizer.pack_op import pack_codes, pack_for_sq_pack_kernel
|
| 3 |
+
import torch
|
| 4 |
+
_PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
|
| 5 |
+
4).flatten()
|
| 6 |
+
_INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
|
| 7 |
+
_INV_PERMUTE[_PERMUTE] = torch.arange(256)
|
| 8 |
+
|
| 9 |
+
def random_mat(N, K):
|
| 10 |
+
return torch.randn(N, K, dtype=torch.float16).cuda()
|
| 11 |
+
|
| 12 |
+
def random_lut(nbits, vec_sz):
|
| 13 |
+
return torch.randn(1 << (nbits), vec_sz, dtype=torch.float16).cuda()
|
| 14 |
+
|
| 15 |
+
def vq_pack_reshape_pack_routine(packed_codes, warp_size, code_n, N):
|
| 16 |
+
'''
|
| 17 |
+
packed_codes: torch.Tensor, shape = (N, -1, code_n)
|
| 18 |
+
'''
|
| 19 |
+
packed_codes = packed_codes.reshape(N, -1, code_n)
|
| 20 |
+
if packed_codes.shape[1] % warp_size == 0:
|
| 21 |
+
packed_codes = packed_codes.reshape(N, -1, warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
|
| 22 |
+
else:
|
| 23 |
+
full_warp_part = packed_codes.shape[1] - packed_codes.shape[1] % warp_size
|
| 24 |
+
packed_codes_full = packed_codes[:, :full_warp_part]
|
| 25 |
+
packed_codes_partial = packed_codes[:, full_warp_part:]
|
| 26 |
+
effective_warp_size = packed_codes_partial.shape[1]
|
| 27 |
+
packed_codes_full = packed_codes_full.reshape(N, -1, warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
|
| 28 |
+
packed_codes_partial = packed_codes_partial.reshape(N, 1, effective_warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
|
| 29 |
+
packed_codes = torch.cat([packed_codes_full, packed_codes_partial], dim=1)
|
| 30 |
+
|
| 31 |
+
return packed_codes
|
| 32 |
+
|
| 33 |
+
def reshape_mat(mat, chunk_size, warp_size=32):
|
| 34 |
+
'''
|
| 35 |
+
mat: torch.Tensor, shape = (N, K)
|
| 36 |
+
chunk_size: int
|
| 37 |
+
|
| 38 |
+
return: reshaped_mat (N, K//chunk_size, chunk_size)
|
| 39 |
+
'''
|
| 40 |
+
N, K = mat.shape
|
| 41 |
+
assert K % 8 == 0 and chunk_size % 8 == 0, "K and chunk_size must be divisible by 8"
|
| 42 |
+
assert K % (chunk_size * warp_size) == 0, "K must be divisible by chunk_size * warp_size"
|
| 43 |
+
K_iter = K // (chunk_size * warp_size)
|
| 44 |
+
new_mat = mat.reshape(N, K_iter, chunk_size // 8, warp_size, 8).permute(0, 1, 3, 2, 4).reshape(N, K_iter, warp_size, chunk_size)
|
| 45 |
+
return new_mat
|
| 46 |
+
|
| 47 |
+
def vq_pack_reshape_mat_routine(mat, chunk_size, vec_sz, warp_size=32):
|
| 48 |
+
'''
|
| 49 |
+
mat: torch.Tensor, shape = (N, K)
|
| 50 |
+
chunk_size: int
|
| 51 |
+
|
| 52 |
+
return: vecs
|
| 53 |
+
'''
|
| 54 |
+
N, K = mat.shape
|
| 55 |
+
assert K % chunk_size == 0, "K must be divisible by chunk_size"
|
| 56 |
+
if K % (chunk_size * warp_size) == 0:
|
| 57 |
+
mat = reshape_mat(mat, chunk_size, warp_size)
|
| 58 |
+
vecs = mat.reshape(-1, vec_sz)
|
| 59 |
+
else:
|
| 60 |
+
full_warp_part = K - K % (chunk_size * warp_size)
|
| 61 |
+
mat_full = reshape_mat(mat[:, :full_warp_part], chunk_size, warp_size)
|
| 62 |
+
effective_warp_size = (K % (chunk_size * warp_size)) // chunk_size
|
| 63 |
+
mat_partial = reshape_mat(mat[:, full_warp_part:], chunk_size, effective_warp_size)
|
| 64 |
+
|
| 65 |
+
vecs = torch.cat([mat_full.reshape(N, -1, vec_sz), mat_partial.reshape(N, -1, vec_sz)], dim=1).reshape(-1, vec_sz)
|
| 66 |
+
|
| 67 |
+
return vecs
|
| 68 |
+
|
| 69 |
+
def pack_qweight_vq_simt(P, lut_bits, vec_sz, code_n, codeT_sz=32):
|
| 70 |
+
'''
|
| 71 |
+
P: (N, K // vec_sz)
|
| 72 |
+
'''
|
| 73 |
+
N = P.shape[0]
|
| 74 |
+
expanded_P = P.unsqueeze(-1).expand(-1, -1, vec_sz).reshape(N, -1)
|
| 75 |
+
reshaped_P = vq_pack_reshape_mat_routine(expanded_P, chunk_size=int(32*vec_sz), vec_sz=vec_sz)[:, 0].contiguous()
|
| 76 |
+
packed_codes = torch.from_numpy(pack_codes(reshaped_P.view(-1).cpu().numpy(), lut_bits, code_n, codeT_sz)).reshape(N, -1).cuda()
|
| 77 |
+
packed_codes = vq_pack_reshape_pack_routine(packed_codes, 32, code_n, N)
|
| 78 |
+
return packed_codes
|
| 79 |
+
|
| 80 |
+
def pack_qweight_sq_simt(P, lut_bits):
|
| 81 |
+
'''
|
| 82 |
+
P: (N, K)
|
| 83 |
+
'''
|
| 84 |
+
N = P.shape[0]
|
| 85 |
+
packed_codes = torch.from_numpy(pack_for_sq_pack_kernel(P.cpu().numpy(), lut_bits)).reshape(N, -1).cuda()
|
| 86 |
+
return packed_codes
|
| 87 |
+
|
| 88 |
+
# for tensor core
|
| 89 |
+
def pack_qweight(P, vec_sz, lut_bits, td_x=16, td_y=16, batch_size=1024):
|
| 90 |
+
'''
|
| 91 |
+
P: (N, K // vec_sz, 2 ** lut_bits) 0, 1
|
| 92 |
+
'''
|
| 93 |
+
N = P.shape[0]
|
| 94 |
+
mat_packed = []
|
| 95 |
+
for i in range(0, N, batch_size):
|
| 96 |
+
sidx, eidx = i, min(i + batch_size, N)
|
| 97 |
+
cur_size = eidx - sidx
|
| 98 |
+
mat_packed.append(pack_qweight_routine(P[sidx:eidx], vec_sz, lut_bits, td_x, td_y))
|
| 99 |
+
return torch.cat(mat_packed, dim=0)
|
| 100 |
+
|
| 101 |
+
def pack_qweight_routine(P, vec_sz, lut_bits, td_x=16, td_y=16):
|
| 102 |
+
'''
|
| 103 |
+
P: (N, K // vec_sz, 2 ** lut_bits) 0, 1
|
| 104 |
+
'''
|
| 105 |
+
if vec_sz == 1:
|
| 106 |
+
if len(P.shape) == 3:
|
| 107 |
+
P_ind = P.argmax(dim=-1) # (N, K)
|
| 108 |
+
elif len(P.shape) == 2:
|
| 109 |
+
P_ind = P
|
| 110 |
+
N, K = P_ind.shape
|
| 111 |
+
P_tiled = P_ind.reshape(N // td_x, td_x, K // td_y, td_y) \
|
| 112 |
+
.permute(0, 2, 1, 3) \
|
| 113 |
+
.reshape(-1, td_x * td_y)
|
| 114 |
+
P_tiled_permuted = P_tiled[..., _PERMUTE]
|
| 115 |
+
elif vec_sz == 2:
|
| 116 |
+
if len(P.shape) == 3:
|
| 117 |
+
P_ind = P.argmax(dim=-1) # (N, K // vec_sz)
|
| 118 |
+
else:
|
| 119 |
+
P_ind = P
|
| 120 |
+
P_ind = P_ind.unsqueeze(-1).expand(-1, -1, vec_sz)
|
| 121 |
+
P_ind = P_ind.reshape(P_ind.shape[0], -1).contiguous()
|
| 122 |
+
N, K = P_ind.shape
|
| 123 |
+
P_tiled = P_ind.reshape(N // td_x, td_x, K // td_y, td_y) \
|
| 124 |
+
.permute(0, 2, 1, 3) \
|
| 125 |
+
.reshape(-1, td_x * td_y)
|
| 126 |
+
# permute and flatten
|
| 127 |
+
P_tiled_permuted = P_tiled[..., _PERMUTE].contiguous().view(-1, vec_sz)
|
| 128 |
+
assert torch.allclose(P_tiled_permuted[:, 0], P_tiled_permuted[:, 1]), \
|
| 129 |
+
"P_tiled_permuted[:, 0] and P_tiled_permuted[:, 1] are not the same"
|
| 130 |
+
P_tiled_permuted = P_tiled_permuted[:, 0].contiguous()
|
| 131 |
+
P_tiled_permuted = P_tiled_permuted.reshape((N * K) // (td_x * td_y), td_x * td_y // vec_sz)
|
| 132 |
+
|
| 133 |
+
m = P_tiled_permuted.shape[0]
|
| 134 |
+
c = (td_x * td_y) // vec_sz
|
| 135 |
+
|
| 136 |
+
K_mask = 2 ** torch.arange(lut_bits, device=P.device).view(1, 1, -1) # => [1,1,lut_bits]
|
| 137 |
+
bits_bool = (P_tiled_permuted.unsqueeze(-1) & K_mask) > 0 # => [m, c, lut_bits]
|
| 138 |
+
if vec_sz == 1:
|
| 139 |
+
|
| 140 |
+
# group 4 bytes => 1 uint32
|
| 141 |
+
# group 8 bits => 1 byte
|
| 142 |
+
bits_bool_8 = bits_bool.reshape(m, (c * lut_bits) // 8, 8) # => [m, c*lut_bits/8, 8]
|
| 143 |
+
uint_mask = (2 ** torch.arange(8, device=P.device, dtype=torch.int16)).view(1, 1, 8)
|
| 144 |
+
packed_8 = (bits_bool_8.to(torch.int16) * uint_mask).sum(dim=-1).to(torch.uint8) # => [m, (c*lut_bits)//8]
|
| 145 |
+
|
| 146 |
+
mat_packed = packed_8.reshape(N // td_x // 2, 2, K // td_y // 2, 2, td_x * td_y // 8, lut_bits) \
|
| 147 |
+
.permute(0, 2, 4, 3, 1, 5).contiguous().flatten().view(torch.uint32)\
|
| 148 |
+
.reshape((N * K) // (td_x * td_y), (td_x * td_y * lut_bits) // (32 * vec_sz))
|
| 149 |
+
elif vec_sz == 2:
|
| 150 |
+
# group 8 bits => 1 byte
|
| 151 |
+
bits_bool_4 = bits_bool.reshape(m, (c * lut_bits) // 4, 4) # => [m, c*nbits/8, 8]
|
| 152 |
+
uint_mask = (2 ** torch.arange(4, device=bits_bool_4.device, dtype=torch.int16)).view(1, 1, 4)
|
| 153 |
+
packed_4 = (bits_bool_4.to(torch.int16) * uint_mask).sum(dim=-1).to(torch.uint8) # => [m, (c*nbits)//8]
|
| 154 |
+
|
| 155 |
+
mat_packed_48 = packed_4.reshape(N // td_x // 2, 2, K // td_y // 2, 2, td_x * td_y // 8, lut_bits) \
|
| 156 |
+
.permute(0, 2, 4, 3, 1, 5).contiguous().flatten()
|
| 157 |
+
# uint 4 packed in uint 8 to uint 32
|
| 158 |
+
packing_mask = torch.Tensor([1, 2 ** 4]).to(torch.int8).view(1,2).cuda()
|
| 159 |
+
mat_packed8 = (mat_packed_48.reshape(-1, 2) * packing_mask).sum(dim=-1).to(torch.uint8).contiguous().flatten()
|
| 160 |
+
|
| 161 |
+
mat_packed = mat_packed8.view(torch.uint32).reshape((N * K) // (td_x * td_y), (td_x * td_y * lut_bits) // (32 * vec_sz))
|
| 162 |
+
return mat_packed.view(N, -1)
|
| 163 |
+
|
| 164 |
+
def load_hessian(in_hess_path, sigma_reg=0.01):
|
| 165 |
+
H_data = torch.load(in_hess_path, map_location=torch.device('cpu'))
|
| 166 |
+
H = utils.flat_to_sym(H_data['flatH'], H_data['n'])
|
| 167 |
+
if 'mu' in H_data:
|
| 168 |
+
mu = H_data['mu']
|
| 169 |
+
H += mu[None, :] * mu[:, None]
|
| 170 |
+
del mu
|
| 171 |
+
del H_data
|
| 172 |
+
H = utils.regularize_H(H, sigma_reg)
|
| 173 |
+
assert len(H.shape) == 2 and H.shape[0] == H.shape[1], "H must be a square matrix"
|
| 174 |
+
return H.to(torch.float64).unsqueeze(-1)
|
| 175 |
+
|
| 176 |
+
def load_group_hessian(in_hess_path, sigma_reg=0.01, layer_key=None):
|
| 177 |
+
H_data = torch.load(in_hess_path, map_location=torch.device('cpu'))
|
| 178 |
+
H = H_data[layer_key]
|
| 179 |
+
for i in range(H.shape[-1]):
|
| 180 |
+
H[:, :, i] = utils.regularize_H(H[:, :, i], sigma_reg)
|
| 181 |
+
assert len(H.shape) == 3 and H.shape[0] == H.shape[1], "H must be a square matrix"
|
| 182 |
+
return H.to(torch.float64)
|
| 183 |
+
|
| 184 |
+
# deprecated func for dequantization
|
| 185 |
+
@torch.compile
|
| 186 |
+
def dequantize_mat_sq(mat_packed, lut, N, K, nbits, td_x=16, td_y=16):
|
| 187 |
+
packed = mat_packed.flatten().view(torch.uint8).reshape(N // td_x // 2,
|
| 188 |
+
K // td_y // 2,
|
| 189 |
+
td_x * td_y // 8,
|
| 190 |
+
2, 2, nbits)
|
| 191 |
+
packed_8 = packed.permute(0, 4, 1, 3, 2, 5).contiguous().reshape(N * K // (td_x * td_y), (td_x * td_y) * nbits // 8)
|
| 192 |
+
bits_mask = (2 ** torch.arange(8, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 8)
|
| 193 |
+
bits_bool_8 = (packed_8.unsqueeze(-1) & bits_mask) > 0
|
| 194 |
+
bits_bool = bits_bool_8.reshape(N * K // (td_x * td_y), (td_x * td_y), nbits)
|
| 195 |
+
K_mask = 2 ** torch.arange(nbits, device=mat_packed.device).view(1, 1, -1)
|
| 196 |
+
indices = (bits_bool * K_mask).sum(dim=-1)
|
| 197 |
+
recon = lut[indices.long()].reshape(N * K // (td_x * td_y), td_x * td_y)
|
| 198 |
+
recon = recon.index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
|
| 199 |
+
recon = recon.reshape(N // td_x, K // td_y, td_x, td_y)
|
| 200 |
+
recon = recon.permute(0, 2, 1, 3).reshape(N, K)
|
| 201 |
+
return recon
|
| 202 |
+
|
| 203 |
+
@torch.compile
|
| 204 |
+
def dequantize_mat_sq_inds(mat_packed, N, K, nbits, td_x=16, td_y=16):
|
| 205 |
+
packed = mat_packed.flatten().view(torch.uint8).reshape(N // td_x // 2,
|
| 206 |
+
K // td_y // 2,
|
| 207 |
+
td_x * td_y // 8,
|
| 208 |
+
2, 2, nbits)
|
| 209 |
+
packed_8 = packed.permute(0, 4, 1, 3, 2, 5).contiguous().reshape(N * K // (td_x * td_y), (td_x * td_y) * nbits // 8)
|
| 210 |
+
bits_mask = (2 ** torch.arange(8, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 8)
|
| 211 |
+
bits_bool_8 = (packed_8.unsqueeze(-1) & bits_mask) > 0
|
| 212 |
+
bits_bool = bits_bool_8.reshape(N * K // (td_x * td_y), (td_x * td_y), nbits)
|
| 213 |
+
K_mask = 2 ** torch.arange(nbits, device=mat_packed.device).view(1, 1, -1)
|
| 214 |
+
indices = (bits_bool * K_mask).sum(dim=-1)
|
| 215 |
+
indices = indices.reshape(N * K // (td_x * td_y), td_x * td_y).index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
|
| 216 |
+
indices = indices.reshape(N // td_x, K // td_y, td_x, td_y)
|
| 217 |
+
indices = indices.permute(0, 2, 1, 3).reshape(N, K).contiguous()
|
| 218 |
+
return indices
|
| 219 |
+
|
| 220 |
+
@torch.compile
|
| 221 |
+
def dequantize_mat_sq_inds_vec2(mat_packed: torch.Tensor,
|
| 222 |
+
N: int,
|
| 223 |
+
K: int,
|
| 224 |
+
lut_bits: int,
|
| 225 |
+
td_x: int = 16,
|
| 226 |
+
td_y: int = 16) -> torch.Tensor:
|
| 227 |
+
mat_packed8 = mat_packed.view(torch.uint8).flatten().unsqueeze(-1).expand(-1, 2).contiguous()
|
| 228 |
+
|
| 229 |
+
mat_packed48 = torch.zeros_like(mat_packed8)
|
| 230 |
+
mat_packed48[:, 0] = mat_packed8[:, 0] & 0b1111
|
| 231 |
+
mat_packed48[:, 1] = mat_packed8[:, 1] >> 4
|
| 232 |
+
|
| 233 |
+
packed_4 = mat_packed48.reshape(N // td_x // 2, K // td_y // 2, td_x * td_y // 8, 2, 2, lut_bits).permute(0,4,1,3,2,5).reshape(N * K // (td_x * td_y), -1).contiguous()
|
| 234 |
+
bits_mask = (2 ** torch.arange(4, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 4)
|
| 235 |
+
bits_bool_4 = (packed_4.unsqueeze(-1) & bits_mask) > 0
|
| 236 |
+
bits_bool = bits_bool_4.reshape(N * K // (td_x * td_y), td_x * td_y // 2, lut_bits)
|
| 237 |
+
K_mask = 2 ** torch.arange(lut_bits, device=mat_packed.device).view(1, 1, -1)
|
| 238 |
+
indices = (bits_bool * K_mask).sum(dim=-1)
|
| 239 |
+
indices = indices.reshape(N * K // (td_x * td_y), td_x * td_y // 2, 1).expand(-1, -1, 2).reshape(N * K // (td_x * td_y), td_x * td_y).contiguous()
|
| 240 |
+
indices = indices.index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
|
| 241 |
+
indices = indices.reshape(N // td_x, K // td_y, td_x, td_y)
|
| 242 |
+
indices = indices.permute(0, 2, 1, 3).reshape(N, K // 2, 2)
|
| 243 |
+
indices = indices[:, :, 0].contiguous()
|
| 244 |
+
return indices
|
| 245 |
+
|
| 246 |
+
def convert_tensor_core_to_simt(mat_packed, N, K, vec_sz, lut_bit, code_n, codeT_sz=32, td_x=16, td_y=16):
|
| 247 |
+
device = mat_packed.device
|
| 248 |
+
mat_packed = mat_packed.cuda()
|
| 249 |
+
if vec_sz == 2:
|
| 250 |
+
indices = dequantize_mat_sq_inds_vec2(mat_packed, N, K, lut_bit, td_x, td_y)
|
| 251 |
+
packed_codes = pack_qweight_vq_simt(indices, lut_bit, vec_sz, code_n, codeT_sz)
|
| 252 |
+
packed_codes = packed_codes.to(device)
|
| 253 |
+
else:
|
| 254 |
+
indices = dequantize_mat_sq_inds(mat_packed, N, K, lut_bit, td_x, td_y)
|
| 255 |
+
packed_codes = pack_qweight_sq_simt(indices, lut_bit)
|
| 256 |
+
packed_codes = packed_codes.to(device)
|
| 257 |
+
return packed_codes.contiguous()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
nbits = 5
|
| 263 |
+
Qidxs = torch.randint(0, 2**nbits, (11008, 4096)).cuda()
|
| 264 |
+
packed = pack_qweight(Qidxs, 1, nbits)
|
| 265 |
+
|
| 266 |
+
indices = dequantize_mat_sq_inds(packed, 11008, 4096, nbits)
|
| 267 |
+
|
| 268 |
+
converted = convert_tensor_core_to_simt(packed, 11008, 4096, 1, nbits, code_n=nbits)
|
| 269 |
+
|
| 270 |
+
Qidxs2 = torch.randint(0, 2**nbits, (11008, 2048)).cuda()
|
| 271 |
+
packed2 = pack_qweight(Qidxs2, 2, nbits)
|
| 272 |
+
|
| 273 |
+
indices2 = dequantize_mat_sq_inds_vec2(packed2, 11008, 4096, nbits)
|
| 274 |
+
|
| 275 |
+
converted2 = convert_tensor_core_to_simt(packed2, 11008, 4096, 2, nbits, code_n=nbits)
|
| 276 |
+
|
| 277 |
+
import ipdb; ipdb.set_trace()
|
lib/quantizer/tcq_quant.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from lib.utils import block_LDL, matmul_hadUt, matmul_hadUt_head
|
| 3 |
+
from lib.algo.ldlq import LDLQ
|
| 4 |
+
from lib.quantizer.quant_op import load_hessian, load_group_hessian
|
| 5 |
+
from lib.linear import QTIPLinearTCQ, IncoherentLinear
|
| 6 |
+
from lib.codebook.bitshift import bitshift_codebook
|
| 7 |
+
import torch._dynamo
|
| 8 |
+
import time
|
| 9 |
+
class Args:
|
| 10 |
+
def __init__(self, td_x, td_y, V):
|
| 11 |
+
self.td_x = td_x
|
| 12 |
+
self.td_y = td_y
|
| 13 |
+
self.V = V
|
| 14 |
+
|
| 15 |
+
def qtip_quantize_mat(Wr, HRr, Wscale, cb, td_x=16, td_y=16, KV=4, V=2, use_hess=True):
|
| 16 |
+
HRr_orig = HRr.clone()
|
| 17 |
+
Wr = Wr.to(torch.float64)
|
| 18 |
+
(m, n) = Wr.shape
|
| 19 |
+
gs = HRr.shape[-1]
|
| 20 |
+
LRrs = []
|
| 21 |
+
diag = torch.arange(n, device=HRr.device)
|
| 22 |
+
if not use_hess:
|
| 23 |
+
eye = torch.eye(n, device=Wr.device, dtype=torch.float64)
|
| 24 |
+
LRr, D = block_LDL(eye, td_y)
|
| 25 |
+
LRr[diag, diag] = 0
|
| 26 |
+
LRrs.append(LRr)
|
| 27 |
+
else:
|
| 28 |
+
for i in range(gs):
|
| 29 |
+
LRr, D = block_LDL(HRr[:,:,i], td_y)
|
| 30 |
+
LRr[diag, diag] = 0
|
| 31 |
+
LRrs.append(LRr)
|
| 32 |
+
|
| 33 |
+
args = Args(td_x, td_y, V)
|
| 34 |
+
|
| 35 |
+
Qidxs_list = []
|
| 36 |
+
hatWr_list = []
|
| 37 |
+
for i in range(gs):
|
| 38 |
+
cur_Wr = Wr[m // gs * i:m // gs * (i+1)]
|
| 39 |
+
hatWr, Qidxs = LDLQ(cur_Wr, LRrs[i], cb.cuda(), args, for_kernel=True)
|
| 40 |
+
hatWr_list.append(hatWr)
|
| 41 |
+
Qidxs_list.append(Qidxs)
|
| 42 |
+
hatWr = torch.cat(hatWr_list, dim=0)
|
| 43 |
+
Qidxs = torch.cat(Qidxs_list, dim=0)
|
| 44 |
+
assert hatWr.shape == Wr.shape, f"hatWr.shape {hatWr.shape} != Wr.shape {Wr.shape}"
|
| 45 |
+
|
| 46 |
+
Qidxs = Qidxs.cpu()
|
| 47 |
+
packed = cb.pack_trellis(
|
| 48 |
+
Qidxs.reshape(m // td_x, td_x, n // td_y,
|
| 49 |
+
td_y // V).transpose(1, 2).reshape(
|
| 50 |
+
-1, td_x * td_y // V))
|
| 51 |
+
|
| 52 |
+
packed_8 = packed.view(torch.uint8).view(-1, 2)
|
| 53 |
+
packed_4 = torch.cat([packed_8.unsqueeze(-1) & (2 ** 4 - 1), (packed_8.unsqueeze(-1) & (2 ** 8 - 2 ** 4)) >> 4], dim=-1).view(-1, 4).flip(
|
| 54 |
+
(-1, ))
|
| 55 |
+
|
| 56 |
+
packed_4 = packed_4.reshape(m // 16 // 2, 2, n // 16 // 2, 2, 16 * 16 // 8,
|
| 57 |
+
KV).permute(0, 2, 4, 3, 1, 5).flip(
|
| 58 |
+
(-1, )).contiguous().flatten()
|
| 59 |
+
packed_8 = torch.sum(packed_4.view(-1, 2) * torch.Tensor([[1, 2 ** 4]]).to(torch.uint8), dim=-1).to(torch.uint8).contiguous()
|
| 60 |
+
packed = packed_8.view(torch.int16).reshape(packed.shape).cuda()
|
| 61 |
+
|
| 62 |
+
Wr *= Wscale.reshape(-1, 1)
|
| 63 |
+
hatWr *= Wscale.reshape(-1, 1)
|
| 64 |
+
|
| 65 |
+
orig_err = (Wr - hatWr).pow(2).mean()
|
| 66 |
+
err = orig_err / Wr.pow(2).mean()
|
| 67 |
+
print(
|
| 68 |
+
f'err {err.item()} orig_err {orig_err.item()}'
|
| 69 |
+
)
|
| 70 |
+
quant_info = {
|
| 71 |
+
"quantizer": "tcq_ldlq",
|
| 72 |
+
"td_x": td_x,
|
| 73 |
+
"td_y": td_y,
|
| 74 |
+
"KV": KV,
|
| 75 |
+
"V": V,
|
| 76 |
+
"use_hess": use_hess,
|
| 77 |
+
"orig_err": orig_err.item(),
|
| 78 |
+
"err": err.item(),
|
| 79 |
+
}
|
| 80 |
+
return packed, hatWr, quant_info
|
| 81 |
+
|
| 82 |
+
def inc_linear_to_inc_tcq_linear(inc_linear, HRr, cb, td_x=16, td_y=16, KV=4, V=2, scale_override=0.9, use_hess=True):
|
| 83 |
+
Wr = inc_linear.linear.weight.data * scale_override
|
| 84 |
+
Wscale = inc_linear.Wscale.data / scale_override
|
| 85 |
+
inc_linear.Wscale.data.copy_(Wscale)
|
| 86 |
+
|
| 87 |
+
packed, hatWr, quant_info = qtip_quantize_mat(Wr, HRr, Wscale, cb, td_x=td_x, td_y=td_y, KV=KV, V=V, use_hess=use_hess)
|
| 88 |
+
out_features, in_features = Wr.shape
|
| 89 |
+
tcq_linear = QTIPLinearTCQ(
|
| 90 |
+
in_features,
|
| 91 |
+
out_features,
|
| 92 |
+
td_x=16,
|
| 93 |
+
td_y=16,
|
| 94 |
+
L=16,
|
| 95 |
+
KV=KV,
|
| 96 |
+
V=V,
|
| 97 |
+
tlut_bits=cb.tlut_bits,
|
| 98 |
+
bias=inc_linear.bias is not None,
|
| 99 |
+
dtype=inc_linear.dtype,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
tcq_linear.trellis.data.copy_(packed)
|
| 103 |
+
tcq_linear.tlut.data.copy_(cb.tlut)
|
| 104 |
+
|
| 105 |
+
inc_linear.linear = tcq_linear
|
| 106 |
+
return inc_linear, quant_info
|
| 107 |
+
|
| 108 |
+
def linear_to_incoherent_for_tcq(linear, cb, HR, scale_override=0.9, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False):
|
| 109 |
+
dtype_ = torch.float32
|
| 110 |
+
device = linear.weight.device
|
| 111 |
+
inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype_)
|
| 112 |
+
if SU is None:
|
| 113 |
+
SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
|
| 114 |
+
if SV is None:
|
| 115 |
+
SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
|
| 116 |
+
|
| 117 |
+
if left_only:
|
| 118 |
+
SV = torch.ones_like(SV)
|
| 119 |
+
|
| 120 |
+
if linear.bias is not None:
|
| 121 |
+
inc_linear.bias.data.copy_(linear.bias)
|
| 122 |
+
|
| 123 |
+
W = linear.weight.data.clone().to(dtype_)
|
| 124 |
+
Wr = matmul_hadUt_head(matmul_hadUt_head(W.T.to(device) * SV, hadV).T * SU, hadU) if not left_only else matmul_hadUt_head(W * SU, hadU)
|
| 125 |
+
|
| 126 |
+
if left_only:
|
| 127 |
+
Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_) / (cb.lut.to(torch.float64).square().mean().sqrt().float() * scale_override) # (out_features, 1)
|
| 128 |
+
else:
|
| 129 |
+
Wscale = Wr.to(torch.float64).square().mean().sqrt().view(-1, 1).to(dtype_) / (cb.lut.to(torch.float64).square().mean().sqrt().float() * scale_override) # (1, 1)
|
| 130 |
+
Wscale = Wscale.repeat(Wr.shape[0], 1) # (out_features, 1)
|
| 131 |
+
|
| 132 |
+
Wr = Wr / Wscale
|
| 133 |
+
HRr = torch.zeros_like(HR)
|
| 134 |
+
for i in range(HR.shape[-1]):
|
| 135 |
+
HRr[:,:,i] = matmul_hadUt_head(matmul_hadUt_head(HR[:,:,i].to(device).contiguous() * (1./ SU), hadU).T * (1./ SU), hadU)
|
| 136 |
+
|
| 137 |
+
inc_linear.SU.data.copy_(1./SU.to(dtype_))
|
| 138 |
+
inc_linear.SV.data.copy_(1./SV.to(dtype_))
|
| 139 |
+
inc_linear.Wscale.data.copy_(Wscale.view(-1))
|
| 140 |
+
inc_linear.linear.weight.data.copy_(Wr.to(dtype_))
|
| 141 |
+
inc_linear.rot_info = rot_info
|
| 142 |
+
inc_linear.apply_rot_info()
|
| 143 |
+
return inc_linear, HRr
|
| 144 |
+
|
| 145 |
+
def linear_to_tcq_linear(target_layer, hess_path, cb, scale_override=0.9, KV=4, V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
|
| 146 |
+
t0 = time.time()
|
| 147 |
+
out_features, in_features = target_layer.weight.shape
|
| 148 |
+
if ghess_key == "":
|
| 149 |
+
HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
|
| 150 |
+
else:
|
| 151 |
+
HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
|
| 152 |
+
layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
|
| 153 |
+
HRr = HRr.cuda()
|
| 154 |
+
layer = layer.cuda()
|
| 155 |
+
layer, quant_info = inc_linear_to_inc_tcq_linear(layer, HRr, cb, scale_override=1.0, td_x=16, td_y=16, KV=KV, V=V, use_hess=use_hess)
|
| 156 |
+
quant_info["scale_override"] = scale_override
|
| 157 |
+
quant_info["hess_path"] = hess_path
|
| 158 |
+
quant_info["time"] = time.time() - t0
|
| 159 |
+
|
| 160 |
+
return layer.to(torch.float16), quant_info
|
lib/quantizer/vq_quant.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from lib.quantizer.quant_op import _INV_PERMUTE, load_hessian, load_group_hessian
|
| 3 |
+
from lib.linear import IncoherentLinear
|
| 4 |
+
from lib.utils import matmul_hadUt, matmul_hadUt_head, clean
|
| 5 |
+
from lib.utils.kmeans import kmeans_sklearn, kmeans_flash1d
|
| 6 |
+
from lib.quantizer.nuq_op import train_least_squares
|
| 7 |
+
from lib.quantizer.quant_op import pack_qweight
|
| 8 |
+
from lib.linear import VQLinearPackTensorCore
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
import time
|
| 12 |
+
def simple_vq(Wr, vec_sz, lut_bits, batch_size=256):
|
| 13 |
+
batch_size = 64 if lut_bits >= 12 else batch_size
|
| 14 |
+
# kmeans
|
| 15 |
+
Wr_flatten = Wr.reshape(-1, vec_sz)
|
| 16 |
+
|
| 17 |
+
if vec_sz == 1:
|
| 18 |
+
init_centroids = kmeans_flash1d(Wr_flatten, 2 ** lut_bits)
|
| 19 |
+
else:
|
| 20 |
+
init_centroids = kmeans_sklearn(Wr_flatten, 2 ** lut_bits)
|
| 21 |
+
Wr_vec = Wr.reshape(Wr.shape[0], -1, vec_sz) # (W_row, W_col // vec_sz, vec_sz)
|
| 22 |
+
|
| 23 |
+
min_indices = torch.zeros(Wr.shape[0], Wr.shape[1] // vec_sz).to(Wr.device).long()
|
| 24 |
+
for s_idx in range(0, Wr.shape[0], batch_size):
|
| 25 |
+
e_idx = min(s_idx + batch_size, Wr.shape[0])
|
| 26 |
+
dist_sq = ((Wr_vec[s_idx:e_idx].unsqueeze(2) - init_centroids.unsqueeze(0).unsqueeze(0)) ** 2).sum(dim=-1)
|
| 27 |
+
idx = dist_sq.argmin(dim=-1) # batch_size, W_col // vec_sz
|
| 28 |
+
min_indices[s_idx:e_idx] = idx
|
| 29 |
+
|
| 30 |
+
init_P = torch.zeros(Wr.shape[0], Wr.shape[1] // vec_sz, 2 ** lut_bits, dtype=torch.uint8).to(Wr.device)
|
| 31 |
+
init_P.scatter_(2, min_indices.unsqueeze(-1), 1)
|
| 32 |
+
|
| 33 |
+
return init_P, init_centroids.to(torch.float32)
|
| 34 |
+
|
| 35 |
+
def vq_quantize_mat(Wr, HRr, Wscale, vec_sz, lut_bits, iterations=6, use_hess=True):
|
| 36 |
+
Wr = Wr.to(torch.float64)
|
| 37 |
+
if use_hess:
|
| 38 |
+
assert len(HRr.shape) == 3, "HRr must be a 3D tensor"
|
| 39 |
+
assert HRr.shape[0] == HRr.shape[1], "HRr must be a square matrix"
|
| 40 |
+
init_P, init_centroids = simple_vq(Wr, vec_sz, lut_bits)
|
| 41 |
+
P, C, log_dict = train_least_squares(
|
| 42 |
+
W=Wr.detach().cpu().numpy(),
|
| 43 |
+
init_P=init_P.detach().cpu().to(torch.float32).numpy(),
|
| 44 |
+
init_centroids=init_centroids.detach().cpu().numpy(),
|
| 45 |
+
H=HRr.permute(2, 0, 1).detach().cpu().numpy(),
|
| 46 |
+
num_iterations=iterations,
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
init_P, init_centroids = simple_vq(Wr, vec_sz, lut_bits)
|
| 50 |
+
P, C = init_P, init_centroids
|
| 51 |
+
P = P.to(Wr.device)
|
| 52 |
+
C = C.to(Wr.device)
|
| 53 |
+
|
| 54 |
+
P_ind = torch.argmax(P, dim=-1)
|
| 55 |
+
hatWr = C[P_ind]
|
| 56 |
+
hatWr = hatWr.view(hatWr.shape[0], -1)
|
| 57 |
+
|
| 58 |
+
Wr *= Wscale.view(-1, 1)
|
| 59 |
+
hatWr *= Wscale.view(-1, 1)
|
| 60 |
+
|
| 61 |
+
orig_err = (Wr - hatWr).pow(2).mean()
|
| 62 |
+
err = (Wr - hatWr).pow(2).mean() / (Wr.pow(2).mean())
|
| 63 |
+
print(
|
| 64 |
+
f'err {err.item()} orig_err {orig_err.item()}'
|
| 65 |
+
)
|
| 66 |
+
quant_info = {
|
| 67 |
+
"quantizer": "vq_lnq",
|
| 68 |
+
"vec_sz": vec_sz,
|
| 69 |
+
"lut_bits": lut_bits,
|
| 70 |
+
"use_hess": use_hess,
|
| 71 |
+
"iterations": iterations,
|
| 72 |
+
"orig_err": orig_err.item(),
|
| 73 |
+
"err": err.item(),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# pack P appropriately to kernel
|
| 77 |
+
packed = pack_qweight(P, vec_sz, lut_bits)
|
| 78 |
+
return packed, C, hatWr, quant_info
|
| 79 |
+
|
| 80 |
+
def inc_linear_to_inc_vq_linear(inc_linear, HRr, lut_bits=4, vec_sz=2, scale_override=0.9, use_hess=True):
|
| 81 |
+
Wr = inc_linear.linear.weight.data * scale_override
|
| 82 |
+
Wscale = inc_linear.Wscale.data / scale_override
|
| 83 |
+
inc_linear.Wscale.data.copy_(Wscale)
|
| 84 |
+
|
| 85 |
+
packed, C, hatWr, quant_info = vq_quantize_mat(Wr, HRr, Wscale, vec_sz, lut_bits, use_hess=use_hess)
|
| 86 |
+
out_features, in_features = Wr.shape
|
| 87 |
+
sq_linear = VQLinearPackTensorCore(
|
| 88 |
+
in_features,
|
| 89 |
+
out_features,
|
| 90 |
+
lut_bits=lut_bits,
|
| 91 |
+
vec_sz=vec_sz,
|
| 92 |
+
bias=inc_linear.bias is not None,
|
| 93 |
+
dtype=inc_linear.dtype,
|
| 94 |
+
)
|
| 95 |
+
sq_linear.qweight.data.copy_(packed)
|
| 96 |
+
sq_linear.lut.data.copy_(C.view(2 ** lut_bits, vec_sz))
|
| 97 |
+
|
| 98 |
+
inc_linear.linear = sq_linear
|
| 99 |
+
return inc_linear, quant_info
|
| 100 |
+
|
| 101 |
+
def linear_to_incoherent_for_vq(linear, HR, scale_override=0.9, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False):
|
| 102 |
+
dtype_ = torch.float32
|
| 103 |
+
device = linear.weight.device
|
| 104 |
+
inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype_)
|
| 105 |
+
if SU is None:
|
| 106 |
+
SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
|
| 107 |
+
if SV is None:
|
| 108 |
+
SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
|
| 109 |
+
|
| 110 |
+
if left_only:
|
| 111 |
+
SV = torch.ones_like(SV)
|
| 112 |
+
|
| 113 |
+
if linear.bias is not None:
|
| 114 |
+
inc_linear.bias.data.copy_(linear.bias)
|
| 115 |
+
|
| 116 |
+
W = linear.weight.data.to(dtype_)
|
| 117 |
+
Wr = matmul_hadUt_head(matmul_hadUt_head(W.T.to(device) * SV, hadV).T * SU, hadU) if not left_only else matmul_hadUt_head(W * SU, hadU)
|
| 118 |
+
Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_) / scale_override
|
| 119 |
+
|
| 120 |
+
Wr = Wr / Wscale
|
| 121 |
+
HRr = torch.zeros_like(HR)
|
| 122 |
+
for i in range(HR.shape[-1]):
|
| 123 |
+
HRr[:,:,i] = matmul_hadUt_head(matmul_hadUt_head(HR[:,:,i].to(device).contiguous() * (1./ SU), hadU).T * (1./ SU), hadU)
|
| 124 |
+
|
| 125 |
+
inc_linear.SU.data.copy_(1./SU.to(dtype_))
|
| 126 |
+
inc_linear.SV.data.copy_((1./SV).to(dtype_))
|
| 127 |
+
inc_linear.Wscale.data.copy_(Wscale.view(-1))
|
| 128 |
+
inc_linear.linear.weight.data.copy_(Wr.to(dtype_))
|
| 129 |
+
inc_linear.rot_info = rot_info
|
| 130 |
+
inc_linear.apply_rot_info()
|
| 131 |
+
return inc_linear, HRr
|
| 132 |
+
|
| 133 |
+
def linear_to_vq_linear(target_layer, hess_path, scale_override=0.9, lut_bits=4, vec_sz=1, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
|
| 134 |
+
t0 = time.time()
|
| 135 |
+
out_features, in_features = target_layer.weight.shape
|
| 136 |
+
if ghess_key == "":
|
| 137 |
+
HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
|
| 138 |
+
else:
|
| 139 |
+
HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
|
| 140 |
+
layer, HRr = linear_to_incoherent_for_vq(target_layer, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
|
| 141 |
+
HRr = HRr.cuda()
|
| 142 |
+
layer = layer.cuda()
|
| 143 |
+
layer, quant_info = inc_linear_to_inc_vq_linear(layer, HRr, scale_override=1.0, lut_bits=lut_bits, vec_sz=vec_sz, use_hess=use_hess)
|
| 144 |
+
|
| 145 |
+
quant_info["scale_override"] = scale_override
|
| 146 |
+
quant_info["hess_path"] = hess_path
|
| 147 |
+
quant_info["time"] = time.time() - t0
|
| 148 |
+
print("elapsed time", time.time() - t0)
|
| 149 |
+
return layer.to(torch.float16), quant_info
|