bdbj commited on
Commit
d7b2a5b
·
verified ·
1 Parent(s): 3d431a5

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. config.json +2885 -0
  3. generation_config.json +9 -0
  4. lib/__init__.py +1 -0
  5. lib/__pycache__/__init__.cpython-311.pyc +0 -0
  6. lib/__pycache__/config.cpython-311.pyc +0 -0
  7. lib/algo/__init__.py +0 -0
  8. lib/algo/__pycache__/__init__.cpython-311.pyc +0 -0
  9. lib/algo/__pycache__/ldlq.cpython-311.pyc +0 -0
  10. lib/algo/ldlq.py +203 -0
  11. lib/algo/ldlq_beam_cd.py +209 -0
  12. lib/codebook/__pycache__/bitshift.cpython-311.pyc +0 -0
  13. lib/codebook/__pycache__/vq_codebook.cpython-311.pyc +0 -0
  14. lib/codebook/bitshift.py +486 -0
  15. lib/codebook/vq_codebook.py +56 -0
  16. lib/config.py +6 -0
  17. lib/linear/__init__.py +430 -0
  18. lib/linear/__pycache__/__init__.cpython-311.pyc +0 -0
  19. lib/linear/__pycache__/comb_linear.cpython-311.pyc +0 -0
  20. lib/linear/__pycache__/incoherent_linear.cpython-311.pyc +0 -0
  21. lib/linear/__pycache__/quantized_linear.cpython-311.pyc +0 -0
  22. lib/linear/__pycache__/tcq_linear.cpython-311.pyc +0 -0
  23. lib/linear/__pycache__/vq_linear.cpython-311.pyc +0 -0
  24. lib/linear/comb_linear.py +325 -0
  25. lib/linear/incoherent_linear.py +639 -0
  26. lib/linear/quantized_linear.py +154 -0
  27. lib/linear/rotation.py +16 -0
  28. lib/linear/tcq_linear.py +122 -0
  29. lib/linear/vq_linear.py +208 -0
  30. lib/quantizer/__pycache__/comb_quant.cpython-311.pyc +0 -0
  31. lib/quantizer/__pycache__/nuq_op.cpython-311.pyc +0 -0
  32. lib/quantizer/__pycache__/pack_op.cpython-311.pyc +0 -0
  33. lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.1.nbc +0 -0
  34. lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.nbi +0 -0
  35. lib/quantizer/__pycache__/pack_op.pack_32-242.py311.1.nbc +0 -0
  36. lib/quantizer/__pycache__/pack_op.pack_32-242.py311.nbi +0 -0
  37. lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc +3 -0
  38. lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.nbi +0 -0
  39. lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.1.nbc +0 -0
  40. lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.nbi +0 -0
  41. lib/quantizer/__pycache__/quant_op.cpython-311.pyc +0 -0
  42. lib/quantizer/__pycache__/tcq_quant.cpython-311.pyc +0 -0
  43. lib/quantizer/__pycache__/vq_quant.cpython-311.pyc +0 -0
  44. lib/quantizer/__pycache__/vq_quant_ldlq.cpython-311.pyc +0 -0
  45. lib/quantizer/comb_quant.py +201 -0
  46. lib/quantizer/nuq_op.py +431 -0
  47. lib/quantizer/pack_op.py +335 -0
  48. lib/quantizer/quant_op.py +277 -0
  49. lib/quantizer/tcq_quant.py +160 -0
  50. lib/quantizer/vq_quant.py +149 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc filter=lfs diff=lfs merge=lfs -text
37
+ lib/utils/__pycache__/matmul_had.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
38
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,2885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "meta-llama/Llama-3.2-1B",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoModelForCausalLM": "qpal_modelling_llama.QPalLlamaForCausalLM"
10
+ },
11
+ "bos_token_id": 128000,
12
+ "eos_token_id": 128001,
13
+ "head_dim": 64,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 2048,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 8192,
18
+ "max_position_embeddings": 131072,
19
+ "mlp_bias": false,
20
+ "model_type": "llama",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 16,
23
+ "num_key_value_heads": 8,
24
+ "pretraining_tp": 1,
25
+ "qpal_quant_config": {
26
+ "modules": {
27
+ "model.layers.0.mlp.down_proj": {
28
+ "bias": false,
29
+ "dtype": "float32",
30
+ "hadU": 8192,
31
+ "hadV": 2048,
32
+ "in_features": 8192,
33
+ "linear": {
34
+ "KV": 7,
35
+ "L": 16,
36
+ "V": 2,
37
+ "bias": false,
38
+ "in_features": 8192,
39
+ "linear_cls": "QTIPLinearTCQ",
40
+ "linear_dtype": "float32",
41
+ "out_features": 2048,
42
+ "td_x": 16,
43
+ "td_y": 16,
44
+ "tlut_bits": 9
45
+ },
46
+ "module_type": "IncoherentLinear",
47
+ "out_features": 2048,
48
+ "rot_info": "skip_r",
49
+ "scale": 32.0
50
+ },
51
+ "model.layers.0.mlp.gate_proj": {
52
+ "bias": false,
53
+ "dtype": "float32",
54
+ "hadU": 2048,
55
+ "hadV": 8192,
56
+ "in_features": 2048,
57
+ "linear": {
58
+ "KV": 5,
59
+ "L": 16,
60
+ "V": 2,
61
+ "bias": false,
62
+ "in_features": 2048,
63
+ "linear_cls": "QTIPLinearTCQ",
64
+ "linear_dtype": "float32",
65
+ "out_features": 8192,
66
+ "td_x": 16,
67
+ "td_y": 16,
68
+ "tlut_bits": 9
69
+ },
70
+ "module_type": "IncoherentLinear",
71
+ "out_features": 8192,
72
+ "rot_info": "skip_r",
73
+ "scale": 32.0
74
+ },
75
+ "model.layers.0.mlp.up_proj": {
76
+ "bias": false,
77
+ "dtype": "float32",
78
+ "hadU": 2048,
79
+ "hadV": 8192,
80
+ "in_features": 2048,
81
+ "linear": {
82
+ "KV": 6,
83
+ "L": 16,
84
+ "V": 2,
85
+ "bias": false,
86
+ "in_features": 2048,
87
+ "linear_cls": "QTIPLinearTCQ",
88
+ "linear_dtype": "float32",
89
+ "out_features": 8192,
90
+ "td_x": 16,
91
+ "td_y": 16,
92
+ "tlut_bits": 9
93
+ },
94
+ "module_type": "IncoherentLinear",
95
+ "out_features": 8192,
96
+ "rot_info": "skip_r",
97
+ "scale": 32.0
98
+ },
99
+ "model.layers.0.self_attn.k_proj": {
100
+ "bias": false,
101
+ "dtype": "float32",
102
+ "hadU": 2048,
103
+ "hadV": 512,
104
+ "in_features": 2048,
105
+ "linear": {
106
+ "KV": 7,
107
+ "L": 16,
108
+ "V": 2,
109
+ "bias": false,
110
+ "in_features": 2048,
111
+ "linear_cls": "QTIPLinearTCQ",
112
+ "linear_dtype": "float32",
113
+ "out_features": 512,
114
+ "td_x": 16,
115
+ "td_y": 16,
116
+ "tlut_bits": 9
117
+ },
118
+ "module_type": "IncoherentLinear",
119
+ "out_features": 512,
120
+ "rot_info": "skip_r",
121
+ "scale": 32.0
122
+ },
123
+ "model.layers.0.self_attn.o_proj": {
124
+ "bias": false,
125
+ "dtype": "float32",
126
+ "hadU": 2048,
127
+ "hadV": 2048,
128
+ "in_features": 2048,
129
+ "linear": {
130
+ "KV": 7,
131
+ "L": 16,
132
+ "V": 2,
133
+ "bias": false,
134
+ "in_features": 2048,
135
+ "linear_cls": "QTIPLinearTCQ",
136
+ "linear_dtype": "float32",
137
+ "out_features": 2048,
138
+ "td_x": 16,
139
+ "td_y": 16,
140
+ "tlut_bits": 9
141
+ },
142
+ "module_type": "IncoherentLinear",
143
+ "out_features": 2048,
144
+ "rot_info": "skip_r",
145
+ "scale": 32.0
146
+ },
147
+ "model.layers.0.self_attn.q_proj": {
148
+ "bias": false,
149
+ "dtype": "float32",
150
+ "hadU": 2048,
151
+ "hadV": 2048,
152
+ "in_features": 2048,
153
+ "linear": {
154
+ "KV": 4,
155
+ "L": 16,
156
+ "V": 2,
157
+ "bias": false,
158
+ "in_features": 2048,
159
+ "linear_cls": "QTIPLinearTCQ",
160
+ "linear_dtype": "float32",
161
+ "out_features": 2048,
162
+ "td_x": 16,
163
+ "td_y": 16,
164
+ "tlut_bits": 9
165
+ },
166
+ "module_type": "IncoherentLinear",
167
+ "out_features": 2048,
168
+ "rot_info": "skip_r",
169
+ "scale": 32.0
170
+ },
171
+ "model.layers.0.self_attn.v_proj": {
172
+ "bias": false,
173
+ "dtype": "float32",
174
+ "hadU": 2048,
175
+ "hadV": 512,
176
+ "in_features": 2048,
177
+ "linear": {
178
+ "KV": [
179
+ 9,
180
+ 10
181
+ ],
182
+ "L": 16,
183
+ "V": 2,
184
+ "bias": false,
185
+ "in_features": 2048,
186
+ "in_part": [
187
+ 1024,
188
+ 1024
189
+ ],
190
+ "linear_cls": "CombtLinearTCQ",
191
+ "linear_dtype": "float32",
192
+ "out_features": 512,
193
+ "td_x": 16,
194
+ "td_y": 16,
195
+ "tlut_bits": 11
196
+ },
197
+ "module_type": "IncoherentLinear",
198
+ "out_features": 512,
199
+ "rot_info": "skip_r",
200
+ "scale": 32.0
201
+ },
202
+ "model.layers.1.mlp.down_proj": {
203
+ "bias": false,
204
+ "dtype": "float32",
205
+ "hadU": 8192,
206
+ "hadV": 2048,
207
+ "in_features": 8192,
208
+ "linear": {
209
+ "KV": 10,
210
+ "L": 16,
211
+ "V": 2,
212
+ "bias": false,
213
+ "in_features": 8192,
214
+ "linear_cls": "QTIPLinearTCQ",
215
+ "linear_dtype": "float32",
216
+ "out_features": 2048,
217
+ "td_x": 16,
218
+ "td_y": 16,
219
+ "tlut_bits": 11
220
+ },
221
+ "module_type": "IncoherentLinear",
222
+ "out_features": 2048,
223
+ "rot_info": "skip_r",
224
+ "scale": 32.0
225
+ },
226
+ "model.layers.1.mlp.gate_proj": {
227
+ "bias": false,
228
+ "dtype": "float32",
229
+ "hadU": 2048,
230
+ "hadV": 8192,
231
+ "in_features": 2048,
232
+ "linear": {
233
+ "KV": 5,
234
+ "L": 16,
235
+ "V": 2,
236
+ "bias": false,
237
+ "in_features": 2048,
238
+ "linear_cls": "QTIPLinearTCQ",
239
+ "linear_dtype": "float32",
240
+ "out_features": 8192,
241
+ "td_x": 16,
242
+ "td_y": 16,
243
+ "tlut_bits": 9
244
+ },
245
+ "module_type": "IncoherentLinear",
246
+ "out_features": 8192,
247
+ "rot_info": "skip_r",
248
+ "scale": 32.0
249
+ },
250
+ "model.layers.1.mlp.up_proj": {
251
+ "bias": false,
252
+ "dtype": "float32",
253
+ "hadU": 2048,
254
+ "hadV": 8192,
255
+ "in_features": 2048,
256
+ "linear": {
257
+ "KV": 5,
258
+ "L": 16,
259
+ "V": 2,
260
+ "bias": false,
261
+ "in_features": 2048,
262
+ "linear_cls": "QTIPLinearTCQ",
263
+ "linear_dtype": "float32",
264
+ "out_features": 8192,
265
+ "td_x": 16,
266
+ "td_y": 16,
267
+ "tlut_bits": 9
268
+ },
269
+ "module_type": "IncoherentLinear",
270
+ "out_features": 8192,
271
+ "rot_info": "skip_r",
272
+ "scale": 32.0
273
+ },
274
+ "model.layers.1.self_attn.k_proj": {
275
+ "bias": false,
276
+ "dtype": "float32",
277
+ "hadU": 2048,
278
+ "hadV": 512,
279
+ "in_features": 2048,
280
+ "linear": {
281
+ "KV": 9,
282
+ "L": 16,
283
+ "V": 2,
284
+ "bias": false,
285
+ "in_features": 2048,
286
+ "linear_cls": "QTIPLinearTCQ",
287
+ "linear_dtype": "float32",
288
+ "out_features": 512,
289
+ "td_x": 16,
290
+ "td_y": 16,
291
+ "tlut_bits": 10
292
+ },
293
+ "module_type": "IncoherentLinear",
294
+ "out_features": 512,
295
+ "rot_info": "skip_r",
296
+ "scale": 32.0
297
+ },
298
+ "model.layers.1.self_attn.o_proj": {
299
+ "bias": false,
300
+ "dtype": "float32",
301
+ "hadU": 2048,
302
+ "hadV": 2048,
303
+ "in_features": 2048,
304
+ "linear": {
305
+ "KV": 7,
306
+ "L": 16,
307
+ "V": 2,
308
+ "bias": false,
309
+ "in_features": 2048,
310
+ "linear_cls": "QTIPLinearTCQ",
311
+ "linear_dtype": "float32",
312
+ "out_features": 2048,
313
+ "td_x": 16,
314
+ "td_y": 16,
315
+ "tlut_bits": 9
316
+ },
317
+ "module_type": "IncoherentLinear",
318
+ "out_features": 2048,
319
+ "rot_info": "skip_r",
320
+ "scale": 32.0
321
+ },
322
+ "model.layers.1.self_attn.q_proj": {
323
+ "bias": false,
324
+ "dtype": "float32",
325
+ "hadU": 2048,
326
+ "hadV": 2048,
327
+ "in_features": 2048,
328
+ "linear": {
329
+ "KV": 7,
330
+ "L": 16,
331
+ "V": 2,
332
+ "bias": false,
333
+ "in_features": 2048,
334
+ "linear_cls": "QTIPLinearTCQ",
335
+ "linear_dtype": "float32",
336
+ "out_features": 2048,
337
+ "td_x": 16,
338
+ "td_y": 16,
339
+ "tlut_bits": 9
340
+ },
341
+ "module_type": "IncoherentLinear",
342
+ "out_features": 2048,
343
+ "rot_info": "skip_r",
344
+ "scale": 32.0
345
+ },
346
+ "model.layers.1.self_attn.v_proj": {
347
+ "bias": false,
348
+ "dtype": "float32",
349
+ "hadU": 2048,
350
+ "hadV": 512,
351
+ "in_features": 2048,
352
+ "linear": {
353
+ "KV": [
354
+ 9,
355
+ 10
356
+ ],
357
+ "L": 16,
358
+ "V": 2,
359
+ "bias": false,
360
+ "in_features": 2048,
361
+ "in_part": [
362
+ 1024,
363
+ 1024
364
+ ],
365
+ "linear_cls": "CombtLinearTCQ",
366
+ "linear_dtype": "float32",
367
+ "out_features": 512,
368
+ "td_x": 16,
369
+ "td_y": 16,
370
+ "tlut_bits": 11
371
+ },
372
+ "module_type": "IncoherentLinear",
373
+ "out_features": 512,
374
+ "rot_info": "skip_r",
375
+ "scale": 32.0
376
+ },
377
+ "model.layers.10.mlp.down_proj": {
378
+ "bias": false,
379
+ "dtype": "float32",
380
+ "hadU": 8192,
381
+ "hadV": 2048,
382
+ "in_features": 8192,
383
+ "linear": {
384
+ "KV": 7,
385
+ "L": 16,
386
+ "V": 2,
387
+ "bias": false,
388
+ "in_features": 8192,
389
+ "linear_cls": "QTIPLinearTCQ",
390
+ "linear_dtype": "float32",
391
+ "out_features": 2048,
392
+ "td_x": 16,
393
+ "td_y": 16,
394
+ "tlut_bits": 9
395
+ },
396
+ "module_type": "IncoherentLinear",
397
+ "out_features": 2048,
398
+ "rot_info": "skip_r",
399
+ "scale": 32.0
400
+ },
401
+ "model.layers.10.mlp.gate_proj": {
402
+ "bias": false,
403
+ "dtype": "float32",
404
+ "hadU": 2048,
405
+ "hadV": 8192,
406
+ "in_features": 2048,
407
+ "linear": {
408
+ "KV": 6,
409
+ "L": 16,
410
+ "V": 2,
411
+ "bias": false,
412
+ "in_features": 2048,
413
+ "linear_cls": "QTIPLinearTCQ",
414
+ "linear_dtype": "float32",
415
+ "out_features": 8192,
416
+ "td_x": 16,
417
+ "td_y": 16,
418
+ "tlut_bits": 9
419
+ },
420
+ "module_type": "IncoherentLinear",
421
+ "out_features": 8192,
422
+ "rot_info": "skip_r",
423
+ "scale": 32.0
424
+ },
425
+ "model.layers.10.mlp.up_proj": {
426
+ "bias": false,
427
+ "dtype": "float32",
428
+ "hadU": 2048,
429
+ "hadV": 8192,
430
+ "in_features": 2048,
431
+ "linear": {
432
+ "KV": 6,
433
+ "L": 16,
434
+ "V": 2,
435
+ "bias": false,
436
+ "in_features": 2048,
437
+ "linear_cls": "QTIPLinearTCQ",
438
+ "linear_dtype": "float32",
439
+ "out_features": 8192,
440
+ "td_x": 16,
441
+ "td_y": 16,
442
+ "tlut_bits": 9
443
+ },
444
+ "module_type": "IncoherentLinear",
445
+ "out_features": 8192,
446
+ "rot_info": "skip_r",
447
+ "scale": 32.0
448
+ },
449
+ "model.layers.10.self_attn.k_proj": {
450
+ "bias": false,
451
+ "dtype": "float32",
452
+ "hadU": 2048,
453
+ "hadV": 512,
454
+ "in_features": 2048,
455
+ "linear": {
456
+ "KV": 9,
457
+ "L": 16,
458
+ "V": 2,
459
+ "bias": false,
460
+ "in_features": 2048,
461
+ "linear_cls": "QTIPLinearTCQ",
462
+ "linear_dtype": "float32",
463
+ "out_features": 512,
464
+ "td_x": 16,
465
+ "td_y": 16,
466
+ "tlut_bits": 10
467
+ },
468
+ "module_type": "IncoherentLinear",
469
+ "out_features": 512,
470
+ "rot_info": "skip_r",
471
+ "scale": 32.0
472
+ },
473
+ "model.layers.10.self_attn.o_proj": {
474
+ "bias": false,
475
+ "dtype": "float32",
476
+ "hadU": 2048,
477
+ "hadV": 2048,
478
+ "in_features": 2048,
479
+ "linear": {
480
+ "KV": 8,
481
+ "L": 16,
482
+ "V": 2,
483
+ "bias": false,
484
+ "in_features": 2048,
485
+ "linear_cls": "QTIPLinearTCQ",
486
+ "linear_dtype": "float32",
487
+ "out_features": 2048,
488
+ "td_x": 16,
489
+ "td_y": 16,
490
+ "tlut_bits": 9
491
+ },
492
+ "module_type": "IncoherentLinear",
493
+ "out_features": 2048,
494
+ "rot_info": "skip_r",
495
+ "scale": 32.0
496
+ },
497
+ "model.layers.10.self_attn.q_proj": {
498
+ "bias": false,
499
+ "dtype": "float32",
500
+ "hadU": 2048,
501
+ "hadV": 2048,
502
+ "in_features": 2048,
503
+ "linear": {
504
+ "KV": 8,
505
+ "L": 16,
506
+ "V": 2,
507
+ "bias": false,
508
+ "in_features": 2048,
509
+ "linear_cls": "QTIPLinearTCQ",
510
+ "linear_dtype": "float32",
511
+ "out_features": 2048,
512
+ "td_x": 16,
513
+ "td_y": 16,
514
+ "tlut_bits": 9
515
+ },
516
+ "module_type": "IncoherentLinear",
517
+ "out_features": 2048,
518
+ "rot_info": "skip_r",
519
+ "scale": 32.0
520
+ },
521
+ "model.layers.10.self_attn.v_proj": {
522
+ "bias": false,
523
+ "dtype": "float32",
524
+ "hadU": 2048,
525
+ "hadV": 512,
526
+ "in_features": 2048,
527
+ "linear": {
528
+ "KV": [
529
+ 9,
530
+ 10
531
+ ],
532
+ "L": 16,
533
+ "V": 2,
534
+ "bias": false,
535
+ "in_features": 2048,
536
+ "in_part": [
537
+ 1024,
538
+ 1024
539
+ ],
540
+ "linear_cls": "CombtLinearTCQ",
541
+ "linear_dtype": "float32",
542
+ "out_features": 512,
543
+ "td_x": 16,
544
+ "td_y": 16,
545
+ "tlut_bits": 11
546
+ },
547
+ "module_type": "IncoherentLinear",
548
+ "out_features": 512,
549
+ "rot_info": "skip_r",
550
+ "scale": 32.0
551
+ },
552
+ "model.layers.11.mlp.down_proj": {
553
+ "bias": false,
554
+ "dtype": "float32",
555
+ "hadU": 8192,
556
+ "hadV": 2048,
557
+ "in_features": 8192,
558
+ "linear": {
559
+ "KV": 7,
560
+ "L": 16,
561
+ "V": 2,
562
+ "bias": false,
563
+ "in_features": 8192,
564
+ "linear_cls": "QTIPLinearTCQ",
565
+ "linear_dtype": "float32",
566
+ "out_features": 2048,
567
+ "td_x": 16,
568
+ "td_y": 16,
569
+ "tlut_bits": 9
570
+ },
571
+ "module_type": "IncoherentLinear",
572
+ "out_features": 2048,
573
+ "rot_info": "skip_r",
574
+ "scale": 32.0
575
+ },
576
+ "model.layers.11.mlp.gate_proj": {
577
+ "bias": false,
578
+ "dtype": "float32",
579
+ "hadU": 2048,
580
+ "hadV": 8192,
581
+ "in_features": 2048,
582
+ "linear": {
583
+ "KV": 6,
584
+ "L": 16,
585
+ "V": 2,
586
+ "bias": false,
587
+ "in_features": 2048,
588
+ "linear_cls": "QTIPLinearTCQ",
589
+ "linear_dtype": "float32",
590
+ "out_features": 8192,
591
+ "td_x": 16,
592
+ "td_y": 16,
593
+ "tlut_bits": 9
594
+ },
595
+ "module_type": "IncoherentLinear",
596
+ "out_features": 8192,
597
+ "rot_info": "skip_r",
598
+ "scale": 32.0
599
+ },
600
+ "model.layers.11.mlp.up_proj": {
601
+ "bias": false,
602
+ "dtype": "float32",
603
+ "hadU": 2048,
604
+ "hadV": 8192,
605
+ "in_features": 2048,
606
+ "linear": {
607
+ "KV": 6,
608
+ "L": 16,
609
+ "V": 2,
610
+ "bias": false,
611
+ "in_features": 2048,
612
+ "linear_cls": "QTIPLinearTCQ",
613
+ "linear_dtype": "float32",
614
+ "out_features": 8192,
615
+ "td_x": 16,
616
+ "td_y": 16,
617
+ "tlut_bits": 9
618
+ },
619
+ "module_type": "IncoherentLinear",
620
+ "out_features": 8192,
621
+ "rot_info": "skip_r",
622
+ "scale": 32.0
623
+ },
624
+ "model.layers.11.self_attn.k_proj": {
625
+ "bias": false,
626
+ "dtype": "float32",
627
+ "hadU": 2048,
628
+ "hadV": 512,
629
+ "in_features": 2048,
630
+ "linear": {
631
+ "KV": [
632
+ 8,
633
+ 9
634
+ ],
635
+ "L": 16,
636
+ "V": 2,
637
+ "bias": false,
638
+ "in_features": 2048,
639
+ "in_part": [
640
+ 1024,
641
+ 1024
642
+ ],
643
+ "linear_cls": "CombtLinearTCQ",
644
+ "linear_dtype": "float32",
645
+ "out_features": 512,
646
+ "td_x": 16,
647
+ "td_y": 16,
648
+ "tlut_bits": 10
649
+ },
650
+ "module_type": "IncoherentLinear",
651
+ "out_features": 512,
652
+ "rot_info": "skip_r",
653
+ "scale": 32.0
654
+ },
655
+ "model.layers.11.self_attn.o_proj": {
656
+ "bias": false,
657
+ "dtype": "float32",
658
+ "hadU": 2048,
659
+ "hadV": 2048,
660
+ "in_features": 2048,
661
+ "linear": {
662
+ "KV": 7,
663
+ "L": 16,
664
+ "V": 2,
665
+ "bias": false,
666
+ "in_features": 2048,
667
+ "linear_cls": "QTIPLinearTCQ",
668
+ "linear_dtype": "float32",
669
+ "out_features": 2048,
670
+ "td_x": 16,
671
+ "td_y": 16,
672
+ "tlut_bits": 9
673
+ },
674
+ "module_type": "IncoherentLinear",
675
+ "out_features": 2048,
676
+ "rot_info": "skip_r",
677
+ "scale": 32.0
678
+ },
679
+ "model.layers.11.self_attn.q_proj": {
680
+ "bias": false,
681
+ "dtype": "float32",
682
+ "hadU": 2048,
683
+ "hadV": 2048,
684
+ "in_features": 2048,
685
+ "linear": {
686
+ "KV": 7,
687
+ "L": 16,
688
+ "V": 2,
689
+ "bias": false,
690
+ "in_features": 2048,
691
+ "linear_cls": "QTIPLinearTCQ",
692
+ "linear_dtype": "float32",
693
+ "out_features": 2048,
694
+ "td_x": 16,
695
+ "td_y": 16,
696
+ "tlut_bits": 9
697
+ },
698
+ "module_type": "IncoherentLinear",
699
+ "out_features": 2048,
700
+ "rot_info": "skip_r",
701
+ "scale": 32.0
702
+ },
703
+ "model.layers.11.self_attn.v_proj": {
704
+ "bias": false,
705
+ "dtype": "float32",
706
+ "hadU": 2048,
707
+ "hadV": 512,
708
+ "in_features": 2048,
709
+ "linear": {
710
+ "KV": 9,
711
+ "L": 16,
712
+ "V": 2,
713
+ "bias": false,
714
+ "in_features": 2048,
715
+ "linear_cls": "QTIPLinearTCQ",
716
+ "linear_dtype": "float32",
717
+ "out_features": 512,
718
+ "td_x": 16,
719
+ "td_y": 16,
720
+ "tlut_bits": 10
721
+ },
722
+ "module_type": "IncoherentLinear",
723
+ "out_features": 512,
724
+ "rot_info": "skip_r",
725
+ "scale": 32.0
726
+ },
727
+ "model.layers.12.mlp.down_proj": {
728
+ "bias": false,
729
+ "dtype": "float32",
730
+ "hadU": 8192,
731
+ "hadV": 2048,
732
+ "in_features": 8192,
733
+ "linear": {
734
+ "KV": 6,
735
+ "L": 16,
736
+ "V": 2,
737
+ "bias": false,
738
+ "in_features": 8192,
739
+ "linear_cls": "QTIPLinearTCQ",
740
+ "linear_dtype": "float32",
741
+ "out_features": 2048,
742
+ "td_x": 16,
743
+ "td_y": 16,
744
+ "tlut_bits": 9
745
+ },
746
+ "module_type": "IncoherentLinear",
747
+ "out_features": 2048,
748
+ "rot_info": "skip_r",
749
+ "scale": 32.0
750
+ },
751
+ "model.layers.12.mlp.gate_proj": {
752
+ "bias": false,
753
+ "dtype": "float32",
754
+ "hadU": 2048,
755
+ "hadV": 8192,
756
+ "in_features": 2048,
757
+ "linear": {
758
+ "KV": 6,
759
+ "L": 16,
760
+ "V": 2,
761
+ "bias": false,
762
+ "in_features": 2048,
763
+ "linear_cls": "QTIPLinearTCQ",
764
+ "linear_dtype": "float32",
765
+ "out_features": 8192,
766
+ "td_x": 16,
767
+ "td_y": 16,
768
+ "tlut_bits": 9
769
+ },
770
+ "module_type": "IncoherentLinear",
771
+ "out_features": 8192,
772
+ "rot_info": "skip_r",
773
+ "scale": 32.0
774
+ },
775
+ "model.layers.12.mlp.up_proj": {
776
+ "bias": false,
777
+ "dtype": "float32",
778
+ "hadU": 2048,
779
+ "hadV": 8192,
780
+ "in_features": 2048,
781
+ "linear": {
782
+ "KV": 6,
783
+ "L": 16,
784
+ "V": 2,
785
+ "bias": false,
786
+ "in_features": 2048,
787
+ "linear_cls": "QTIPLinearTCQ",
788
+ "linear_dtype": "float32",
789
+ "out_features": 8192,
790
+ "td_x": 16,
791
+ "td_y": 16,
792
+ "tlut_bits": 9
793
+ },
794
+ "module_type": "IncoherentLinear",
795
+ "out_features": 8192,
796
+ "rot_info": "skip_r",
797
+ "scale": 32.0
798
+ },
799
+ "model.layers.12.self_attn.k_proj": {
800
+ "bias": false,
801
+ "dtype": "float32",
802
+ "hadU": 2048,
803
+ "hadV": 512,
804
+ "in_features": 2048,
805
+ "linear": {
806
+ "KV": 9,
807
+ "L": 16,
808
+ "V": 2,
809
+ "bias": false,
810
+ "in_features": 2048,
811
+ "linear_cls": "QTIPLinearTCQ",
812
+ "linear_dtype": "float32",
813
+ "out_features": 512,
814
+ "td_x": 16,
815
+ "td_y": 16,
816
+ "tlut_bits": 10
817
+ },
818
+ "module_type": "IncoherentLinear",
819
+ "out_features": 512,
820
+ "rot_info": "skip_r",
821
+ "scale": 32.0
822
+ },
823
+ "model.layers.12.self_attn.o_proj": {
824
+ "bias": false,
825
+ "dtype": "float32",
826
+ "hadU": 2048,
827
+ "hadV": 2048,
828
+ "in_features": 2048,
829
+ "linear": {
830
+ "KV": 7,
831
+ "L": 16,
832
+ "V": 2,
833
+ "bias": false,
834
+ "in_features": 2048,
835
+ "linear_cls": "QTIPLinearTCQ",
836
+ "linear_dtype": "float32",
837
+ "out_features": 2048,
838
+ "td_x": 16,
839
+ "td_y": 16,
840
+ "tlut_bits": 9
841
+ },
842
+ "module_type": "IncoherentLinear",
843
+ "out_features": 2048,
844
+ "rot_info": "skip_r",
845
+ "scale": 32.0
846
+ },
847
+ "model.layers.12.self_attn.q_proj": {
848
+ "bias": false,
849
+ "dtype": "float32",
850
+ "hadU": 2048,
851
+ "hadV": 2048,
852
+ "in_features": 2048,
853
+ "linear": {
854
+ "KV": 7,
855
+ "L": 16,
856
+ "V": 2,
857
+ "bias": false,
858
+ "in_features": 2048,
859
+ "linear_cls": "QTIPLinearTCQ",
860
+ "linear_dtype": "float32",
861
+ "out_features": 2048,
862
+ "td_x": 16,
863
+ "td_y": 16,
864
+ "tlut_bits": 9
865
+ },
866
+ "module_type": "IncoherentLinear",
867
+ "out_features": 2048,
868
+ "rot_info": "skip_r",
869
+ "scale": 32.0
870
+ },
871
+ "model.layers.12.self_attn.v_proj": {
872
+ "bias": false,
873
+ "dtype": "float32",
874
+ "hadU": 2048,
875
+ "hadV": 512,
876
+ "in_features": 2048,
877
+ "linear": {
878
+ "KV": 9,
879
+ "L": 16,
880
+ "V": 2,
881
+ "bias": false,
882
+ "in_features": 2048,
883
+ "linear_cls": "QTIPLinearTCQ",
884
+ "linear_dtype": "float32",
885
+ "out_features": 512,
886
+ "td_x": 16,
887
+ "td_y": 16,
888
+ "tlut_bits": 10
889
+ },
890
+ "module_type": "IncoherentLinear",
891
+ "out_features": 512,
892
+ "rot_info": "skip_r",
893
+ "scale": 32.0
894
+ },
895
+ "model.layers.13.mlp.down_proj": {
896
+ "bias": false,
897
+ "dtype": "float32",
898
+ "hadU": 8192,
899
+ "hadV": 2048,
900
+ "in_features": 8192,
901
+ "linear": {
902
+ "KV": 6,
903
+ "L": 16,
904
+ "V": 2,
905
+ "bias": false,
906
+ "in_features": 8192,
907
+ "linear_cls": "QTIPLinearTCQ",
908
+ "linear_dtype": "float32",
909
+ "out_features": 2048,
910
+ "td_x": 16,
911
+ "td_y": 16,
912
+ "tlut_bits": 9
913
+ },
914
+ "module_type": "IncoherentLinear",
915
+ "out_features": 2048,
916
+ "rot_info": "skip_r",
917
+ "scale": 32.0
918
+ },
919
+ "model.layers.13.mlp.gate_proj": {
920
+ "bias": false,
921
+ "dtype": "float32",
922
+ "hadU": 2048,
923
+ "hadV": 8192,
924
+ "in_features": 2048,
925
+ "linear": {
926
+ "KV": 6,
927
+ "L": 16,
928
+ "V": 2,
929
+ "bias": false,
930
+ "in_features": 2048,
931
+ "linear_cls": "QTIPLinearTCQ",
932
+ "linear_dtype": "float32",
933
+ "out_features": 8192,
934
+ "td_x": 16,
935
+ "td_y": 16,
936
+ "tlut_bits": 9
937
+ },
938
+ "module_type": "IncoherentLinear",
939
+ "out_features": 8192,
940
+ "rot_info": "skip_r",
941
+ "scale": 32.0
942
+ },
943
+ "model.layers.13.mlp.up_proj": {
944
+ "bias": false,
945
+ "dtype": "float32",
946
+ "hadU": 2048,
947
+ "hadV": 8192,
948
+ "in_features": 2048,
949
+ "linear": {
950
+ "KV": 6,
951
+ "L": 16,
952
+ "V": 2,
953
+ "bias": false,
954
+ "in_features": 2048,
955
+ "linear_cls": "QTIPLinearTCQ",
956
+ "linear_dtype": "float32",
957
+ "out_features": 8192,
958
+ "td_x": 16,
959
+ "td_y": 16,
960
+ "tlut_bits": 9
961
+ },
962
+ "module_type": "IncoherentLinear",
963
+ "out_features": 8192,
964
+ "rot_info": "skip_r",
965
+ "scale": 32.0
966
+ },
967
+ "model.layers.13.self_attn.k_proj": {
968
+ "bias": false,
969
+ "dtype": "float32",
970
+ "hadU": 2048,
971
+ "hadV": 512,
972
+ "in_features": 2048,
973
+ "linear": {
974
+ "KV": 8,
975
+ "L": 16,
976
+ "V": 2,
977
+ "bias": false,
978
+ "in_features": 2048,
979
+ "linear_cls": "QTIPLinearTCQ",
980
+ "linear_dtype": "float32",
981
+ "out_features": 512,
982
+ "td_x": 16,
983
+ "td_y": 16,
984
+ "tlut_bits": 9
985
+ },
986
+ "module_type": "IncoherentLinear",
987
+ "out_features": 512,
988
+ "rot_info": "skip_r",
989
+ "scale": 32.0
990
+ },
991
+ "model.layers.13.self_attn.o_proj": {
992
+ "bias": false,
993
+ "dtype": "float32",
994
+ "hadU": 2048,
995
+ "hadV": 2048,
996
+ "in_features": 2048,
997
+ "linear": {
998
+ "KV": 7,
999
+ "L": 16,
1000
+ "V": 2,
1001
+ "bias": false,
1002
+ "in_features": 2048,
1003
+ "linear_cls": "QTIPLinearTCQ",
1004
+ "linear_dtype": "float32",
1005
+ "out_features": 2048,
1006
+ "td_x": 16,
1007
+ "td_y": 16,
1008
+ "tlut_bits": 9
1009
+ },
1010
+ "module_type": "IncoherentLinear",
1011
+ "out_features": 2048,
1012
+ "rot_info": "skip_r",
1013
+ "scale": 32.0
1014
+ },
1015
+ "model.layers.13.self_attn.q_proj": {
1016
+ "bias": false,
1017
+ "dtype": "float32",
1018
+ "hadU": 2048,
1019
+ "hadV": 2048,
1020
+ "in_features": 2048,
1021
+ "linear": {
1022
+ "KV": 8,
1023
+ "L": 16,
1024
+ "V": 2,
1025
+ "bias": false,
1026
+ "in_features": 2048,
1027
+ "linear_cls": "QTIPLinearTCQ",
1028
+ "linear_dtype": "float32",
1029
+ "out_features": 2048,
1030
+ "td_x": 16,
1031
+ "td_y": 16,
1032
+ "tlut_bits": 9
1033
+ },
1034
+ "module_type": "IncoherentLinear",
1035
+ "out_features": 2048,
1036
+ "rot_info": "skip_r",
1037
+ "scale": 32.0
1038
+ },
1039
+ "model.layers.13.self_attn.v_proj": {
1040
+ "bias": false,
1041
+ "dtype": "float32",
1042
+ "hadU": 2048,
1043
+ "hadV": 512,
1044
+ "in_features": 2048,
1045
+ "linear": {
1046
+ "KV": 9,
1047
+ "L": 16,
1048
+ "V": 2,
1049
+ "bias": false,
1050
+ "in_features": 2048,
1051
+ "linear_cls": "QTIPLinearTCQ",
1052
+ "linear_dtype": "float32",
1053
+ "out_features": 512,
1054
+ "td_x": 16,
1055
+ "td_y": 16,
1056
+ "tlut_bits": 10
1057
+ },
1058
+ "module_type": "IncoherentLinear",
1059
+ "out_features": 512,
1060
+ "rot_info": "skip_r",
1061
+ "scale": 32.0
1062
+ },
1063
+ "model.layers.14.mlp.down_proj": {
1064
+ "bias": false,
1065
+ "dtype": "float32",
1066
+ "hadU": 8192,
1067
+ "hadV": 2048,
1068
+ "in_features": 8192,
1069
+ "linear": {
1070
+ "KV": 7,
1071
+ "L": 16,
1072
+ "V": 2,
1073
+ "bias": false,
1074
+ "in_features": 8192,
1075
+ "linear_cls": "QTIPLinearTCQ",
1076
+ "linear_dtype": "float32",
1077
+ "out_features": 2048,
1078
+ "td_x": 16,
1079
+ "td_y": 16,
1080
+ "tlut_bits": 9
1081
+ },
1082
+ "module_type": "IncoherentLinear",
1083
+ "out_features": 2048,
1084
+ "rot_info": "skip_r",
1085
+ "scale": 32.0
1086
+ },
1087
+ "model.layers.14.mlp.gate_proj": {
1088
+ "bias": false,
1089
+ "dtype": "float32",
1090
+ "hadU": 2048,
1091
+ "hadV": 8192,
1092
+ "in_features": 2048,
1093
+ "linear": {
1094
+ "KV": 6,
1095
+ "L": 16,
1096
+ "V": 2,
1097
+ "bias": false,
1098
+ "in_features": 2048,
1099
+ "linear_cls": "QTIPLinearTCQ",
1100
+ "linear_dtype": "float32",
1101
+ "out_features": 8192,
1102
+ "td_x": 16,
1103
+ "td_y": 16,
1104
+ "tlut_bits": 9
1105
+ },
1106
+ "module_type": "IncoherentLinear",
1107
+ "out_features": 8192,
1108
+ "rot_info": "skip_r",
1109
+ "scale": 32.0
1110
+ },
1111
+ "model.layers.14.mlp.up_proj": {
1112
+ "bias": false,
1113
+ "dtype": "float32",
1114
+ "hadU": 2048,
1115
+ "hadV": 8192,
1116
+ "in_features": 2048,
1117
+ "linear": {
1118
+ "KV": 6,
1119
+ "L": 16,
1120
+ "V": 2,
1121
+ "bias": false,
1122
+ "in_features": 2048,
1123
+ "linear_cls": "QTIPLinearTCQ",
1124
+ "linear_dtype": "float32",
1125
+ "out_features": 8192,
1126
+ "td_x": 16,
1127
+ "td_y": 16,
1128
+ "tlut_bits": 9
1129
+ },
1130
+ "module_type": "IncoherentLinear",
1131
+ "out_features": 8192,
1132
+ "rot_info": "skip_r",
1133
+ "scale": 32.0
1134
+ },
1135
+ "model.layers.14.self_attn.k_proj": {
1136
+ "bias": false,
1137
+ "dtype": "float32",
1138
+ "hadU": 2048,
1139
+ "hadV": 512,
1140
+ "in_features": 2048,
1141
+ "linear": {
1142
+ "KV": [
1143
+ 8,
1144
+ 9
1145
+ ],
1146
+ "L": 16,
1147
+ "V": 2,
1148
+ "bias": false,
1149
+ "in_features": 2048,
1150
+ "in_part": [
1151
+ 1024,
1152
+ 1024
1153
+ ],
1154
+ "linear_cls": "CombtLinearTCQ",
1155
+ "linear_dtype": "float32",
1156
+ "out_features": 512,
1157
+ "td_x": 16,
1158
+ "td_y": 16,
1159
+ "tlut_bits": 10
1160
+ },
1161
+ "module_type": "IncoherentLinear",
1162
+ "out_features": 512,
1163
+ "rot_info": "skip_r",
1164
+ "scale": 32.0
1165
+ },
1166
+ "model.layers.14.self_attn.o_proj": {
1167
+ "bias": false,
1168
+ "dtype": "float32",
1169
+ "hadU": 2048,
1170
+ "hadV": 2048,
1171
+ "in_features": 2048,
1172
+ "linear": {
1173
+ "KV": 7,
1174
+ "L": 16,
1175
+ "V": 2,
1176
+ "bias": false,
1177
+ "in_features": 2048,
1178
+ "linear_cls": "QTIPLinearTCQ",
1179
+ "linear_dtype": "float32",
1180
+ "out_features": 2048,
1181
+ "td_x": 16,
1182
+ "td_y": 16,
1183
+ "tlut_bits": 9
1184
+ },
1185
+ "module_type": "IncoherentLinear",
1186
+ "out_features": 2048,
1187
+ "rot_info": "skip_r",
1188
+ "scale": 32.0
1189
+ },
1190
+ "model.layers.14.self_attn.q_proj": {
1191
+ "bias": false,
1192
+ "dtype": "float32",
1193
+ "hadU": 2048,
1194
+ "hadV": 2048,
1195
+ "in_features": 2048,
1196
+ "linear": {
1197
+ "KV": 6,
1198
+ "L": 16,
1199
+ "V": 2,
1200
+ "bias": false,
1201
+ "in_features": 2048,
1202
+ "linear_cls": "QTIPLinearTCQ",
1203
+ "linear_dtype": "float32",
1204
+ "out_features": 2048,
1205
+ "td_x": 16,
1206
+ "td_y": 16,
1207
+ "tlut_bits": 9
1208
+ },
1209
+ "module_type": "IncoherentLinear",
1210
+ "out_features": 2048,
1211
+ "rot_info": "skip_r",
1212
+ "scale": 32.0
1213
+ },
1214
+ "model.layers.14.self_attn.v_proj": {
1215
+ "bias": false,
1216
+ "dtype": "float32",
1217
+ "hadU": 2048,
1218
+ "hadV": 512,
1219
+ "in_features": 2048,
1220
+ "linear": {
1221
+ "KV": [
1222
+ 9,
1223
+ 10
1224
+ ],
1225
+ "L": 16,
1226
+ "V": 2,
1227
+ "bias": false,
1228
+ "in_features": 2048,
1229
+ "in_part": [
1230
+ 1024,
1231
+ 1024
1232
+ ],
1233
+ "linear_cls": "CombtLinearTCQ",
1234
+ "linear_dtype": "float32",
1235
+ "out_features": 512,
1236
+ "td_x": 16,
1237
+ "td_y": 16,
1238
+ "tlut_bits": 11
1239
+ },
1240
+ "module_type": "IncoherentLinear",
1241
+ "out_features": 512,
1242
+ "rot_info": "skip_r",
1243
+ "scale": 32.0
1244
+ },
1245
+ "model.layers.15.mlp.down_proj": {
1246
+ "bias": false,
1247
+ "dtype": "float32",
1248
+ "hadU": 8192,
1249
+ "hadV": 2048,
1250
+ "in_features": 8192,
1251
+ "linear": {
1252
+ "KV": [
1253
+ 8,
1254
+ 9
1255
+ ],
1256
+ "L": 16,
1257
+ "V": 2,
1258
+ "bias": false,
1259
+ "in_features": 8192,
1260
+ "in_part": [
1261
+ 4096,
1262
+ 4096
1263
+ ],
1264
+ "linear_cls": "CombtLinearTCQ",
1265
+ "linear_dtype": "float32",
1266
+ "out_features": 2048,
1267
+ "td_x": 16,
1268
+ "td_y": 16,
1269
+ "tlut_bits": 10
1270
+ },
1271
+ "module_type": "IncoherentLinear",
1272
+ "out_features": 2048,
1273
+ "rot_info": "skip_r",
1274
+ "scale": 32.0
1275
+ },
1276
+ "model.layers.15.mlp.gate_proj": {
1277
+ "bias": false,
1278
+ "dtype": "float32",
1279
+ "hadU": 2048,
1280
+ "hadV": 8192,
1281
+ "in_features": 2048,
1282
+ "linear": {
1283
+ "KV": 6,
1284
+ "L": 16,
1285
+ "V": 2,
1286
+ "bias": false,
1287
+ "in_features": 2048,
1288
+ "linear_cls": "QTIPLinearTCQ",
1289
+ "linear_dtype": "float32",
1290
+ "out_features": 8192,
1291
+ "td_x": 16,
1292
+ "td_y": 16,
1293
+ "tlut_bits": 9
1294
+ },
1295
+ "module_type": "IncoherentLinear",
1296
+ "out_features": 8192,
1297
+ "rot_info": "skip_r",
1298
+ "scale": 32.0
1299
+ },
1300
+ "model.layers.15.mlp.up_proj": {
1301
+ "bias": false,
1302
+ "dtype": "float32",
1303
+ "hadU": 2048,
1304
+ "hadV": 8192,
1305
+ "in_features": 2048,
1306
+ "linear": {
1307
+ "KV": 7,
1308
+ "L": 16,
1309
+ "V": 2,
1310
+ "bias": false,
1311
+ "in_features": 2048,
1312
+ "linear_cls": "QTIPLinearTCQ",
1313
+ "linear_dtype": "float32",
1314
+ "out_features": 8192,
1315
+ "td_x": 16,
1316
+ "td_y": 16,
1317
+ "tlut_bits": 9
1318
+ },
1319
+ "module_type": "IncoherentLinear",
1320
+ "out_features": 8192,
1321
+ "rot_info": "skip_r",
1322
+ "scale": 32.0
1323
+ },
1324
+ "model.layers.15.self_attn.k_proj": {
1325
+ "bias": false,
1326
+ "dtype": "float32",
1327
+ "hadU": 2048,
1328
+ "hadV": 512,
1329
+ "in_features": 2048,
1330
+ "linear": {
1331
+ "KV": [
1332
+ 8,
1333
+ 9
1334
+ ],
1335
+ "L": 16,
1336
+ "V": 2,
1337
+ "bias": false,
1338
+ "in_features": 2048,
1339
+ "in_part": [
1340
+ 1024,
1341
+ 1024
1342
+ ],
1343
+ "linear_cls": "CombtLinearTCQ",
1344
+ "linear_dtype": "float32",
1345
+ "out_features": 512,
1346
+ "td_x": 16,
1347
+ "td_y": 16,
1348
+ "tlut_bits": 10
1349
+ },
1350
+ "module_type": "IncoherentLinear",
1351
+ "out_features": 512,
1352
+ "rot_info": "skip_r",
1353
+ "scale": 32.0
1354
+ },
1355
+ "model.layers.15.self_attn.o_proj": {
1356
+ "bias": false,
1357
+ "dtype": "float32",
1358
+ "hadU": 2048,
1359
+ "hadV": 2048,
1360
+ "in_features": 2048,
1361
+ "linear": {
1362
+ "KV": [
1363
+ 8,
1364
+ 9
1365
+ ],
1366
+ "L": 16,
1367
+ "V": 2,
1368
+ "bias": false,
1369
+ "in_features": 2048,
1370
+ "in_part": [
1371
+ 1024,
1372
+ 1024
1373
+ ],
1374
+ "linear_cls": "CombtLinearTCQ",
1375
+ "linear_dtype": "float32",
1376
+ "out_features": 2048,
1377
+ "td_x": 16,
1378
+ "td_y": 16,
1379
+ "tlut_bits": 10
1380
+ },
1381
+ "module_type": "IncoherentLinear",
1382
+ "out_features": 2048,
1383
+ "rot_info": "skip_r",
1384
+ "scale": 32.0
1385
+ },
1386
+ "model.layers.15.self_attn.q_proj": {
1387
+ "bias": false,
1388
+ "dtype": "float32",
1389
+ "hadU": 2048,
1390
+ "hadV": 2048,
1391
+ "in_features": 2048,
1392
+ "linear": {
1393
+ "KV": 6,
1394
+ "L": 16,
1395
+ "V": 2,
1396
+ "bias": false,
1397
+ "in_features": 2048,
1398
+ "linear_cls": "QTIPLinearTCQ",
1399
+ "linear_dtype": "float32",
1400
+ "out_features": 2048,
1401
+ "td_x": 16,
1402
+ "td_y": 16,
1403
+ "tlut_bits": 9
1404
+ },
1405
+ "module_type": "IncoherentLinear",
1406
+ "out_features": 2048,
1407
+ "rot_info": "skip_r",
1408
+ "scale": 32.0
1409
+ },
1410
+ "model.layers.15.self_attn.v_proj": {
1411
+ "bias": false,
1412
+ "dtype": "float32",
1413
+ "hadU": 2048,
1414
+ "hadV": 512,
1415
+ "in_features": 2048,
1416
+ "linear": {
1417
+ "KV": 9,
1418
+ "L": 16,
1419
+ "V": 2,
1420
+ "bias": false,
1421
+ "in_features": 2048,
1422
+ "linear_cls": "QTIPLinearTCQ",
1423
+ "linear_dtype": "float32",
1424
+ "out_features": 512,
1425
+ "td_x": 16,
1426
+ "td_y": 16,
1427
+ "tlut_bits": 10
1428
+ },
1429
+ "module_type": "IncoherentLinear",
1430
+ "out_features": 512,
1431
+ "rot_info": "skip_r",
1432
+ "scale": 32.0
1433
+ },
1434
+ "model.layers.2.mlp.down_proj": {
1435
+ "bias": false,
1436
+ "dtype": "float32",
1437
+ "hadU": 8192,
1438
+ "hadV": 2048,
1439
+ "in_features": 8192,
1440
+ "linear": {
1441
+ "KV": 6,
1442
+ "L": 16,
1443
+ "V": 2,
1444
+ "bias": false,
1445
+ "in_features": 8192,
1446
+ "linear_cls": "QTIPLinearTCQ",
1447
+ "linear_dtype": "float32",
1448
+ "out_features": 2048,
1449
+ "td_x": 16,
1450
+ "td_y": 16,
1451
+ "tlut_bits": 9
1452
+ },
1453
+ "module_type": "IncoherentLinear",
1454
+ "out_features": 2048,
1455
+ "rot_info": "skip_r",
1456
+ "scale": 32.0
1457
+ },
1458
+ "model.layers.2.mlp.gate_proj": {
1459
+ "bias": false,
1460
+ "dtype": "float32",
1461
+ "hadU": 2048,
1462
+ "hadV": 8192,
1463
+ "in_features": 2048,
1464
+ "linear": {
1465
+ "KV": 5,
1466
+ "L": 16,
1467
+ "V": 2,
1468
+ "bias": false,
1469
+ "in_features": 2048,
1470
+ "linear_cls": "QTIPLinearTCQ",
1471
+ "linear_dtype": "float32",
1472
+ "out_features": 8192,
1473
+ "td_x": 16,
1474
+ "td_y": 16,
1475
+ "tlut_bits": 9
1476
+ },
1477
+ "module_type": "IncoherentLinear",
1478
+ "out_features": 8192,
1479
+ "rot_info": "skip_r",
1480
+ "scale": 32.0
1481
+ },
1482
+ "model.layers.2.mlp.up_proj": {
1483
+ "bias": false,
1484
+ "dtype": "float32",
1485
+ "hadU": 2048,
1486
+ "hadV": 8192,
1487
+ "in_features": 2048,
1488
+ "linear": {
1489
+ "KV": 6,
1490
+ "L": 16,
1491
+ "V": 2,
1492
+ "bias": false,
1493
+ "in_features": 2048,
1494
+ "linear_cls": "QTIPLinearTCQ",
1495
+ "linear_dtype": "float32",
1496
+ "out_features": 8192,
1497
+ "td_x": 16,
1498
+ "td_y": 16,
1499
+ "tlut_bits": 9
1500
+ },
1501
+ "module_type": "IncoherentLinear",
1502
+ "out_features": 8192,
1503
+ "rot_info": "skip_r",
1504
+ "scale": 32.0
1505
+ },
1506
+ "model.layers.2.self_attn.k_proj": {
1507
+ "bias": false,
1508
+ "dtype": "float32",
1509
+ "hadU": 2048,
1510
+ "hadV": 512,
1511
+ "in_features": 2048,
1512
+ "linear": {
1513
+ "KV": [
1514
+ 8,
1515
+ 9
1516
+ ],
1517
+ "L": 16,
1518
+ "V": 2,
1519
+ "bias": false,
1520
+ "in_features": 2048,
1521
+ "in_part": [
1522
+ 1024,
1523
+ 1024
1524
+ ],
1525
+ "linear_cls": "CombtLinearTCQ",
1526
+ "linear_dtype": "float32",
1527
+ "out_features": 512,
1528
+ "td_x": 16,
1529
+ "td_y": 16,
1530
+ "tlut_bits": 10
1531
+ },
1532
+ "module_type": "IncoherentLinear",
1533
+ "out_features": 512,
1534
+ "rot_info": "skip_r",
1535
+ "scale": 32.0
1536
+ },
1537
+ "model.layers.2.self_attn.o_proj": {
1538
+ "bias": false,
1539
+ "dtype": "float32",
1540
+ "hadU": 2048,
1541
+ "hadV": 2048,
1542
+ "in_features": 2048,
1543
+ "linear": {
1544
+ "KV": 7,
1545
+ "L": 16,
1546
+ "V": 2,
1547
+ "bias": false,
1548
+ "in_features": 2048,
1549
+ "linear_cls": "QTIPLinearTCQ",
1550
+ "linear_dtype": "float32",
1551
+ "out_features": 2048,
1552
+ "td_x": 16,
1553
+ "td_y": 16,
1554
+ "tlut_bits": 9
1555
+ },
1556
+ "module_type": "IncoherentLinear",
1557
+ "out_features": 2048,
1558
+ "rot_info": "skip_r",
1559
+ "scale": 32.0
1560
+ },
1561
+ "model.layers.2.self_attn.q_proj": {
1562
+ "bias": false,
1563
+ "dtype": "float32",
1564
+ "hadU": 2048,
1565
+ "hadV": 2048,
1566
+ "in_features": 2048,
1567
+ "linear": {
1568
+ "KV": 7,
1569
+ "L": 16,
1570
+ "V": 2,
1571
+ "bias": false,
1572
+ "in_features": 2048,
1573
+ "linear_cls": "QTIPLinearTCQ",
1574
+ "linear_dtype": "float32",
1575
+ "out_features": 2048,
1576
+ "td_x": 16,
1577
+ "td_y": 16,
1578
+ "tlut_bits": 9
1579
+ },
1580
+ "module_type": "IncoherentLinear",
1581
+ "out_features": 2048,
1582
+ "rot_info": "skip_r",
1583
+ "scale": 32.0
1584
+ },
1585
+ "model.layers.2.self_attn.v_proj": {
1586
+ "bias": false,
1587
+ "dtype": "float32",
1588
+ "hadU": 2048,
1589
+ "hadV": 512,
1590
+ "in_features": 2048,
1591
+ "linear": {
1592
+ "KV": [
1593
+ 9,
1594
+ 10
1595
+ ],
1596
+ "L": 16,
1597
+ "V": 2,
1598
+ "bias": false,
1599
+ "in_features": 2048,
1600
+ "in_part": [
1601
+ 1024,
1602
+ 1024
1603
+ ],
1604
+ "linear_cls": "CombtLinearTCQ",
1605
+ "linear_dtype": "float32",
1606
+ "out_features": 512,
1607
+ "td_x": 16,
1608
+ "td_y": 16,
1609
+ "tlut_bits": 11
1610
+ },
1611
+ "module_type": "IncoherentLinear",
1612
+ "out_features": 512,
1613
+ "rot_info": "skip_r",
1614
+ "scale": 32.0
1615
+ },
1616
+ "model.layers.3.mlp.down_proj": {
1617
+ "bias": false,
1618
+ "dtype": "float32",
1619
+ "hadU": 8192,
1620
+ "hadV": 2048,
1621
+ "in_features": 8192,
1622
+ "linear": {
1623
+ "KV": 6,
1624
+ "L": 16,
1625
+ "V": 2,
1626
+ "bias": false,
1627
+ "in_features": 8192,
1628
+ "linear_cls": "QTIPLinearTCQ",
1629
+ "linear_dtype": "float32",
1630
+ "out_features": 2048,
1631
+ "td_x": 16,
1632
+ "td_y": 16,
1633
+ "tlut_bits": 9
1634
+ },
1635
+ "module_type": "IncoherentLinear",
1636
+ "out_features": 2048,
1637
+ "rot_info": "skip_r",
1638
+ "scale": 32.0
1639
+ },
1640
+ "model.layers.3.mlp.gate_proj": {
1641
+ "bias": false,
1642
+ "dtype": "float32",
1643
+ "hadU": 2048,
1644
+ "hadV": 8192,
1645
+ "in_features": 2048,
1646
+ "linear": {
1647
+ "KV": 5,
1648
+ "L": 16,
1649
+ "V": 2,
1650
+ "bias": false,
1651
+ "in_features": 2048,
1652
+ "linear_cls": "QTIPLinearTCQ",
1653
+ "linear_dtype": "float32",
1654
+ "out_features": 8192,
1655
+ "td_x": 16,
1656
+ "td_y": 16,
1657
+ "tlut_bits": 9
1658
+ },
1659
+ "module_type": "IncoherentLinear",
1660
+ "out_features": 8192,
1661
+ "rot_info": "skip_r",
1662
+ "scale": 32.0
1663
+ },
1664
+ "model.layers.3.mlp.up_proj": {
1665
+ "bias": false,
1666
+ "dtype": "float32",
1667
+ "hadU": 2048,
1668
+ "hadV": 8192,
1669
+ "in_features": 2048,
1670
+ "linear": {
1671
+ "KV": 6,
1672
+ "L": 16,
1673
+ "V": 2,
1674
+ "bias": false,
1675
+ "in_features": 2048,
1676
+ "linear_cls": "QTIPLinearTCQ",
1677
+ "linear_dtype": "float32",
1678
+ "out_features": 8192,
1679
+ "td_x": 16,
1680
+ "td_y": 16,
1681
+ "tlut_bits": 9
1682
+ },
1683
+ "module_type": "IncoherentLinear",
1684
+ "out_features": 8192,
1685
+ "rot_info": "skip_r",
1686
+ "scale": 32.0
1687
+ },
1688
+ "model.layers.3.self_attn.k_proj": {
1689
+ "bias": false,
1690
+ "dtype": "float32",
1691
+ "hadU": 2048,
1692
+ "hadV": 512,
1693
+ "in_features": 2048,
1694
+ "linear": {
1695
+ "KV": [
1696
+ 8,
1697
+ 9
1698
+ ],
1699
+ "L": 16,
1700
+ "V": 2,
1701
+ "bias": false,
1702
+ "in_features": 2048,
1703
+ "in_part": [
1704
+ 1024,
1705
+ 1024
1706
+ ],
1707
+ "linear_cls": "CombtLinearTCQ",
1708
+ "linear_dtype": "float32",
1709
+ "out_features": 512,
1710
+ "td_x": 16,
1711
+ "td_y": 16,
1712
+ "tlut_bits": 10
1713
+ },
1714
+ "module_type": "IncoherentLinear",
1715
+ "out_features": 512,
1716
+ "rot_info": "skip_r",
1717
+ "scale": 32.0
1718
+ },
1719
+ "model.layers.3.self_attn.o_proj": {
1720
+ "bias": false,
1721
+ "dtype": "float32",
1722
+ "hadU": 2048,
1723
+ "hadV": 2048,
1724
+ "in_features": 2048,
1725
+ "linear": {
1726
+ "KV": 7,
1727
+ "L": 16,
1728
+ "V": 2,
1729
+ "bias": false,
1730
+ "in_features": 2048,
1731
+ "linear_cls": "QTIPLinearTCQ",
1732
+ "linear_dtype": "float32",
1733
+ "out_features": 2048,
1734
+ "td_x": 16,
1735
+ "td_y": 16,
1736
+ "tlut_bits": 9
1737
+ },
1738
+ "module_type": "IncoherentLinear",
1739
+ "out_features": 2048,
1740
+ "rot_info": "skip_r",
1741
+ "scale": 32.0
1742
+ },
1743
+ "model.layers.3.self_attn.q_proj": {
1744
+ "bias": false,
1745
+ "dtype": "float32",
1746
+ "hadU": 2048,
1747
+ "hadV": 2048,
1748
+ "in_features": 2048,
1749
+ "linear": {
1750
+ "KV": 7,
1751
+ "L": 16,
1752
+ "V": 2,
1753
+ "bias": false,
1754
+ "in_features": 2048,
1755
+ "linear_cls": "QTIPLinearTCQ",
1756
+ "linear_dtype": "float32",
1757
+ "out_features": 2048,
1758
+ "td_x": 16,
1759
+ "td_y": 16,
1760
+ "tlut_bits": 9
1761
+ },
1762
+ "module_type": "IncoherentLinear",
1763
+ "out_features": 2048,
1764
+ "rot_info": "skip_r",
1765
+ "scale": 32.0
1766
+ },
1767
+ "model.layers.3.self_attn.v_proj": {
1768
+ "bias": false,
1769
+ "dtype": "float32",
1770
+ "hadU": 2048,
1771
+ "hadV": 512,
1772
+ "in_features": 2048,
1773
+ "linear": {
1774
+ "KV": 10,
1775
+ "L": 16,
1776
+ "V": 2,
1777
+ "bias": false,
1778
+ "in_features": 2048,
1779
+ "linear_cls": "QTIPLinearTCQ",
1780
+ "linear_dtype": "float32",
1781
+ "out_features": 512,
1782
+ "td_x": 16,
1783
+ "td_y": 16,
1784
+ "tlut_bits": 11
1785
+ },
1786
+ "module_type": "IncoherentLinear",
1787
+ "out_features": 512,
1788
+ "rot_info": "skip_r",
1789
+ "scale": 32.0
1790
+ },
1791
+ "model.layers.4.mlp.down_proj": {
1792
+ "bias": false,
1793
+ "dtype": "float32",
1794
+ "hadU": 8192,
1795
+ "hadV": 2048,
1796
+ "in_features": 8192,
1797
+ "linear": {
1798
+ "KV": 7,
1799
+ "L": 16,
1800
+ "V": 2,
1801
+ "bias": false,
1802
+ "in_features": 8192,
1803
+ "linear_cls": "QTIPLinearTCQ",
1804
+ "linear_dtype": "float32",
1805
+ "out_features": 2048,
1806
+ "td_x": 16,
1807
+ "td_y": 16,
1808
+ "tlut_bits": 9
1809
+ },
1810
+ "module_type": "IncoherentLinear",
1811
+ "out_features": 2048,
1812
+ "rot_info": "skip_r",
1813
+ "scale": 32.0
1814
+ },
1815
+ "model.layers.4.mlp.gate_proj": {
1816
+ "bias": false,
1817
+ "dtype": "float32",
1818
+ "hadU": 2048,
1819
+ "hadV": 8192,
1820
+ "in_features": 2048,
1821
+ "linear": {
1822
+ "KV": 6,
1823
+ "L": 16,
1824
+ "V": 2,
1825
+ "bias": false,
1826
+ "in_features": 2048,
1827
+ "linear_cls": "QTIPLinearTCQ",
1828
+ "linear_dtype": "float32",
1829
+ "out_features": 8192,
1830
+ "td_x": 16,
1831
+ "td_y": 16,
1832
+ "tlut_bits": 9
1833
+ },
1834
+ "module_type": "IncoherentLinear",
1835
+ "out_features": 8192,
1836
+ "rot_info": "skip_r",
1837
+ "scale": 32.0
1838
+ },
1839
+ "model.layers.4.mlp.up_proj": {
1840
+ "bias": false,
1841
+ "dtype": "float32",
1842
+ "hadU": 2048,
1843
+ "hadV": 8192,
1844
+ "in_features": 2048,
1845
+ "linear": {
1846
+ "KV": 6,
1847
+ "L": 16,
1848
+ "V": 2,
1849
+ "bias": false,
1850
+ "in_features": 2048,
1851
+ "linear_cls": "QTIPLinearTCQ",
1852
+ "linear_dtype": "float32",
1853
+ "out_features": 8192,
1854
+ "td_x": 16,
1855
+ "td_y": 16,
1856
+ "tlut_bits": 9
1857
+ },
1858
+ "module_type": "IncoherentLinear",
1859
+ "out_features": 8192,
1860
+ "rot_info": "skip_r",
1861
+ "scale": 32.0
1862
+ },
1863
+ "model.layers.4.self_attn.k_proj": {
1864
+ "bias": false,
1865
+ "dtype": "float32",
1866
+ "hadU": 2048,
1867
+ "hadV": 512,
1868
+ "in_features": 2048,
1869
+ "linear": {
1870
+ "KV": [
1871
+ 8,
1872
+ 9
1873
+ ],
1874
+ "L": 16,
1875
+ "V": 2,
1876
+ "bias": false,
1877
+ "in_features": 2048,
1878
+ "in_part": [
1879
+ 1024,
1880
+ 1024
1881
+ ],
1882
+ "linear_cls": "CombtLinearTCQ",
1883
+ "linear_dtype": "float32",
1884
+ "out_features": 512,
1885
+ "td_x": 16,
1886
+ "td_y": 16,
1887
+ "tlut_bits": 10
1888
+ },
1889
+ "module_type": "IncoherentLinear",
1890
+ "out_features": 512,
1891
+ "rot_info": "skip_r",
1892
+ "scale": 32.0
1893
+ },
1894
+ "model.layers.4.self_attn.o_proj": {
1895
+ "bias": false,
1896
+ "dtype": "float32",
1897
+ "hadU": 2048,
1898
+ "hadV": 2048,
1899
+ "in_features": 2048,
1900
+ "linear": {
1901
+ "KV": 8,
1902
+ "L": 16,
1903
+ "V": 2,
1904
+ "bias": false,
1905
+ "in_features": 2048,
1906
+ "linear_cls": "QTIPLinearTCQ",
1907
+ "linear_dtype": "float32",
1908
+ "out_features": 2048,
1909
+ "td_x": 16,
1910
+ "td_y": 16,
1911
+ "tlut_bits": 9
1912
+ },
1913
+ "module_type": "IncoherentLinear",
1914
+ "out_features": 2048,
1915
+ "rot_info": "skip_r",
1916
+ "scale": 32.0
1917
+ },
1918
+ "model.layers.4.self_attn.q_proj": {
1919
+ "bias": false,
1920
+ "dtype": "float32",
1921
+ "hadU": 2048,
1922
+ "hadV": 2048,
1923
+ "in_features": 2048,
1924
+ "linear": {
1925
+ "KV": 7,
1926
+ "L": 16,
1927
+ "V": 2,
1928
+ "bias": false,
1929
+ "in_features": 2048,
1930
+ "linear_cls": "QTIPLinearTCQ",
1931
+ "linear_dtype": "float32",
1932
+ "out_features": 2048,
1933
+ "td_x": 16,
1934
+ "td_y": 16,
1935
+ "tlut_bits": 9
1936
+ },
1937
+ "module_type": "IncoherentLinear",
1938
+ "out_features": 2048,
1939
+ "rot_info": "skip_r",
1940
+ "scale": 32.0
1941
+ },
1942
+ "model.layers.4.self_attn.v_proj": {
1943
+ "bias": false,
1944
+ "dtype": "float32",
1945
+ "hadU": 2048,
1946
+ "hadV": 512,
1947
+ "in_features": 2048,
1948
+ "linear": {
1949
+ "KV": [
1950
+ 9,
1951
+ 10
1952
+ ],
1953
+ "L": 16,
1954
+ "V": 2,
1955
+ "bias": false,
1956
+ "in_features": 2048,
1957
+ "in_part": [
1958
+ 1024,
1959
+ 1024
1960
+ ],
1961
+ "linear_cls": "CombtLinearTCQ",
1962
+ "linear_dtype": "float32",
1963
+ "out_features": 512,
1964
+ "td_x": 16,
1965
+ "td_y": 16,
1966
+ "tlut_bits": 11
1967
+ },
1968
+ "module_type": "IncoherentLinear",
1969
+ "out_features": 512,
1970
+ "rot_info": "skip_r",
1971
+ "scale": 32.0
1972
+ },
1973
+ "model.layers.5.mlp.down_proj": {
1974
+ "bias": false,
1975
+ "dtype": "float32",
1976
+ "hadU": 8192,
1977
+ "hadV": 2048,
1978
+ "in_features": 8192,
1979
+ "linear": {
1980
+ "KV": 6,
1981
+ "L": 16,
1982
+ "V": 2,
1983
+ "bias": false,
1984
+ "in_features": 8192,
1985
+ "linear_cls": "QTIPLinearTCQ",
1986
+ "linear_dtype": "float32",
1987
+ "out_features": 2048,
1988
+ "td_x": 16,
1989
+ "td_y": 16,
1990
+ "tlut_bits": 9
1991
+ },
1992
+ "module_type": "IncoherentLinear",
1993
+ "out_features": 2048,
1994
+ "rot_info": "skip_r",
1995
+ "scale": 32.0
1996
+ },
1997
+ "model.layers.5.mlp.gate_proj": {
1998
+ "bias": false,
1999
+ "dtype": "float32",
2000
+ "hadU": 2048,
2001
+ "hadV": 8192,
2002
+ "in_features": 2048,
2003
+ "linear": {
2004
+ "KV": 6,
2005
+ "L": 16,
2006
+ "V": 2,
2007
+ "bias": false,
2008
+ "in_features": 2048,
2009
+ "linear_cls": "QTIPLinearTCQ",
2010
+ "linear_dtype": "float32",
2011
+ "out_features": 8192,
2012
+ "td_x": 16,
2013
+ "td_y": 16,
2014
+ "tlut_bits": 9
2015
+ },
2016
+ "module_type": "IncoherentLinear",
2017
+ "out_features": 8192,
2018
+ "rot_info": "skip_r",
2019
+ "scale": 32.0
2020
+ },
2021
+ "model.layers.5.mlp.up_proj": {
2022
+ "bias": false,
2023
+ "dtype": "float32",
2024
+ "hadU": 2048,
2025
+ "hadV": 8192,
2026
+ "in_features": 2048,
2027
+ "linear": {
2028
+ "KV": 6,
2029
+ "L": 16,
2030
+ "V": 2,
2031
+ "bias": false,
2032
+ "in_features": 2048,
2033
+ "linear_cls": "QTIPLinearTCQ",
2034
+ "linear_dtype": "float32",
2035
+ "out_features": 8192,
2036
+ "td_x": 16,
2037
+ "td_y": 16,
2038
+ "tlut_bits": 9
2039
+ },
2040
+ "module_type": "IncoherentLinear",
2041
+ "out_features": 8192,
2042
+ "rot_info": "skip_r",
2043
+ "scale": 32.0
2044
+ },
2045
+ "model.layers.5.self_attn.k_proj": {
2046
+ "bias": false,
2047
+ "dtype": "float32",
2048
+ "hadU": 2048,
2049
+ "hadV": 512,
2050
+ "in_features": 2048,
2051
+ "linear": {
2052
+ "KV": 9,
2053
+ "L": 16,
2054
+ "V": 2,
2055
+ "bias": false,
2056
+ "in_features": 2048,
2057
+ "linear_cls": "QTIPLinearTCQ",
2058
+ "linear_dtype": "float32",
2059
+ "out_features": 512,
2060
+ "td_x": 16,
2061
+ "td_y": 16,
2062
+ "tlut_bits": 10
2063
+ },
2064
+ "module_type": "IncoherentLinear",
2065
+ "out_features": 512,
2066
+ "rot_info": "skip_r",
2067
+ "scale": 32.0
2068
+ },
2069
+ "model.layers.5.self_attn.o_proj": {
2070
+ "bias": false,
2071
+ "dtype": "float32",
2072
+ "hadU": 2048,
2073
+ "hadV": 2048,
2074
+ "in_features": 2048,
2075
+ "linear": {
2076
+ "KV": [
2077
+ 8,
2078
+ 9
2079
+ ],
2080
+ "L": 16,
2081
+ "V": 2,
2082
+ "bias": false,
2083
+ "in_features": 2048,
2084
+ "in_part": [
2085
+ 1024,
2086
+ 1024
2087
+ ],
2088
+ "linear_cls": "CombtLinearTCQ",
2089
+ "linear_dtype": "float32",
2090
+ "out_features": 2048,
2091
+ "td_x": 16,
2092
+ "td_y": 16,
2093
+ "tlut_bits": 10
2094
+ },
2095
+ "module_type": "IncoherentLinear",
2096
+ "out_features": 2048,
2097
+ "rot_info": "skip_r",
2098
+ "scale": 32.0
2099
+ },
2100
+ "model.layers.5.self_attn.q_proj": {
2101
+ "bias": false,
2102
+ "dtype": "float32",
2103
+ "hadU": 2048,
2104
+ "hadV": 2048,
2105
+ "in_features": 2048,
2106
+ "linear": {
2107
+ "KV": 8,
2108
+ "L": 16,
2109
+ "V": 2,
2110
+ "bias": false,
2111
+ "in_features": 2048,
2112
+ "linear_cls": "QTIPLinearTCQ",
2113
+ "linear_dtype": "float32",
2114
+ "out_features": 2048,
2115
+ "td_x": 16,
2116
+ "td_y": 16,
2117
+ "tlut_bits": 9
2118
+ },
2119
+ "module_type": "IncoherentLinear",
2120
+ "out_features": 2048,
2121
+ "rot_info": "skip_r",
2122
+ "scale": 32.0
2123
+ },
2124
+ "model.layers.5.self_attn.v_proj": {
2125
+ "bias": false,
2126
+ "dtype": "float32",
2127
+ "hadU": 2048,
2128
+ "hadV": 512,
2129
+ "in_features": 2048,
2130
+ "linear": {
2131
+ "KV": [
2132
+ 9,
2133
+ 10
2134
+ ],
2135
+ "L": 16,
2136
+ "V": 2,
2137
+ "bias": false,
2138
+ "in_features": 2048,
2139
+ "in_part": [
2140
+ 1024,
2141
+ 1024
2142
+ ],
2143
+ "linear_cls": "CombtLinearTCQ",
2144
+ "linear_dtype": "float32",
2145
+ "out_features": 512,
2146
+ "td_x": 16,
2147
+ "td_y": 16,
2148
+ "tlut_bits": 11
2149
+ },
2150
+ "module_type": "IncoherentLinear",
2151
+ "out_features": 512,
2152
+ "rot_info": "skip_r",
2153
+ "scale": 32.0
2154
+ },
2155
+ "model.layers.6.mlp.down_proj": {
2156
+ "bias": false,
2157
+ "dtype": "float32",
2158
+ "hadU": 8192,
2159
+ "hadV": 2048,
2160
+ "in_features": 8192,
2161
+ "linear": {
2162
+ "KV": 6,
2163
+ "L": 16,
2164
+ "V": 2,
2165
+ "bias": false,
2166
+ "in_features": 8192,
2167
+ "linear_cls": "QTIPLinearTCQ",
2168
+ "linear_dtype": "float32",
2169
+ "out_features": 2048,
2170
+ "td_x": 16,
2171
+ "td_y": 16,
2172
+ "tlut_bits": 9
2173
+ },
2174
+ "module_type": "IncoherentLinear",
2175
+ "out_features": 2048,
2176
+ "rot_info": "skip_r",
2177
+ "scale": 32.0
2178
+ },
2179
+ "model.layers.6.mlp.gate_proj": {
2180
+ "bias": false,
2181
+ "dtype": "float32",
2182
+ "hadU": 2048,
2183
+ "hadV": 8192,
2184
+ "in_features": 2048,
2185
+ "linear": {
2186
+ "KV": 6,
2187
+ "L": 16,
2188
+ "V": 2,
2189
+ "bias": false,
2190
+ "in_features": 2048,
2191
+ "linear_cls": "QTIPLinearTCQ",
2192
+ "linear_dtype": "float32",
2193
+ "out_features": 8192,
2194
+ "td_x": 16,
2195
+ "td_y": 16,
2196
+ "tlut_bits": 9
2197
+ },
2198
+ "module_type": "IncoherentLinear",
2199
+ "out_features": 8192,
2200
+ "rot_info": "skip_r",
2201
+ "scale": 32.0
2202
+ },
2203
+ "model.layers.6.mlp.up_proj": {
2204
+ "bias": false,
2205
+ "dtype": "float32",
2206
+ "hadU": 2048,
2207
+ "hadV": 8192,
2208
+ "in_features": 2048,
2209
+ "linear": {
2210
+ "KV": 6,
2211
+ "L": 16,
2212
+ "V": 2,
2213
+ "bias": false,
2214
+ "in_features": 2048,
2215
+ "linear_cls": "QTIPLinearTCQ",
2216
+ "linear_dtype": "float32",
2217
+ "out_features": 8192,
2218
+ "td_x": 16,
2219
+ "td_y": 16,
2220
+ "tlut_bits": 9
2221
+ },
2222
+ "module_type": "IncoherentLinear",
2223
+ "out_features": 8192,
2224
+ "rot_info": "skip_r",
2225
+ "scale": 32.0
2226
+ },
2227
+ "model.layers.6.self_attn.k_proj": {
2228
+ "bias": false,
2229
+ "dtype": "float32",
2230
+ "hadU": 2048,
2231
+ "hadV": 512,
2232
+ "in_features": 2048,
2233
+ "linear": {
2234
+ "KV": [
2235
+ 8,
2236
+ 9
2237
+ ],
2238
+ "L": 16,
2239
+ "V": 2,
2240
+ "bias": false,
2241
+ "in_features": 2048,
2242
+ "in_part": [
2243
+ 1024,
2244
+ 1024
2245
+ ],
2246
+ "linear_cls": "CombtLinearTCQ",
2247
+ "linear_dtype": "float32",
2248
+ "out_features": 512,
2249
+ "td_x": 16,
2250
+ "td_y": 16,
2251
+ "tlut_bits": 10
2252
+ },
2253
+ "module_type": "IncoherentLinear",
2254
+ "out_features": 512,
2255
+ "rot_info": "skip_r",
2256
+ "scale": 32.0
2257
+ },
2258
+ "model.layers.6.self_attn.o_proj": {
2259
+ "bias": false,
2260
+ "dtype": "float32",
2261
+ "hadU": 2048,
2262
+ "hadV": 2048,
2263
+ "in_features": 2048,
2264
+ "linear": {
2265
+ "KV": 9,
2266
+ "L": 16,
2267
+ "V": 2,
2268
+ "bias": false,
2269
+ "in_features": 2048,
2270
+ "linear_cls": "QTIPLinearTCQ",
2271
+ "linear_dtype": "float32",
2272
+ "out_features": 2048,
2273
+ "td_x": 16,
2274
+ "td_y": 16,
2275
+ "tlut_bits": 10
2276
+ },
2277
+ "module_type": "IncoherentLinear",
2278
+ "out_features": 2048,
2279
+ "rot_info": "skip_r",
2280
+ "scale": 32.0
2281
+ },
2282
+ "model.layers.6.self_attn.q_proj": {
2283
+ "bias": false,
2284
+ "dtype": "float32",
2285
+ "hadU": 2048,
2286
+ "hadV": 2048,
2287
+ "in_features": 2048,
2288
+ "linear": {
2289
+ "KV": 7,
2290
+ "L": 16,
2291
+ "V": 2,
2292
+ "bias": false,
2293
+ "in_features": 2048,
2294
+ "linear_cls": "QTIPLinearTCQ",
2295
+ "linear_dtype": "float32",
2296
+ "out_features": 2048,
2297
+ "td_x": 16,
2298
+ "td_y": 16,
2299
+ "tlut_bits": 9
2300
+ },
2301
+ "module_type": "IncoherentLinear",
2302
+ "out_features": 2048,
2303
+ "rot_info": "skip_r",
2304
+ "scale": 32.0
2305
+ },
2306
+ "model.layers.6.self_attn.v_proj": {
2307
+ "bias": false,
2308
+ "dtype": "float32",
2309
+ "hadU": 2048,
2310
+ "hadV": 512,
2311
+ "in_features": 2048,
2312
+ "linear": {
2313
+ "KV": [
2314
+ 9,
2315
+ 10
2316
+ ],
2317
+ "L": 16,
2318
+ "V": 2,
2319
+ "bias": false,
2320
+ "in_features": 2048,
2321
+ "in_part": [
2322
+ 1024,
2323
+ 1024
2324
+ ],
2325
+ "linear_cls": "CombtLinearTCQ",
2326
+ "linear_dtype": "float32",
2327
+ "out_features": 512,
2328
+ "td_x": 16,
2329
+ "td_y": 16,
2330
+ "tlut_bits": 11
2331
+ },
2332
+ "module_type": "IncoherentLinear",
2333
+ "out_features": 512,
2334
+ "rot_info": "skip_r",
2335
+ "scale": 32.0
2336
+ },
2337
+ "model.layers.7.mlp.down_proj": {
2338
+ "bias": false,
2339
+ "dtype": "float32",
2340
+ "hadU": 8192,
2341
+ "hadV": 2048,
2342
+ "in_features": 8192,
2343
+ "linear": {
2344
+ "KV": [
2345
+ 6,
2346
+ 7
2347
+ ],
2348
+ "L": 16,
2349
+ "V": 2,
2350
+ "bias": false,
2351
+ "in_features": 8192,
2352
+ "in_part": [
2353
+ 4096,
2354
+ 4096
2355
+ ],
2356
+ "linear_cls": "CombtLinearTCQ",
2357
+ "linear_dtype": "float32",
2358
+ "out_features": 2048,
2359
+ "td_x": 16,
2360
+ "td_y": 16,
2361
+ "tlut_bits": 9
2362
+ },
2363
+ "module_type": "IncoherentLinear",
2364
+ "out_features": 2048,
2365
+ "rot_info": "skip_r",
2366
+ "scale": 32.0
2367
+ },
2368
+ "model.layers.7.mlp.gate_proj": {
2369
+ "bias": false,
2370
+ "dtype": "float32",
2371
+ "hadU": 2048,
2372
+ "hadV": 8192,
2373
+ "in_features": 2048,
2374
+ "linear": {
2375
+ "KV": 6,
2376
+ "L": 16,
2377
+ "V": 2,
2378
+ "bias": false,
2379
+ "in_features": 2048,
2380
+ "linear_cls": "QTIPLinearTCQ",
2381
+ "linear_dtype": "float32",
2382
+ "out_features": 8192,
2383
+ "td_x": 16,
2384
+ "td_y": 16,
2385
+ "tlut_bits": 9
2386
+ },
2387
+ "module_type": "IncoherentLinear",
2388
+ "out_features": 8192,
2389
+ "rot_info": "skip_r",
2390
+ "scale": 32.0
2391
+ },
2392
+ "model.layers.7.mlp.up_proj": {
2393
+ "bias": false,
2394
+ "dtype": "float32",
2395
+ "hadU": 2048,
2396
+ "hadV": 8192,
2397
+ "in_features": 2048,
2398
+ "linear": {
2399
+ "KV": 6,
2400
+ "L": 16,
2401
+ "V": 2,
2402
+ "bias": false,
2403
+ "in_features": 2048,
2404
+ "linear_cls": "QTIPLinearTCQ",
2405
+ "linear_dtype": "float32",
2406
+ "out_features": 8192,
2407
+ "td_x": 16,
2408
+ "td_y": 16,
2409
+ "tlut_bits": 9
2410
+ },
2411
+ "module_type": "IncoherentLinear",
2412
+ "out_features": 8192,
2413
+ "rot_info": "skip_r",
2414
+ "scale": 32.0
2415
+ },
2416
+ "model.layers.7.self_attn.k_proj": {
2417
+ "bias": false,
2418
+ "dtype": "float32",
2419
+ "hadU": 2048,
2420
+ "hadV": 512,
2421
+ "in_features": 2048,
2422
+ "linear": {
2423
+ "KV": [
2424
+ 8,
2425
+ 9
2426
+ ],
2427
+ "L": 16,
2428
+ "V": 2,
2429
+ "bias": false,
2430
+ "in_features": 2048,
2431
+ "in_part": [
2432
+ 1024,
2433
+ 1024
2434
+ ],
2435
+ "linear_cls": "CombtLinearTCQ",
2436
+ "linear_dtype": "float32",
2437
+ "out_features": 512,
2438
+ "td_x": 16,
2439
+ "td_y": 16,
2440
+ "tlut_bits": 10
2441
+ },
2442
+ "module_type": "IncoherentLinear",
2443
+ "out_features": 512,
2444
+ "rot_info": "skip_r",
2445
+ "scale": 32.0
2446
+ },
2447
+ "model.layers.7.self_attn.o_proj": {
2448
+ "bias": false,
2449
+ "dtype": "float32",
2450
+ "hadU": 2048,
2451
+ "hadV": 2048,
2452
+ "in_features": 2048,
2453
+ "linear": {
2454
+ "KV": 9,
2455
+ "L": 16,
2456
+ "V": 2,
2457
+ "bias": false,
2458
+ "in_features": 2048,
2459
+ "linear_cls": "QTIPLinearTCQ",
2460
+ "linear_dtype": "float32",
2461
+ "out_features": 2048,
2462
+ "td_x": 16,
2463
+ "td_y": 16,
2464
+ "tlut_bits": 10
2465
+ },
2466
+ "module_type": "IncoherentLinear",
2467
+ "out_features": 2048,
2468
+ "rot_info": "skip_r",
2469
+ "scale": 32.0
2470
+ },
2471
+ "model.layers.7.self_attn.q_proj": {
2472
+ "bias": false,
2473
+ "dtype": "float32",
2474
+ "hadU": 2048,
2475
+ "hadV": 2048,
2476
+ "in_features": 2048,
2477
+ "linear": {
2478
+ "KV": 7,
2479
+ "L": 16,
2480
+ "V": 2,
2481
+ "bias": false,
2482
+ "in_features": 2048,
2483
+ "linear_cls": "QTIPLinearTCQ",
2484
+ "linear_dtype": "float32",
2485
+ "out_features": 2048,
2486
+ "td_x": 16,
2487
+ "td_y": 16,
2488
+ "tlut_bits": 9
2489
+ },
2490
+ "module_type": "IncoherentLinear",
2491
+ "out_features": 2048,
2492
+ "rot_info": "skip_r",
2493
+ "scale": 32.0
2494
+ },
2495
+ "model.layers.7.self_attn.v_proj": {
2496
+ "bias": false,
2497
+ "dtype": "float32",
2498
+ "hadU": 2048,
2499
+ "hadV": 512,
2500
+ "in_features": 2048,
2501
+ "linear": {
2502
+ "KV": 10,
2503
+ "L": 16,
2504
+ "V": 2,
2505
+ "bias": false,
2506
+ "in_features": 2048,
2507
+ "linear_cls": "QTIPLinearTCQ",
2508
+ "linear_dtype": "float32",
2509
+ "out_features": 512,
2510
+ "td_x": 16,
2511
+ "td_y": 16,
2512
+ "tlut_bits": 11
2513
+ },
2514
+ "module_type": "IncoherentLinear",
2515
+ "out_features": 512,
2516
+ "rot_info": "skip_r",
2517
+ "scale": 32.0
2518
+ },
2519
+ "model.layers.8.mlp.down_proj": {
2520
+ "bias": false,
2521
+ "dtype": "float32",
2522
+ "hadU": 8192,
2523
+ "hadV": 2048,
2524
+ "in_features": 8192,
2525
+ "linear": {
2526
+ "KV": 7,
2527
+ "L": 16,
2528
+ "V": 2,
2529
+ "bias": false,
2530
+ "in_features": 8192,
2531
+ "linear_cls": "QTIPLinearTCQ",
2532
+ "linear_dtype": "float32",
2533
+ "out_features": 2048,
2534
+ "td_x": 16,
2535
+ "td_y": 16,
2536
+ "tlut_bits": 9
2537
+ },
2538
+ "module_type": "IncoherentLinear",
2539
+ "out_features": 2048,
2540
+ "rot_info": "skip_r",
2541
+ "scale": 32.0
2542
+ },
2543
+ "model.layers.8.mlp.gate_proj": {
2544
+ "bias": false,
2545
+ "dtype": "float32",
2546
+ "hadU": 2048,
2547
+ "hadV": 8192,
2548
+ "in_features": 2048,
2549
+ "linear": {
2550
+ "KV": 6,
2551
+ "L": 16,
2552
+ "V": 2,
2553
+ "bias": false,
2554
+ "in_features": 2048,
2555
+ "linear_cls": "QTIPLinearTCQ",
2556
+ "linear_dtype": "float32",
2557
+ "out_features": 8192,
2558
+ "td_x": 16,
2559
+ "td_y": 16,
2560
+ "tlut_bits": 9
2561
+ },
2562
+ "module_type": "IncoherentLinear",
2563
+ "out_features": 8192,
2564
+ "rot_info": "skip_r",
2565
+ "scale": 32.0
2566
+ },
2567
+ "model.layers.8.mlp.up_proj": {
2568
+ "bias": false,
2569
+ "dtype": "float32",
2570
+ "hadU": 2048,
2571
+ "hadV": 8192,
2572
+ "in_features": 2048,
2573
+ "linear": {
2574
+ "KV": 7,
2575
+ "L": 16,
2576
+ "V": 2,
2577
+ "bias": false,
2578
+ "in_features": 2048,
2579
+ "linear_cls": "QTIPLinearTCQ",
2580
+ "linear_dtype": "float32",
2581
+ "out_features": 8192,
2582
+ "td_x": 16,
2583
+ "td_y": 16,
2584
+ "tlut_bits": 9
2585
+ },
2586
+ "module_type": "IncoherentLinear",
2587
+ "out_features": 8192,
2588
+ "rot_info": "skip_r",
2589
+ "scale": 32.0
2590
+ },
2591
+ "model.layers.8.self_attn.k_proj": {
2592
+ "bias": false,
2593
+ "dtype": "float32",
2594
+ "hadU": 2048,
2595
+ "hadV": 512,
2596
+ "in_features": 2048,
2597
+ "linear": {
2598
+ "KV": [
2599
+ 8,
2600
+ 9
2601
+ ],
2602
+ "L": 16,
2603
+ "V": 2,
2604
+ "bias": false,
2605
+ "in_features": 2048,
2606
+ "in_part": [
2607
+ 1024,
2608
+ 1024
2609
+ ],
2610
+ "linear_cls": "CombtLinearTCQ",
2611
+ "linear_dtype": "float32",
2612
+ "out_features": 512,
2613
+ "td_x": 16,
2614
+ "td_y": 16,
2615
+ "tlut_bits": 10
2616
+ },
2617
+ "module_type": "IncoherentLinear",
2618
+ "out_features": 512,
2619
+ "rot_info": "skip_r",
2620
+ "scale": 32.0
2621
+ },
2622
+ "model.layers.8.self_attn.o_proj": {
2623
+ "bias": false,
2624
+ "dtype": "float32",
2625
+ "hadU": 2048,
2626
+ "hadV": 2048,
2627
+ "in_features": 2048,
2628
+ "linear": {
2629
+ "KV": 9,
2630
+ "L": 16,
2631
+ "V": 2,
2632
+ "bias": false,
2633
+ "in_features": 2048,
2634
+ "linear_cls": "QTIPLinearTCQ",
2635
+ "linear_dtype": "float32",
2636
+ "out_features": 2048,
2637
+ "td_x": 16,
2638
+ "td_y": 16,
2639
+ "tlut_bits": 10
2640
+ },
2641
+ "module_type": "IncoherentLinear",
2642
+ "out_features": 2048,
2643
+ "rot_info": "skip_r",
2644
+ "scale": 32.0
2645
+ },
2646
+ "model.layers.8.self_attn.q_proj": {
2647
+ "bias": false,
2648
+ "dtype": "float32",
2649
+ "hadU": 2048,
2650
+ "hadV": 2048,
2651
+ "in_features": 2048,
2652
+ "linear": {
2653
+ "KV": 7,
2654
+ "L": 16,
2655
+ "V": 2,
2656
+ "bias": false,
2657
+ "in_features": 2048,
2658
+ "linear_cls": "QTIPLinearTCQ",
2659
+ "linear_dtype": "float32",
2660
+ "out_features": 2048,
2661
+ "td_x": 16,
2662
+ "td_y": 16,
2663
+ "tlut_bits": 9
2664
+ },
2665
+ "module_type": "IncoherentLinear",
2666
+ "out_features": 2048,
2667
+ "rot_info": "skip_r",
2668
+ "scale": 32.0
2669
+ },
2670
+ "model.layers.8.self_attn.v_proj": {
2671
+ "bias": false,
2672
+ "dtype": "float32",
2673
+ "hadU": 2048,
2674
+ "hadV": 512,
2675
+ "in_features": 2048,
2676
+ "linear": {
2677
+ "KV": 10,
2678
+ "L": 16,
2679
+ "V": 2,
2680
+ "bias": false,
2681
+ "in_features": 2048,
2682
+ "linear_cls": "QTIPLinearTCQ",
2683
+ "linear_dtype": "float32",
2684
+ "out_features": 512,
2685
+ "td_x": 16,
2686
+ "td_y": 16,
2687
+ "tlut_bits": 11
2688
+ },
2689
+ "module_type": "IncoherentLinear",
2690
+ "out_features": 512,
2691
+ "rot_info": "skip_r",
2692
+ "scale": 32.0
2693
+ },
2694
+ "model.layers.9.mlp.down_proj": {
2695
+ "bias": false,
2696
+ "dtype": "float32",
2697
+ "hadU": 8192,
2698
+ "hadV": 2048,
2699
+ "in_features": 8192,
2700
+ "linear": {
2701
+ "KV": 7,
2702
+ "L": 16,
2703
+ "V": 2,
2704
+ "bias": false,
2705
+ "in_features": 8192,
2706
+ "linear_cls": "QTIPLinearTCQ",
2707
+ "linear_dtype": "float32",
2708
+ "out_features": 2048,
2709
+ "td_x": 16,
2710
+ "td_y": 16,
2711
+ "tlut_bits": 9
2712
+ },
2713
+ "module_type": "IncoherentLinear",
2714
+ "out_features": 2048,
2715
+ "rot_info": "skip_r",
2716
+ "scale": 32.0
2717
+ },
2718
+ "model.layers.9.mlp.gate_proj": {
2719
+ "bias": false,
2720
+ "dtype": "float32",
2721
+ "hadU": 2048,
2722
+ "hadV": 8192,
2723
+ "in_features": 2048,
2724
+ "linear": {
2725
+ "KV": 6,
2726
+ "L": 16,
2727
+ "V": 2,
2728
+ "bias": false,
2729
+ "in_features": 2048,
2730
+ "linear_cls": "QTIPLinearTCQ",
2731
+ "linear_dtype": "float32",
2732
+ "out_features": 8192,
2733
+ "td_x": 16,
2734
+ "td_y": 16,
2735
+ "tlut_bits": 9
2736
+ },
2737
+ "module_type": "IncoherentLinear",
2738
+ "out_features": 8192,
2739
+ "rot_info": "skip_r",
2740
+ "scale": 32.0
2741
+ },
2742
+ "model.layers.9.mlp.up_proj": {
2743
+ "bias": false,
2744
+ "dtype": "float32",
2745
+ "hadU": 2048,
2746
+ "hadV": 8192,
2747
+ "in_features": 2048,
2748
+ "linear": {
2749
+ "KV": 7,
2750
+ "L": 16,
2751
+ "V": 2,
2752
+ "bias": false,
2753
+ "in_features": 2048,
2754
+ "linear_cls": "QTIPLinearTCQ",
2755
+ "linear_dtype": "float32",
2756
+ "out_features": 8192,
2757
+ "td_x": 16,
2758
+ "td_y": 16,
2759
+ "tlut_bits": 9
2760
+ },
2761
+ "module_type": "IncoherentLinear",
2762
+ "out_features": 8192,
2763
+ "rot_info": "skip_r",
2764
+ "scale": 32.0
2765
+ },
2766
+ "model.layers.9.self_attn.k_proj": {
2767
+ "bias": false,
2768
+ "dtype": "float32",
2769
+ "hadU": 2048,
2770
+ "hadV": 512,
2771
+ "in_features": 2048,
2772
+ "linear": {
2773
+ "KV": 9,
2774
+ "L": 16,
2775
+ "V": 2,
2776
+ "bias": false,
2777
+ "in_features": 2048,
2778
+ "linear_cls": "QTIPLinearTCQ",
2779
+ "linear_dtype": "float32",
2780
+ "out_features": 512,
2781
+ "td_x": 16,
2782
+ "td_y": 16,
2783
+ "tlut_bits": 10
2784
+ },
2785
+ "module_type": "IncoherentLinear",
2786
+ "out_features": 512,
2787
+ "rot_info": "skip_r",
2788
+ "scale": 32.0
2789
+ },
2790
+ "model.layers.9.self_attn.o_proj": {
2791
+ "bias": false,
2792
+ "dtype": "float32",
2793
+ "hadU": 2048,
2794
+ "hadV": 2048,
2795
+ "in_features": 2048,
2796
+ "linear": {
2797
+ "KV": [
2798
+ 8,
2799
+ 9
2800
+ ],
2801
+ "L": 16,
2802
+ "V": 2,
2803
+ "bias": false,
2804
+ "in_features": 2048,
2805
+ "in_part": [
2806
+ 1024,
2807
+ 1024
2808
+ ],
2809
+ "linear_cls": "CombtLinearTCQ",
2810
+ "linear_dtype": "float32",
2811
+ "out_features": 2048,
2812
+ "td_x": 16,
2813
+ "td_y": 16,
2814
+ "tlut_bits": 10
2815
+ },
2816
+ "module_type": "IncoherentLinear",
2817
+ "out_features": 2048,
2818
+ "rot_info": "skip_r",
2819
+ "scale": 32.0
2820
+ },
2821
+ "model.layers.9.self_attn.q_proj": {
2822
+ "bias": false,
2823
+ "dtype": "float32",
2824
+ "hadU": 2048,
2825
+ "hadV": 2048,
2826
+ "in_features": 2048,
2827
+ "linear": {
2828
+ "KV": 7,
2829
+ "L": 16,
2830
+ "V": 2,
2831
+ "bias": false,
2832
+ "in_features": 2048,
2833
+ "linear_cls": "QTIPLinearTCQ",
2834
+ "linear_dtype": "float32",
2835
+ "out_features": 2048,
2836
+ "td_x": 16,
2837
+ "td_y": 16,
2838
+ "tlut_bits": 9
2839
+ },
2840
+ "module_type": "IncoherentLinear",
2841
+ "out_features": 2048,
2842
+ "rot_info": "skip_r",
2843
+ "scale": 32.0
2844
+ },
2845
+ "model.layers.9.self_attn.v_proj": {
2846
+ "bias": false,
2847
+ "dtype": "float32",
2848
+ "hadU": 2048,
2849
+ "hadV": 512,
2850
+ "in_features": 2048,
2851
+ "linear": {
2852
+ "KV": 10,
2853
+ "L": 16,
2854
+ "V": 2,
2855
+ "bias": false,
2856
+ "in_features": 2048,
2857
+ "linear_cls": "QTIPLinearTCQ",
2858
+ "linear_dtype": "float32",
2859
+ "out_features": 512,
2860
+ "td_x": 16,
2861
+ "td_y": 16,
2862
+ "tlut_bits": 11
2863
+ },
2864
+ "module_type": "IncoherentLinear",
2865
+ "out_features": 512,
2866
+ "rot_info": "skip_r",
2867
+ "scale": 32.0
2868
+ }
2869
+ }
2870
+ },
2871
+ "rms_norm_eps": 1e-05,
2872
+ "rope_scaling": {
2873
+ "factor": 32.0,
2874
+ "high_freq_factor": 4.0,
2875
+ "low_freq_factor": 1.0,
2876
+ "original_max_position_embeddings": 8192,
2877
+ "rope_type": "llama3"
2878
+ },
2879
+ "rope_theta": 500000.0,
2880
+ "tie_word_embeddings": true,
2881
+ "torch_dtype": "float16",
2882
+ "transformers_version": "4.45.2",
2883
+ "use_cache": true,
2884
+ "vocab_size": 128256
2885
+ }
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "do_sample": true,
5
+ "eos_token_id": 128001,
6
+ "temperature": 0.6,
7
+ "top_p": 0.9,
8
+ "transformers_version": "4.45.2"
9
+ }
lib/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
lib/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (158 Bytes). View file
 
lib/__pycache__/config.cpython-311.pyc ADDED
Binary file (318 Bytes). View file
 
lib/algo/__init__.py ADDED
File without changes
lib/algo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (163 Bytes). View file
 
lib/algo/__pycache__/ldlq.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
lib/algo/ldlq.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import glog
5
+ import torch
6
+ from tqdm import tqdm
7
+ import time
8
+ from lib import utils
9
+
10
+ _PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
11
+ 4).flatten()
12
+ _INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
13
+ _INV_PERMUTE[_PERMUTE] = torch.arange(256)
14
+
15
+
16
+ def LDLQ_VQ(Wr, L, cb, buf_cols=128):
17
+ buf_cols = max(buf_cols, cb.vec_sz)
18
+ (m, n) = Wr.shape
19
+ assert buf_cols % cb.vec_sz == 0
20
+ assert n % buf_cols == 0
21
+ buf_size = buf_cols // cb.vec_sz
22
+
23
+ hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
24
+ Qidxs_T = torch.zeros(n // cb.vec_sz, m, dtype=cb.idx_dtype, device=L.device)
25
+
26
+ device = Wr.device
27
+ Wr = Wr.cpu()
28
+ utils.clean()
29
+ Wr_T = Wr.T.contiguous().to(device)
30
+
31
+ prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
32
+ for cur_col in tqdm(range(n // cb.vec_sz, 0, -buf_size)):
33
+ b_Wr_T = Wr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz * cur_col]
34
+ b_hatWr_T = hatWr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
35
+ cur_col]
36
+ b_L = L[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
37
+ cur_col].contiguous()
38
+ b_prod = prod_cache[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
39
+ cur_col]
40
+ b_Qidxs_T = Qidxs_T[(cur_col - buf_size):cur_col]
41
+ L_offset = cb.vec_sz * (cur_col - buf_size)
42
+ for i in reversed(range(buf_size)):
43
+ WXWX = b_Wr_T[cb.vec_sz * i : cb.vec_sz * (i + 1)] + \
44
+ b_L[cb.vec_sz * (i + 1):, L_offset + cb.vec_sz * i : L_offset + cb.vec_sz * (i + 1)].T @ \
45
+ (b_Wr_T[cb.vec_sz * (i + 1):] - b_hatWr_T[cb.vec_sz * (i + 1):]) + \
46
+ b_prod[cb.vec_sz * i : cb.vec_sz * (i + 1)]
47
+
48
+ q_out = cb.quantize(WXWX.T)
49
+ b_hatWr_T[cb.vec_sz * i:cb.vec_sz * (i + 1)] = q_out[0].T
50
+ b_Qidxs_T[i:(i + 1)] = q_out[1].T
51
+
52
+ prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
53
+ hatWr_T[cb.vec_sz * (cur_col - buf_size):cb.vec_sz *
54
+ cur_col] = b_hatWr_T
55
+
56
+ del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
57
+ utils.clean()
58
+ return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
59
+
60
+
61
+ def LDLQ(Wr, L, cb, args, buf_cols=128, for_kernel=True):
62
+ if for_kernel:
63
+ assert args.td_x == 16 and args.td_y == 16
64
+ buf_cols = max(buf_cols, args.td_y)
65
+ trellissz = args.td_x * args.td_y
66
+ (m, n) = Wr.shape
67
+ assert buf_cols % args.td_y == 0
68
+ assert n % buf_cols == 0
69
+ assert args.td_y % args.V == 0
70
+ buf_size = buf_cols // args.td_y
71
+
72
+ hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
73
+ Qidxs_T = torch.zeros(n // args.V, m, dtype=cb.idx_dtype, device=L.device)
74
+
75
+ device = Wr.device
76
+ Wr = Wr.cpu()
77
+ utils.clean()
78
+ Wr_T = Wr.T.contiguous().to(device)
79
+
80
+ # quip
81
+ prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
82
+ for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
83
+ b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
84
+ b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
85
+ cur_col]
86
+ b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
87
+ cur_col].contiguous()
88
+ b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
89
+ cur_col]
90
+ b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
91
+ args.V:args.td_y * cur_col // args.V]
92
+ L_offset = args.td_y * (cur_col - buf_size)
93
+ for i in reversed(range(buf_size)):
94
+ WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
95
+ b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
96
+ (b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
97
+ b_prod[args.td_y * i : args.td_y * (i + 1)]
98
+ if trellissz > -1:
99
+ WXWXshape = WXWX.shape
100
+ thing = WXWX.T.reshape(-1, trellissz)
101
+ if for_kernel:
102
+ thing = thing[..., _PERMUTE]
103
+ q_out = cb.quantize(thing)
104
+ if for_kernel:
105
+ thing = q_out[0][..., _INV_PERMUTE].reshape(
106
+ WXWXshape[1], WXWXshape[0])
107
+ else:
108
+ thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
109
+ idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
110
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
111
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
112
+ (i + 1)] = idxs.T
113
+ else:
114
+ q_out = cb.quantize(WXWX.T)
115
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
116
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
117
+ (i + 1)] = q_out[1].T
118
+
119
+ prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
120
+ hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
121
+ cur_col] = b_hatWr_T
122
+
123
+ del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
124
+ utils.clean()
125
+ return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
126
+
127
+
128
+ def LDLQ_combt(Wr, L, cb1, cb2, args, buf_cols=128, for_kernel=True):
129
+ if for_kernel:
130
+ assert args.td_x == 16 and args.td_y == 16
131
+ buf_cols = max(buf_cols, args.td_y)
132
+ trellissz = args.td_x * args.td_y
133
+ (m, n) = Wr.shape
134
+ assert buf_cols % args.td_y == 0
135
+ assert n % buf_cols == 0
136
+ assert args.td_y % args.V == 0
137
+ buf_size = buf_cols // args.td_y
138
+
139
+ hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
140
+ Qidxs_T = torch.zeros(n // args.V, m, dtype=cb1.idx_dtype, device=L.device)
141
+
142
+ device = Wr.device
143
+ Wr = Wr.cpu()
144
+ utils.clean()
145
+ Wr_T = Wr.T.contiguous().to(device)
146
+
147
+ # quip
148
+ prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
149
+
150
+ flag_for_cb1_compile = True
151
+ for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
152
+ b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
153
+ b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
154
+ cur_col]
155
+ b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
156
+ cur_col].contiguous()
157
+ b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
158
+ cur_col]
159
+ b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
160
+ args.V:args.td_y * cur_col // args.V]
161
+ L_offset = args.td_y * (cur_col - buf_size)
162
+ for i in reversed(range(buf_size)):
163
+ WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
164
+ b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
165
+ (b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
166
+ b_prod[args.td_y * i : args.td_y * (i + 1)]
167
+ if trellissz > -1:
168
+ WXWXshape = WXWX.shape
169
+ thing = WXWX.T.reshape(-1, trellissz)
170
+ if for_kernel:
171
+ thing = thing[..., _PERMUTE]
172
+ if args.td_y * (cur_col - buf_size) >= n // 2:
173
+ q_out = cb2.quantize(thing)
174
+ else:
175
+ if flag_for_cb1_compile:
176
+ torch._dynamo.reset()
177
+ flag_for_cb1_compile = False
178
+ q_out = cb1.quantize(thing)
179
+ if for_kernel:
180
+ thing = q_out[0][..., _INV_PERMUTE].reshape(
181
+ WXWXshape[1], WXWXshape[0])
182
+ else:
183
+ thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
184
+ idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
185
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
186
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
187
+ (i + 1)] = idxs.T
188
+ else:
189
+ if args.td_y * (cur_col - buf_size) >= n // 2:
190
+ q_out = cb2.quantize(WXWX.T)
191
+ else:
192
+ q_out = cb1.quantize(WXWX.T)
193
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
194
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
195
+ (i + 1)] = q_out[1].T
196
+
197
+ prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
198
+ hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
199
+ cur_col] = b_hatWr_T
200
+
201
+ del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
202
+ utils.clean()
203
+ return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
lib/algo/ldlq_beam_cd.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import glog
5
+ import torch
6
+ from tqdm import tqdm
7
+ import time
8
+ from lib import utils
9
+
10
+ _PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
11
+ 4).flatten()
12
+
13
+ _PERMUTE_HALF = torch.arange(128).reshape(2, 8, 2, 4, 1).permute(1, 3, 2, 0,
14
+ 4).flatten()
15
+ _INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
16
+ _INV_PERMUTE[_PERMUTE] = torch.arange(256)
17
+ _INV_PERMUTE_HALF = torch.zeros(128, dtype=torch.int64)
18
+ _INV_PERMUTE_HALF[_PERMUTE_HALF] = torch.arange(128)
19
+
20
+ def LDLQ(Wr, L, cb, args, D=None, buf_cols=128, for_kernel=True, use_beam_search=False, use_diag=False):
21
+ if for_kernel:
22
+ assert args.td_x == 16 and args.td_y == 16
23
+ buf_cols = max(buf_cols, args.td_y)
24
+ trellissz = args.td_x * args.td_y
25
+ (m, n) = Wr.shape
26
+ assert buf_cols % args.td_y == 0
27
+ assert n % buf_cols == 0
28
+ assert args.td_y % args.V == 0
29
+ buf_size = buf_cols // args.td_y
30
+
31
+ hatWr_T = torch.zeros(n, m, dtype=L.dtype, device=L.device)
32
+ Qidxs_T = torch.zeros(n // args.V, m, dtype=cb.idx_dtype, device=L.device)
33
+
34
+ device = Wr.device
35
+ Wr = Wr.cpu()
36
+ utils.clean()
37
+ Wr_T = Wr.T.contiguous().to(device)
38
+
39
+ # quip
40
+ prod_cache = torch.zeros(n, m, dtype=Wr_T.dtype, device=Wr_T.device)
41
+ for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
42
+ b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
43
+ b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
44
+ cur_col]
45
+ b_L = L[args.td_y * (cur_col - buf_size):args.td_y *
46
+ cur_col].contiguous()
47
+ b_prod = prod_cache[args.td_y * (cur_col - buf_size):args.td_y *
48
+ cur_col]
49
+ b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
50
+ args.V:args.td_y * cur_col // args.V]
51
+ L_offset = args.td_y * (cur_col - buf_size)
52
+ for i in reversed(range(buf_size)):
53
+ WXWX = b_Wr_T[args.td_y * i : args.td_y * (i + 1)] + \
54
+ b_L[args.td_y * (i + 1):, L_offset + args.td_y * i : L_offset + args.td_y * (i + 1)].T @ \
55
+ (b_Wr_T[args.td_y * (i + 1):] - b_hatWr_T[args.td_y * (i + 1):]) + \
56
+ b_prod[args.td_y * i : args.td_y * (i + 1)]
57
+ if trellissz > -1:
58
+ WXWXshape = WXWX.shape
59
+ thing = WXWX.T.reshape(-1, trellissz)
60
+ if for_kernel:
61
+ thing = thing[..., _PERMUTE]
62
+ if use_beam_search:
63
+ # D: (n // td_y, td_y, td_y)
64
+ D_cur = D[cur_col - buf_size + i] # (td_y, td_y)
65
+ D_tiled = torch.kron(torch.eye(args.td_y, device=D_cur.device, dtype=D_cur.dtype), D_cur)
66
+ if for_kernel:
67
+ D_tiled = D_tiled[:, _PERMUTE][_PERMUTE, :]
68
+ q_out = cb.quantize_beam_search_with_hessian(thing, D_tiled, beam_sz=1024)
69
+ else:
70
+ if use_diag:
71
+ D_cur = D[cur_col - buf_size + i] # (td_y, td_y)
72
+ weight = torch.diag(D_cur).repeat(trellissz // args.td_y)[_PERMUTE]
73
+ q_out = cb.quantize(thing, w2=weight)
74
+ else:
75
+ q_out = cb.quantize(thing)
76
+ if for_kernel:
77
+ thing = q_out[0][..., _INV_PERMUTE].reshape(
78
+ WXWXshape[1], WXWXshape[0])
79
+ else:
80
+ thing = q_out[0].reshape(WXWXshape[1], WXWXshape[0])
81
+ idxs = q_out[1].reshape(WXWXshape[1], WXWXshape[0] // args.V)
82
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
83
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
84
+ (i + 1)] = idxs.T
85
+ else:
86
+ raise NotImplementedError
87
+ # q_out = cb.quantize(WXWX.T)
88
+ # b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = q_out[0].T
89
+ # b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
90
+ # (i + 1)] = q_out[1].T
91
+
92
+ prod_cache += b_L.T @ (b_Wr_T - b_hatWr_T)
93
+ hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
94
+ cur_col] = b_hatWr_T
95
+
96
+ del b_Wr_T, b_hatWr_T, b_L, b_prod, L_offset, prod_cache
97
+ utils.clean()
98
+ return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
99
+
100
+ def calc_obj(hatWr_T, Wr_T, HRr):
101
+ diff_T = hatWr_T.cuda() - Wr_T.cuda()
102
+ obj = torch.trace(diff_T.T @ HRr @ diff_T)
103
+ return obj.cpu().item()
104
+
105
+ def CD(Wr, HRr, Qidxs, hatWr, cb, args, buf_cols=128, for_kernel=True, use_beam_search=False):
106
+ if for_kernel:
107
+ assert args.td_x == 16 and args.td_y == 16
108
+ buf_cols = max(buf_cols, args.td_y)
109
+ trellissz = args.td_x * args.td_y
110
+ (m, n) = Wr.shape
111
+ assert buf_cols % args.td_y == 0
112
+ assert n % buf_cols == 0
113
+ assert args.td_y % args.V == 0
114
+ buf_size = buf_cols // args.td_y
115
+
116
+ hatWr_T = hatWr.T.contiguous()
117
+ Qidxs_T = Qidxs.T.contiguous()
118
+ device = hatWr.device
119
+ hatWr = hatWr.cpu()
120
+ utils.clean()
121
+ Wr_T = Wr.T.contiguous().to(device)
122
+
123
+ # obj = calc_obj(hatWr_T, Wr_T, HRr)
124
+ # print("init obj", obj)
125
+ for cur_col in tqdm(range(n // args.td_y, 0, -buf_size)):
126
+ b_Wr_T = Wr_T[args.td_y * (cur_col - buf_size):args.td_y * cur_col]
127
+ b_hatWr_T = hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
128
+ cur_col]
129
+ b_Qidxs_T = Qidxs_T[args.td_y * (cur_col - buf_size) //
130
+ args.V:args.td_y * cur_col // args.V]
131
+ b_HRr = HRr[args.td_y * (cur_col - buf_size):args.td_y * cur_col, args.td_y * (cur_col - buf_size):args.td_y * cur_col] # (buf_size * td_y, buf_size * td_y)
132
+
133
+ # update global hessian
134
+ res_inds = torch.cat([
135
+ torch.arange(0, (cur_col - buf_size) * args.td_y, device=device),
136
+ torch.arange(cur_col * args.td_y, n, device=device)
137
+ ])
138
+ Wr_diff_T = hatWr_T - Wr_T # (n, m)
139
+ b_global_hess = torch.matmul(Wr_diff_T[res_inds].T, HRr[res_inds, args.td_y * (cur_col - buf_size):args.td_y * cur_col]) # (m, buf_size * td_y)
140
+ for i in reversed(range(buf_size)):
141
+ start_col, end_col = args.td_y * i, args.td_y * (i + 1)
142
+ WXWX = b_Wr_T[start_col:end_col] # (td_y, m)
143
+ b_Wr_diff_T = b_hatWr_T - b_Wr_T # (td_y * buf_size, m)
144
+ if trellissz > -1:
145
+ WXWXshape = WXWX.shape
146
+ thing = WXWX.T.reshape(-1, trellissz) # (-1, trellissz)
147
+ if for_kernel:
148
+ thing = thing[..., _PERMUTE]
149
+
150
+ # local hessian
151
+ HRr_cur = b_HRr[start_col:end_col, start_col:end_col].contiguous() # (td_y, td_y)
152
+ HRr_tiled = torch.kron(torch.eye(args.td_y, device=HRr_cur.device, dtype=HRr_cur.dtype), HRr_cur)
153
+
154
+ # global hessian
155
+ cur_global_hess = b_global_hess[:, start_col:end_col] # (m, td_y)
156
+ cur_res_ind = torch.cat([
157
+ torch.arange(0, start_col, device=device),
158
+ torch.arange(end_col, buf_size * args.td_y, device=device)
159
+ ]) # 나머지 indices for args.td_y * i : args.td_y * (i + 1)
160
+
161
+ cur_global_hess_res = torch.matmul(b_Wr_diff_T[cur_res_ind].T, b_HRr[cur_res_ind, start_col:end_col]) # (m, td_y)
162
+ cur_weight = cur_global_hess + cur_global_hess_res # (m, td_y)
163
+ cur_weight = cur_weight.reshape(-1, trellissz)
164
+ if for_kernel:
165
+ cur_weight = cur_weight[..., _PERMUTE] # (-1, trellissz)
166
+ HRr_tiled = HRr_tiled[:, _PERMUTE][_PERMUTE, :]
167
+
168
+ cur_hatWr_T = b_hatWr_T[start_col:end_col].T.reshape(-1, trellissz)[..., _PERMUTE].contiguous() # (-1, trellissz)
169
+ cur_qidx = b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V * (i + 1)].T.reshape(-1, trellissz // args.V) # (-1, trellissz)
170
+ diff = cur_hatWr_T - thing
171
+ obj_before = torch.diag(diff @ HRr_tiled @ diff.T) + torch.sum(cur_weight * diff, dim=-1) * 2 # (-1)
172
+
173
+ if use_beam_search:
174
+ q_out = cb.quantize_beam_search_with_hessian(thing, HRr_tiled, U=cur_weight, beam_sz=1024)
175
+ else:
176
+ q_out = cb.quantize(thing, w1=cur_weight * 2, w2=torch.diag(HRr_tiled))
177
+ diff = q_out[0] - thing
178
+ obj_after = torch.diag(diff @ HRr_tiled @ diff.T) + torch.sum(cur_weight * diff, dim=-1) * 2 # (-1)
179
+
180
+ # select only improved
181
+ improved = obj_before > obj_after
182
+ # out[i] = q_out[0][i] if improved[i] else cur_hatWr_T[i]
183
+ new_hatWr_T = torch.where(improved.unsqueeze(-1), q_out[0], cur_hatWr_T)
184
+ new_qidx = torch.where(improved.unsqueeze(-1), q_out[1], cur_qidx)
185
+
186
+ if for_kernel:
187
+ thing = new_hatWr_T[..., _INV_PERMUTE].reshape(
188
+ WXWXshape[1], WXWXshape[0])
189
+ else:
190
+ thing = new_hatWr_T.reshape(WXWXshape[1], WXWXshape[0])
191
+ idxs = new_qidx.reshape(WXWXshape[1], WXWXshape[0] // args.V)
192
+ b_hatWr_T[args.td_y * i:args.td_y * (i + 1)] = thing.T
193
+ b_Qidxs_T[args.td_y // args.V * i:args.td_y // args.V *
194
+ (i + 1)] = idxs.T
195
+
196
+
197
+ hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
198
+ cur_col] = b_hatWr_T
199
+ else:
200
+ raise NotImplementedError
201
+ hatWr_T[args.td_y * (cur_col - buf_size):args.td_y *
202
+ cur_col] = b_hatWr_T
203
+
204
+ # obj = calc_obj(hatWr_T, Wr_T, HRr)
205
+ # print("cur_col", cur_col, "obj", obj)
206
+
207
+ del b_Wr_T, b_hatWr_T
208
+ utils.clean()
209
+ return hatWr_T.T.contiguous(), Qidxs_T.T.contiguous()
lib/codebook/__pycache__/bitshift.cpython-311.pyc ADDED
Binary file (30.8 kB). View file
 
lib/codebook/__pycache__/vq_codebook.cpython-311.pyc ADDED
Binary file (3.9 kB). View file
 
lib/codebook/bitshift.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ import os
4
+ from functools import cache
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from tqdm import tqdm
10
+
11
+ from lib.utils.kernel_check import has_kernel
12
+ from lib.utils.kernel_decompress import decode_compressed, bitshift_linear_kernel
13
+ from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
14
+ import time
15
+
16
+ def decode_1mad(x):
17
+ x = x.to(torch.int64)
18
+ x = x & ((1 << 32) - 1)
19
+ x = x * 34038481 + 76625530
20
+ x = x & ((1 << 32) - 1)
21
+ y = (x & 255) + ((x >> 8) & 255) + ((x >> 16) & 255) + ((x >> 24) & 255)
22
+ y = y - 510
23
+ y = y.to(torch.float32)
24
+ y = y / 147.800537109375
25
+ return y
26
+
27
+
28
+ def decode_2mad(x):
29
+ x = x.to(torch.int64)
30
+ x = x & ((1 << 32) - 1)
31
+ x = x * 264435761 + 1013904223
32
+ x = x & ((1 << 32) - 1)
33
+ x = ((x * 1664525) >> 32) + x
34
+ x = x & ((1 << 32) - 1)
35
+ y = (x & 255) + ((x >> 8) & 255) + ((x >> 16) & 255) + ((x >> 24) & 255)
36
+ y = y - 510
37
+ y = y.to(torch.float32)
38
+ y = y / 147.800537109375
39
+ return y
40
+
41
+
42
+ def decode_3inst(x):
43
+
44
+ def bfe16_to_fp16(x):
45
+ x[torch.where(x >= 2**15)] -= 2**16
46
+ return torch.tensor(x.to(torch.int16).numpy().view(np.float16))
47
+
48
+ a = 89226354
49
+ b = 64248484
50
+ fpmask = 996162400
51
+ x = x.to(torch.int64)
52
+ x = x & ((1 << 32) - 1)
53
+ x = x * a + b
54
+ mask = (1 << 15) + ((1 << 12) - 1)
55
+ mask = (mask << 16) + mask
56
+ res = (mask & x) ^ fpmask
57
+ top = bfe16_to_fp16(res >> 16)
58
+ bottom = bfe16_to_fp16(res & ((1 << 16) - 1))
59
+ return (top + bottom).float()
60
+
61
+
62
+ def quantlut(tlut, L, nbits):
63
+ with torch.no_grad():
64
+ lut = torch.arange(1 << L)
65
+ lut = (lut + 1) * lut
66
+ lut = (lut >> (16 - nbits)) & ((1 << nbits) - 1)
67
+ lut = tlut[lut]
68
+ return lut
69
+
70
+
71
+ def quantlut_sym(tlut, L, nbits):
72
+ with torch.no_grad():
73
+ lut = torch.arange(1 << L, device=tlut.device)
74
+ lut = (lut + 1) * lut
75
+ sflp = 1 - ((lut >> 15) & 1) * 2
76
+ lut = (lut >> (16 - nbits - 1)) & ((1 << nbits) - 1)
77
+ lut = tlut[lut]
78
+ lut[:, 0] = lut[:, 0] * sflp
79
+ return lut
80
+
81
+
82
+ class bitshift_codebook(nn.Module):
83
+
84
+ def __init__(self,
85
+ L=16,
86
+ KV=4,
87
+ V=2,
88
+ tlut_bits=16,
89
+ decode_mode='lut',
90
+ tlut=None):
91
+ super(bitshift_codebook, self).__init__()
92
+ self.idx_dtype = torch.int32
93
+ self.opt_scale = 1
94
+
95
+ self.L = L
96
+ self.KV = KV
97
+ self.V = V
98
+ self.tlut_bits = tlut_bits
99
+ self.decode_mode = decode_mode
100
+
101
+ if decode_mode == 'lut':
102
+ if tlut is None:
103
+ assert tlut_bits == L
104
+ self.register_buffer('tlut', torch.randn(2**L, V))
105
+ self.register_buffer('lut', self.tlut.T.contiguous())
106
+ else:
107
+ self.tlut = tlut
108
+ self.recons_lut()
109
+
110
+ elif decode_mode == '1mad':
111
+ assert V == 1
112
+ self.register_buffer('lut',
113
+ decode_1mad(torch.arange(2**L)).unsqueeze(0))
114
+ elif decode_mode == '2mad':
115
+ assert V == 1
116
+ self.register_buffer('lut',
117
+ decode_2mad(torch.arange(2**L)).unsqueeze(0))
118
+ elif decode_mode == '3inst':
119
+ assert V == 1
120
+ self.register_buffer('lut',
121
+ decode_3inst(torch.arange(2**L)).unsqueeze(0))
122
+ elif decode_mode == 'quantlut':
123
+ if tlut is None:
124
+ assert tlut_bits > 0
125
+ if V == 1:
126
+ tlut = torch.erfinv((torch.arange(1 << tlut_bits) + 0.5) /
127
+ (1 << tlut_bits) * 2 -
128
+ 1) * torch.tensor(2.0).sqrt()
129
+ elif V == 2:
130
+ n = 2**tlut_bits
131
+ tlut = torch.zeros(n)
132
+ R = ((n / (n - torch.arange(n))).log() * 2).sqrt()
133
+ tlut = torch.stack(
134
+ [R * torch.arange(n).sin(), R * torch.arange(n).cos()],
135
+ dim=-1)
136
+ else:
137
+ raise Exception
138
+ self.register_buffer('tlut', tlut.unsqueeze(-1))
139
+ self.register_buffer(
140
+ 'lut',
141
+ quantlut(self.tlut, L, tlut_bits).T.contiguous())
142
+ else:
143
+ self.tlut = tlut
144
+ self.recons_lut()
145
+ elif decode_mode == 'quantlut_sym':
146
+ if tlut is None:
147
+ assert tlut_bits > 0
148
+ if V == 2:
149
+ fname = f'assets/lut_cache/kmeans_{tlut_bits}_{V}.pt'
150
+ if not os.path.exists(fname):
151
+ tlut = torch.randn(2**tlut_bits, V)
152
+ import scipy
153
+ data = torch.randn(1 << 20, 2)
154
+ clusters = scipy.cluster.vq.kmeans(data, tlut)
155
+ tlut = torch.tensor(clusters[0])
156
+ tlut = (tlut /
157
+ tlut.std(unbiased=False)) * 0.9682458365518543
158
+ torch.save(tlut, fname)
159
+ else:
160
+ tlut = torch.load(fname)
161
+ else:
162
+ raise Exception
163
+ self.register_buffer('tlut', tlut)
164
+ self.register_buffer(
165
+ 'lut',
166
+ quantlut_sym(self.tlut, L, tlut_bits).T.contiguous())
167
+ else:
168
+ self.tlut = tlut
169
+ self.recons_lut()
170
+ else:
171
+ raise Exception
172
+
173
+ self.fakeinf = torch.tensor(torch.inf)
174
+
175
+ self.register_buffer('sumdelta',
176
+ torch.arange(2**(KV)) << (L - KV))
177
+ self.sumdelta = self.sumdelta.view(1, 1, -1)
178
+
179
+ self.register_buffer('state', torch.arange(2**L).unsqueeze(0))
180
+ self.register_buffer('state_cand',
181
+ (self.state >>
182
+ (KV))[0, ::2**(KV)].unsqueeze(-1) +
183
+ self.sumdelta)
184
+ self.register_buffer('recons_state', self.recons(self.state))
185
+
186
+ self.version = 0
187
+
188
+ def recons_lut(self):
189
+ if self.decode_mode == 'lut':
190
+ self.lut = self.tlut.T.contiguous()
191
+ elif self.decode_mode == 'quantlut':
192
+ self.lut = quantlut(self.tlut, self.L,
193
+ self.tlut_bits).T.contiguous()
194
+ elif self.decode_mode == 'quantlut_sym':
195
+ self.lut = quantlut_sym(self.tlut, self.L,
196
+ self.tlut_bits).T.contiguous()
197
+
198
+ def recons(self, encoded, **kwargs):
199
+ return self.lut[:,
200
+ encoded.int().to(self.lut.device)].to(encoded.device)
201
+
202
+ @torch.compile
203
+ def update(self, cost, thing):
204
+ state_err = (self.recons_state -
205
+ thing.unsqueeze(-1)).square().sum(dim=0)
206
+ cand_cost = torch.gather(
207
+ cost.unsqueeze(-2).expand(-1, self.state_cand.shape[1], -1), -1,
208
+ self.state_cand.expand(len(cost), -1, 2**(self.KV)))
209
+ best = torch.min(cand_cost, dim=-1)
210
+ cost = state_err + best.values.unsqueeze(-1).expand(
211
+ -1, -1, 2**(self.KV)).reshape(state_err.shape)
212
+ prev_state = torch.gather(
213
+ self.state_cand.expand(thing.shape[1], -1, -1), -1,
214
+ best.indices.unsqueeze(-1))[..., 0]
215
+ return prev_state, cost
216
+
217
+ def viterbi(self, X, overlap=None):
218
+ """
219
+ X (T, B)
220
+ """
221
+ T, B = X.shape
222
+ assert T % self.V == 0
223
+ # cost is (B, 2**L)
224
+ cost = (self.recons_state -
225
+ X[:self.V].unsqueeze(-1)).square().sum(dim=0)
226
+
227
+ if overlap is not None:
228
+ mask = torch.ones(B, 2**self.L, device=X.device) * self.fakeinf
229
+ allow = (overlap <<
230
+ (self.KV)).unsqueeze(-1) + torch.arange(
231
+ 2**(self.KV)).to(X.device).view(1, 1, -1)
232
+ mask.scatter_(1, allow[0], 0)
233
+ cost = torch.min(cost + mask, self.fakeinf)
234
+
235
+ from_state = torch.zeros(T // self.V,
236
+ B,
237
+ 2**(self.L - self.KV),
238
+ dtype=self.state.dtype,
239
+ device=self.state.device)
240
+
241
+ for i in range(1, T // self.V):
242
+ from_state[i], cost = self.update(cost,
243
+ X[i * self.V:(i + 1) * self.V])
244
+
245
+ if overlap is not None:
246
+ mask = torch.ones(B, 2**self.L, device=X.device) * self.fakeinf
247
+ allow = (overlap.unsqueeze(-1) + self.sumdelta.unsqueeze(0))
248
+ mask.scatter_(1, allow[0, 0], 0)
249
+ cost = torch.min(cost + mask, self.fakeinf)
250
+
251
+ final_state = torch.zeros(T // self.V,
252
+ B,
253
+ dtype=self.idx_dtype,
254
+ device=X.device)
255
+ final_state[T // self.V - 1] = torch.argmin(cost, dim=-1)
256
+ for i in range(T // self.V - 1, 0, -1):
257
+ final_state[i - 1] = torch.gather(
258
+ from_state[i], -1,
259
+ (final_state[i].to(torch.int64).unsqueeze(-1)) >>
260
+ (self.KV))[..., 0]
261
+ return final_state
262
+
263
+ def quantize_seq(self, X, overlap=None, **kwargs):
264
+ T, NO = X.shape
265
+ bs = min(2**(24 - self.L), NO)
266
+ pad_amt = math.ceil(NO / bs) * bs - NO
267
+ X = torch.nn.functional.pad(X, (0, pad_amt))
268
+ T, N = X.shape
269
+ X = X.reshape(T, N // bs, bs).transpose(0, 1).contiguous()
270
+ if overlap is not None:
271
+ overlap = torch.nn.functional.pad(overlap, (0, pad_amt))
272
+ overlap = overlap.reshape(N // bs, bs)
273
+
274
+ Qidxs = torch.zeros(N // bs,
275
+ T // self.V,
276
+ bs,
277
+ dtype=self.idx_dtype,
278
+ device=X.device)
279
+ for i in range(len(X)):
280
+ b_overlap = None if overlap is None else overlap[i]
281
+ Qidxs[i] = self.viterbi(X[i], overlap=b_overlap)
282
+ Qidxs = Qidxs.transpose(0, 1).reshape(T // self.V, N)[:, :NO]
283
+ return Qidxs
284
+
285
+ def quantize(self, X, **kwargs):
286
+ X = X.T.contiguous().to(torch.float16)
287
+ T = X.shape[0]
288
+ roll_X = torch.roll(X, T // (2 * self.V) * self.V, 0)
289
+ state = self.quantize_seq(roll_X, overlap=None)
290
+ overlap = state[T // (2 * self.V)] >> self.KV
291
+ state = self.quantize_seq(X, overlap=overlap)
292
+ hatX = self.recons(state).transpose(0, 1).reshape(X.shape)
293
+ return hatX.T.contiguous().to(X.device), state.T.contiguous().to(
294
+ X.device)
295
+
296
+ def pack_trellis(self, trellis):
297
+ # T is really T // self.V here
298
+ B, T = trellis.shape
299
+ bf = torch.zeros(B,
300
+ T * self.KV + self.L - self.KV,
301
+ dtype=bool,
302
+ device=trellis.device)
303
+ bf[:, :self.L] = (trellis[:, 0].unsqueeze(-1) & (2**torch.arange(
304
+ self.L, device=trellis.device).flip(dims=(-1, ))).unsqueeze(0)) > 0
305
+ K_mask = 2**torch.arange(
306
+ self.KV,
307
+ device=trellis.device).flip(dims=(-1, )).unsqueeze(0)
308
+ for i in range(1, T):
309
+ assert ((trellis[:, i - 1] &
310
+ ((1 << (self.L - self.KV)) - 1)) == (
311
+ trellis[:, i] >> (self.KV))).all()
312
+ bf[:,
313
+ (self.L +
314
+ (i - 1) * self.KV):(self.L + i * self.KV)] = (
315
+ (trellis[:, i] &
316
+ ((1 <<
317
+ (self.KV)) - 1)).unsqueeze(-1) & K_mask) > 0
318
+
319
+ bf = bf[:, :-(self.L - self.KV)]
320
+ pad_amt = math.ceil(
321
+ T * self.KV / 16) * 16 - T * self.KV
322
+ bf = torch.nn.functional.pad(bf, (0, pad_amt)).reshape(
323
+ -1, (T * self.KV + pad_amt) // 16, 16)
324
+
325
+ uint_mask = (2**torch.arange(
326
+ 16, dtype=torch.int32,
327
+ device=bf.device)).flip(dims=(-1, )).unsqueeze(0).unsqueeze(0)
328
+ bf_sum = (bf.to(torch.int32) * uint_mask).sum(dim=-1)
329
+ return bf_sum.to(torch.uint16)
330
+
331
+ class BitshiftLinear(nn.Module):
332
+
333
+ def __init__(self,
334
+ td_x,
335
+ td_y,
336
+ L,
337
+ K,
338
+ V,
339
+ tlut_bits,
340
+ decode_mode,
341
+ dtype=torch.float16,
342
+ tlut=None,
343
+ has_kernel=False):
344
+ super().__init__()
345
+ self.td_x = td_x
346
+ self.td_y = td_y
347
+ self.V = V
348
+ self.cb = bitshift_codebook(L, K, V, tlut_bits, decode_mode, tlut=tlut)
349
+ self.internal_dtype = dtype
350
+ self.has_kernel = has_kernel
351
+ self.scale = 32
352
+
353
+ def get_hatW(self, unpacked_trellis, m, n):
354
+ return self.cb.recons(unpacked_trellis).transpose(0, 1).transpose(
355
+ 1, 2).reshape(m // self.td_x, n // self.td_y, self.td_x,
356
+ self.td_y).transpose(1, 2).reshape(m, n)
357
+
358
+ def get_hatW_kernel(self, trellis, m, n):
359
+ out = decode_compressed(self.cb.L, self.cb.tlut_bits, self.cb.K,
360
+ int(math.log2(self.V)), m, n, trellis.view(-1),
361
+ self.cb.lut.T)
362
+ return out
363
+
364
+ def cache_hatW(self, packed_trellis, had_left, had_right, K_left, K_right,
365
+ m, n, rcp, tp_rank):
366
+ if self.has_kernel:
367
+ hatW = self.get_hatW_kernel(packed_trellis, m, n)
368
+ else:
369
+ hatW = self.get_hatW(
370
+ self.cb.unpack_trellis(packed_trellis, self.td_x * self.td_y),
371
+ m, n)
372
+ hatW = hatW.float() / self.scale
373
+
374
+ if rcp == 1:
375
+ self.hatW = matmul_hadU_cuda(
376
+ matmul_hadU_cuda(hatW.reshape(tp_rank * m, n // tp_rank),
377
+ had_left, K_left).reshape(m, n).T, had_right,
378
+ K_right).T.contiguous().to(self.internal_dtype)
379
+ elif rcp == 2:
380
+ self.hatW = matmul_hadU_cuda(
381
+ matmul_hadU_cuda(hatW, had_left,
382
+ K_left).T.reshape(tp_rank * n,
383
+ m // tp_rank), had_right,
384
+ K_right).reshape(n, m).T.contiguous().to(self.internal_dtype)
385
+ else:
386
+ self.hatW = matmul_hadU_cuda(
387
+ matmul_hadU_cuda(hatW, had_left, K_left).T, had_right,
388
+ K_right).T.contiguous().to(self.internal_dtype)
389
+
390
+ def forward(self,
391
+ input,
392
+ trellis,
393
+ SU,
394
+ SV,
395
+ had_left,
396
+ had_right,
397
+ K_left,
398
+ K_right,
399
+ rcp,
400
+ tp_rank,
401
+ mode='eval',
402
+ use_prev_kernel=False,
403
+ **kwargs):
404
+ n, m = len(SU), len(SV)
405
+ x = input.view(-1, n).to(torch.float32)
406
+ x = x * SU
407
+
408
+ if mode == 'train-fixW':
409
+ x = (x.to(self.internal_dtype) @ self.hatW.T).float()
410
+ else:
411
+ bs = x.shape[0]
412
+
413
+ if rcp == 1:
414
+ x = matmul_hadUt_cuda(x.reshape(-1, n // tp_rank), had_left,
415
+ K_left).reshape(x.shape) / self.scale
416
+ else:
417
+ x = matmul_hadUt_cuda(x, had_left, K_left) / self.scale
418
+
419
+ if bs == 1 and self.has_kernel:
420
+ wrapper = getattr(
421
+ torch.ops.quip_lib,
422
+ f"decompress_gemm_tcq_{m}_1_{x.numel()}_{self.cb.K}")
423
+
424
+ x = wrapper(trellis, x, self.cb.tlut)
425
+
426
+ else:
427
+ if mode == 'train-recons':
428
+ self.cb.recons_lut()
429
+
430
+ if self.has_kernel:
431
+ if use_prev_kernel:
432
+ x = BitshiftLinearKernelAG.apply(
433
+ x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
434
+ self.V, self.cb.lut).float()
435
+ else:
436
+ x = bitshift_linear_kernel(
437
+ x, trellis, m, n, self.cb.L, self.cb.tlut_bits, self.cb.K,
438
+ self.V, self.cb.lut).float()
439
+ else:
440
+ if mode == 'eval':
441
+ trellis = self.cb.unpack_trellis(
442
+ trellis, self.td_x * self.td_y)
443
+ hatW = self.get_hatW(trellis, m, n)
444
+ x = (x.to(hatW.dtype) @ hatW.T).float()
445
+
446
+ if rcp == 2:
447
+ x = matmul_hadU_cuda(x.reshape(-1, m // tp_rank), had_right,
448
+ K_right).reshape(x.shape)
449
+ else:
450
+ x = matmul_hadU_cuda(x, had_right, K_right)
451
+
452
+ x = x.to(SV.device) * (SV * self.scale)
453
+ return x.view(*input.shape[:-1], m).to(input.dtype)
454
+
455
+
456
+
457
+ class BitshiftLinearKernelAG(torch.autograd.Function):
458
+ @staticmethod
459
+ def forward(ctx, input, trellis, m, n, L, tlut_bits, K, V, lut):
460
+ ctx.save_for_backward(trellis, lut)
461
+ ctx.L = L
462
+ ctx.tlut_bits = tlut_bits
463
+ ctx.K = K
464
+ ctx.V = V
465
+ ctx.m = m
466
+ ctx.n = n
467
+
468
+ hatW = decode_compressed(L, tlut_bits, K, int(math.log2(V)),
469
+ m, n, trellis.view(-1), lut.T)
470
+ return input.to(hatW.dtype) @ hatW.T
471
+
472
+ @staticmethod
473
+ def backward(ctx, grad_output):
474
+ trellis, lut = ctx.saved_tensors
475
+ L = ctx.L
476
+ tlut_bits = ctx.tlut_bits
477
+ K = ctx.K
478
+ V = ctx.V
479
+ m = ctx.m
480
+ n = ctx.n
481
+
482
+ hatW = decode_compressed(L, tlut_bits, K, int(math.log2(V)),
483
+ m, n, trellis.view(-1), lut.T)
484
+
485
+ grad_input = grad_output.to(hatW.dtype) @ hatW
486
+ return grad_input, None, None, None, None, None, None, None, None
lib/codebook/vq_codebook.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from lib.utils.kmeans import kmeans_flash1d, kmeans_sklearn
7
+
8
+ class vq_codebook(nn.Module):
9
+ def __init__(self,
10
+ vec_sz=2,
11
+ lut_bits=8):
12
+ super(vq_codebook, self).__init__()
13
+ self.idx_dtype = torch.int32
14
+ self.vec_sz = vec_sz
15
+ self.lut_bits = lut_bits
16
+
17
+ fname = f'assets/lut_cache/vq_kmeans_{lut_bits}_{vec_sz}.pt'
18
+ if not os.path.exists(fname):
19
+ if vec_sz == 1:
20
+ data = torch.randn(int(1e8), vec_sz)
21
+ tlut = kmeans_flash1d(data, 2**lut_bits)
22
+ elif vec_sz in [2,4]:
23
+ data = torch.randn(int(1e8), vec_sz)
24
+ if lut_bits <= 5:
25
+ tlut = kmeans_sklearn(data, 2**lut_bits, max_data=int(1e8))
26
+ else:
27
+ tlut = kmeans_sklearn(data, 2**lut_bits, max_data=int(1e7))
28
+ torch.save(tlut, fname)
29
+ else:
30
+ tlut = torch.load(fname)
31
+ self.register_buffer("tlut", tlut)
32
+ self.register_buffer("lut", tlut.T.contiguous())
33
+
34
+ def recons(self, encoded, **kwargs):
35
+ return self.tlut[encoded].contiguous()
36
+
37
+ def quantize(self, X, **kwargs):
38
+ """
39
+ X : [B, vec_sz]
40
+ """
41
+ dist = torch.cdist(X, self.tlut.to(X.device, dtype=X.dtype)) # [B, 2**lut_bits]
42
+ state = torch.argmin(dist, dim=-1) # [B,] each entry is in [0, 2**lut_bits)
43
+ hatX = self.recons(state)
44
+ return hatX.to(X.device), state.to(X.device)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ for vec_sz in [4]:
49
+ for lut_bits in [6,7,8,9,10,11,12]:
50
+ # for lut_bits in [1,2,3,4,5,6,7,8,9,10,11,12]:
51
+ if vec_sz == 1 and lut_bits > 8:
52
+ continue
53
+ vq = vq_codebook(vec_sz=vec_sz, lut_bits=lut_bits)
54
+ X = torch.randn(int(1e5), vec_sz)
55
+ hatX, state = vq.quantize(X)
56
+ print(f"vec_sz: {vec_sz}, lut_bits: {lut_bits}, mse: {(hatX-X).pow(2).mean()}")
lib/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ MODEL_KEYS = {
2
+ "meta-llama/Llama-3.1-8B": "3_8b",
3
+ "meta-llama/Llama-3.2-1B": "3_1b",
4
+ "meta-llama/Llama-3.2-3B": "3_3b",
5
+ "Qwen/Qwen2.5-7B": "qwen_7b",
6
+ }
lib/linear/__init__.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quantized_linear import QuantizedLinear
2
+ from .incoherent_linear import IncoherentLinear
3
+ from .vq_linear import VQLinearPackSIMT, VQLinearPackTensorCore
4
+ from .tcq_linear import QTIPLinearTCQ
5
+ from .comb_linear import CombLinearTCQ, CombtLinearTCQ
6
+ import vq_tensor_kernels
7
+ import torch
8
+
9
+ kernels = [
10
+ (53248, 16384),
11
+ (16384, 53248),
12
+ (1024, 16384),
13
+ (16384, 16384),
14
+ (4096, 14336),
15
+ (14336, 4096),
16
+ (28672, 4096),
17
+ (2048, 4096),
18
+ (5120, 4096),
19
+ (6144, 4096),
20
+ (1024, 4096),
21
+ (4096, 4096),
22
+ (4096, 11008),
23
+ (11008, 4096),
24
+ (22016, 4096),
25
+ (12288, 4096),
26
+ (8192, 4096),
27
+ (8192, 8192),
28
+ (10240, 8192),
29
+ (57344, 8192),
30
+ (8192, 1024),
31
+ (8192, 28672),
32
+ (28672, 8192),
33
+ (1024, 8192),
34
+ (5120, 5120),
35
+ (10240, 5120),
36
+ (15360, 5120),
37
+ (13568, 5120),
38
+ (27136, 5120),
39
+ (5120, 13568),
40
+ ]
41
+
42
+ kdict = {}
43
+ for vtype, max_bits, bit_stride in [
44
+ ('sq_dup', 4, 1),
45
+ ('sq', 8, 1),
46
+ ('vq2', 12, 1),
47
+ ]:
48
+ for m, k in kernels:
49
+ for n in [1,2,4,8]:
50
+ for bitrate in range(2,max_bits+1, bit_stride):
51
+
52
+ torch.library.define(
53
+ f"ours_lib::decompress_gemm_{m}_{n}_{k}_{bitrate}_{vtype}",
54
+ "(Tensor compressed, Tensor x, Tensor codebook) -> Tensor")
55
+
56
+ name = f"decompress_gemm_{m}_{n}_{k}_{bitrate}_{vtype}"
57
+ kernel_name = f"vq_tensor_kernels.decompress_gemm_{bitrate}_{m}_{n}_{k}_{vtype}"
58
+ exec(f"""\
59
+ @torch.library.register_fake("ours_lib::{name}")
60
+ def {name}_abstract(
61
+ compressed: torch.Tensor,
62
+ x: torch.Tensor,
63
+ codebook: torch.Tensor) -> torch.Tensor:
64
+ return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
65
+
66
+ @torch.library.impl("ours_lib::{name}", "cuda")
67
+ def {name}_cuda(
68
+ compressed: torch.Tensor,
69
+ x: torch.Tensor,
70
+ codebook: torch.Tensor) -> torch.Tensor:
71
+ out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
72
+ {kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
73
+ return out
74
+ """)
75
+
76
+ for bitrate in range(2,max_bits+1, bit_stride):
77
+ name = f"decompress_gemv_{m}_{k}_{bitrate}_{vtype}"
78
+ kernel_name = f"vq_tensor_kernels.decompress_gemm_{bitrate}_{m}_{1}_{k}_{vtype}"
79
+ exec(f"""\
80
+ @torch.library.custom_op("ours_lib::{name}", mutates_args={{"out"}})
81
+ def {name}(
82
+ compressed: torch.Tensor,
83
+ x: torch.Tensor,
84
+ codebook: torch.Tensor,
85
+ out: torch.Tensor) -> torch.Tensor:
86
+ {kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
87
+ @{name}.register_fake
88
+ def {name}_fake(compressed, x, codebook, out):
89
+ return None
90
+ """)
91
+
92
+ for bitrate in range(2,max_bits+1, bit_stride):
93
+ torch.library.define(
94
+ f"ours_lib::decompress_{bitrate}_{vtype}",
95
+ "(Tensor compressed, Tensor codebook, int m, int k) -> Tensor")
96
+
97
+ name = f"decompress_{bitrate}_{vtype}"
98
+ kernel_name = f"vq_tensor_kernels.decompress_{bitrate}_{vtype}"
99
+ exec(f"""\
100
+ @torch.library.register_fake("ours_lib::{name}")
101
+ def {name}_abstract(
102
+ compressed: torch.Tensor,
103
+ codebook: torch.Tensor,
104
+ m: int,
105
+ k: int) -> torch.Tensor:
106
+ return torch.zeros(m, k, dtype=torch.float16, device=compresed.device)
107
+
108
+ @torch.library.impl("ours_lib::{name}", "cuda")
109
+ def {name}_cuda(
110
+ compressed: torch.Tensor,
111
+ codebook: torch.Tensor,
112
+ m: int,
113
+ k: int) -> torch.Tensor:
114
+ out = torch.zeros((m, k), dtype=torch.float16, device=compressed.device)
115
+ {kernel_name}(out, compressed.reshape(-1).view(torch.int32), codebook.reshape(-1))
116
+ return out
117
+ """)
118
+
119
+ import tcq_kernels
120
+ import torch
121
+ MKSHAPE = [
122
+ (53248, 16384),
123
+ (16384, 53248),
124
+ (1024, 16384),
125
+ (16384, 16384),
126
+ (4096, 14336),
127
+ (14336, 4096),
128
+ (28672, 4096),
129
+ (5120, 4096),
130
+ (6144, 4096),
131
+ (512, 4096),
132
+ (1024, 4096),
133
+ (2048, 4096),
134
+ (4096, 4096),
135
+ (2048, 11008),
136
+ (4096, 11008),
137
+ (5504, 4096),
138
+ (11008, 4096),
139
+ (22016, 4096),
140
+ (12288, 4096),
141
+ (8192, 4096),
142
+ (8192, 8192),
143
+ (10240, 8192),
144
+ (57344, 8192),
145
+ (8192, 1024),
146
+ (8192, 28672),
147
+ (28672, 8192),
148
+ (1024, 8192),
149
+ (5120, 5120),
150
+ (10240, 5120),
151
+ (15360, 5120),
152
+ (13568, 5120),
153
+ (27136, 5120),
154
+ (5120, 13568),
155
+ (3072, 3072),
156
+ (1024, 3072),
157
+ (4096, 3072),
158
+ (2048, 3072),
159
+ (5120, 3072),
160
+ (8192, 3072),
161
+ (16384, 3072),
162
+ (3072, 8192),
163
+ ]
164
+
165
+ kdict = {}
166
+ for S in [9, 10, 11]:
167
+ if S == 9:
168
+ bitrate_list = [2,3,4,5,6,7,8,9,10]
169
+ elif S == 10:
170
+ bitrate_list = [8, 9, 10]
171
+ elif S == 11:
172
+ bitrate_list = [9, 10]
173
+ for m, k in MKSHAPE:
174
+ for n in [1,2,4,8]:
175
+ for bitrate in bitrate_list:
176
+ torch.library.define(
177
+ f"ours_lib::decompress_gemm_tcq_{m}_{n}_{k}_{S}_{bitrate}",
178
+ "(Tensor compressed, Tensor x, Tensor codebook) -> Tensor")
179
+
180
+ name = f"decompress_gemm_tcq_{m}_{n}_{k}_{S}_{bitrate}"
181
+ kernel_name = f"tcq_kernels.decompress_gemm_16_{S}_{bitrate}_1_{m}_{n}_{k}"
182
+ exec(f"""\
183
+ @torch.library.register_fake("ours_lib::{name}")
184
+ def {name}_abstract(
185
+ compressed: torch.Tensor,
186
+ x: torch.Tensor,
187
+ codebook: torch.Tensor) -> torch.Tensor:
188
+ return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
189
+
190
+ @torch.library.impl("ours_lib::{name}", "cuda")
191
+ def {name}_cuda(
192
+ compressed: torch.Tensor,
193
+ x: torch.Tensor,
194
+ codebook: torch.Tensor) -> torch.Tensor:
195
+ out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
196
+ {kernel_name}(out, compressed.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
197
+ return out
198
+ """)
199
+ if bitrate == bitrate_list[-1]:
200
+ continue
201
+ torch.library.define(
202
+ f"ours_lib::decompress_gemm_tcq_comb_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}",
203
+ "(Tensor compressed1, Tensor compressed2, Tensor x, Tensor codebook) -> Tensor")
204
+
205
+ name = f"decompress_gemm_tcq_comb_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}"
206
+ kernel_name = f"tcq_kernels.decompress_gemm_comb_16_{S}_{bitrate}_{int(bitrate+1)}_1_{m}_{n}_{k}"
207
+ exec(f"""\
208
+ @torch.library.register_fake("ours_lib::{name}")
209
+ def {name}_abstract(
210
+ compressed1: torch.Tensor,
211
+ compressed2: torch.Tensor,
212
+ x: torch.Tensor,
213
+ codebook: torch.Tensor) -> torch.Tensor:
214
+ return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
215
+
216
+ @torch.library.impl("ours_lib::{name}", "cuda")
217
+ def {name}_cuda(
218
+ compressed1: torch.Tensor,
219
+ compressed2: torch.Tensor,
220
+ x: torch.Tensor,
221
+ codebook: torch.Tensor) -> torch.Tensor:
222
+ out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
223
+ {kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
224
+ return out
225
+ """)
226
+ torch.library.define(
227
+ f"ours_lib::decompress_gemm_tcq_combt_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}",
228
+ "(Tensor compressed1, Tensor compressed2, Tensor x, Tensor codebook) -> Tensor")
229
+
230
+ name = f"decompress_gemm_tcq_combt_{m}_{n}_{k}_{S}_{bitrate}_{int(bitrate+1)}"
231
+ kernel_name = f"tcq_kernels.decompress_gemm_combt_16_{S}_{bitrate}_{int(bitrate+1)}_1_{m}_{n}_{k}"
232
+ exec(f"""\
233
+ @torch.library.register_fake("ours_lib::{name}")
234
+ def {name}_abstract(
235
+ compressed1: torch.Tensor,
236
+ compressed2: torch.Tensor,
237
+ x: torch.Tensor,
238
+ codebook: torch.Tensor) -> torch.Tensor:
239
+ return torch.zeros({n}, {m}, dtype=torch.float32, device=x.device)
240
+
241
+ @torch.library.impl("ours_lib::{name}", "cuda")
242
+ def {name}_cuda(
243
+ compressed1: torch.Tensor,
244
+ compressed2: torch.Tensor,
245
+ x: torch.Tensor,
246
+ codebook: torch.Tensor) -> torch.Tensor:
247
+ out = torch.zeros(({n}, {m}), dtype=torch.float32, device=x.device)
248
+ {kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), x.to(torch.float16), codebook.reshape(-1))
249
+ return out
250
+ """)
251
+
252
+ for S in [9, 10, 11]:
253
+ if S == 9:
254
+ bitrate_list = [2,3,4,5,6,7,8,9,10]
255
+ elif S == 10:
256
+ bitrate_list = [8, 9, 10]
257
+ elif S == 11:
258
+ bitrate_list = [9, 10]
259
+ for bitrate in bitrate_list:
260
+ torch.library.define(
261
+ f"ours_lib::decompress_tcq_{S}_{bitrate}",
262
+ "(Tensor compressed, Tensor codebook, int m, int k) -> Tensor")
263
+
264
+ name = f"decompress_tcq_{S}_{bitrate}"
265
+ kernel_name = f"tcq_kernels.decompress_16_{S}_{bitrate}"
266
+ exec(f"""\
267
+ @torch.library.register_fake("ours_lib::{name}")
268
+ def {name}_abstract(
269
+ compressed: torch.Tensor,
270
+ codebook: torch.Tensor,
271
+ m: int,
272
+ k: int) -> torch.Tensor:
273
+ return torch.zeros(m, k, dtype=torch.float16, device=compresed.device)
274
+
275
+ @torch.library.impl("ours_lib::{name}", "cuda")
276
+ def {name}_cuda(
277
+ compressed: torch.Tensor,
278
+ codebook: torch.Tensor,
279
+ m: int,
280
+ k: int) -> torch.Tensor:
281
+ out = torch.zeros((m, k), dtype=torch.float16, device=compressed.device)
282
+ {kernel_name}(out, compressed.reshape(-1).view(torch.int32), codebook.reshape(-1))
283
+ return out
284
+ """)
285
+ if bitrate == bitrate_list[-1]:
286
+ break
287
+
288
+ torch.library.define(
289
+ f"ours_lib::decompress_tcq_comb_{S}_{bitrate}_{int(bitrate+1)}",
290
+ "(Tensor compressed1, Tensor compressed2, Tensor codebook, int m, int k) -> Tensor")
291
+
292
+ name = f"decompress_tcq_comb_{S}_{bitrate}_{int(bitrate+1)}"
293
+ kernel_name = f"tcq_kernels.decompress_comb_16_{S}_{bitrate}_{int(bitrate+1)}"
294
+ exec(f"""\
295
+ @torch.library.register_fake("ours_lib::{name}")
296
+ def {name}_abstract(
297
+ compressed1: torch.Tensor,
298
+ compressed2: torch.Tensor,
299
+ codebook: torch.Tensor,
300
+ m: int, k: int) -> torch.Tensor:
301
+ return torch.zeros(m, k, dtype=torch.float16, device=compressed1.device)
302
+
303
+ @torch.library.impl("ours_lib::{name}", "cuda")
304
+ def {name}_cuda(
305
+ compressed1: torch.Tensor,
306
+ compressed2: torch.Tensor,
307
+ codebook: torch.Tensor,
308
+ m: int, k: int) -> torch.Tensor:
309
+ out = torch.zeros((m, k), dtype=torch.float16, device=compressed1.device)
310
+ {kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), codebook.reshape(-1))
311
+ return out
312
+ """)
313
+ torch.library.define(
314
+ f"ours_lib::decompress_tcq_combt_{S}_{bitrate}_{int(bitrate+1)}",
315
+ "(Tensor compressed1, Tensor compressed2, Tensor codebook, int m, int k) -> Tensor")
316
+
317
+ name = f"decompress_tcq_combt_{S}_{bitrate}_{int(bitrate+1)}"
318
+ kernel_name = f"tcq_kernels.decompress_combt_16_{S}_{bitrate}_{int(bitrate+1)}"
319
+ exec(f"""\
320
+ @torch.library.register_fake("ours_lib::{name}")
321
+ def {name}_abstract(
322
+ compressed1: torch.Tensor,
323
+ compressed2: torch.Tensor,
324
+ codebook: torch.Tensor,
325
+ m: int, k: int) -> torch.Tensor:
326
+ return torch.zeros(m, k, dtype=torch.float16, device=compressed1.device)
327
+
328
+ @torch.library.impl("ours_lib::{name}", "cuda")
329
+ def {name}_cuda(
330
+ compressed1: torch.Tensor,
331
+ compressed2: torch.Tensor,
332
+ codebook: torch.Tensor,
333
+ m: int, k: int) -> torch.Tensor:
334
+ out = torch.zeros((m, k), dtype=torch.float16, device=compressed1.device)
335
+ {kernel_name}(out, compressed1.reshape(-1).view(torch.int32), compressed2.reshape(-1).view(torch.int32), codebook.reshape(-1))
336
+ return out
337
+ """)
338
+
339
+
340
+
341
+
342
+
343
+
344
+ import sq_pack_gemm
345
+ import vq_pack_gemm
346
+ """
347
+ SQ Pack SIMT
348
+ """
349
+ torch.library.define("ours_lib::sq_pack_gemm_simt", "(Tensor x, Tensor q_weight, Tensor lut, int bitwidth) -> Tensor")
350
+ @torch.library.register_fake("ours_lib::sq_pack_gemm_simt")
351
+ def sq_pack_gemm_simt_abstract(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int) -> torch.Tensor:
352
+ return torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
353
+
354
+ @torch.library.impl("ours_lib::sq_pack_gemm_simt", "cuda")
355
+ def sq_pack_gemm_simt_cuda(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int) -> torch.Tensor:
356
+ output = torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
357
+ sq_pack_gemm.pack_gemm(x, output, q_weight, lut.view(-1), bitwidth)
358
+ return output
359
+
360
+ torch.library.define("ours_lib::sq_pack_dequant_simt", "(Tensor q_weight, Tensor lut, int bitwidth, int m, int k) -> Tensor")
361
+ @torch.library.register_fake("ours_lib::sq_pack_dequant_simt")
362
+ def sq_pack_dequant_simt_abstract(q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int, m:int, k:int) -> torch.Tensor:
363
+ return torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
364
+
365
+ @torch.library.impl("ours_lib::sq_pack_dequant_simt", "cuda")
366
+ def sq_pack_dequant_simt_cuda(q_weight: torch.Tensor, lut: torch.Tensor, bitwidth:int, m:int, k:int) -> torch.Tensor:
367
+ output = torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
368
+ sq_pack_gemm.pack_dequant(output, q_weight, lut.view(-1), bitwidth)
369
+ return output
370
+
371
+
372
+ @torch.library.custom_op("ours_lib::sq_pack_gemm_inplace_simt", mutates_args={"output"})
373
+ def sq_pack_gemm_inplace_simt(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor, output:torch.Tensor, bitwidth:int) -> None:
374
+ sq_pack_gemm.pack_gemm(x, output, q_weight, lut, bitwidth)
375
+
376
+ @sq_pack_gemm_inplace_simt.register_fake
377
+ def _(x, q_weight, lut, output, bitwidth):
378
+ return None
379
+
380
+ """
381
+ VQ Pack SIMT
382
+ """
383
+ codeT_sz = 32
384
+ for vec_sz in [2,4]:
385
+ if vec_sz == 2:
386
+ lut_bits_list = [3,4,5,6,7,8,9,10,11,12]
387
+ elif vec_sz == 4:
388
+ lut_bits_list = [6,7,8,9,10,11,12]
389
+ for lut_bits in lut_bits_list:
390
+ code_n = lut_bits
391
+ recons_n = int(vec_sz * 16)
392
+ for maxm in [1,2,4,8]:
393
+ name = f"vq_pack_gemm_simt_{maxm}_{vec_sz}_{lut_bits}"
394
+ kernel_name = f"vq_pack_gemm.vq_pack_gemm_{maxm}_{lut_bits}_{vec_sz}_{code_n}_{codeT_sz}_{recons_n}"
395
+ torch.library.define(f"ours_lib::{name}", "(Tensor x, Tensor q_weight, Tensor lut) -> Tensor")
396
+ exec(f"""\
397
+ @torch.library.register_fake("ours_lib::{name}")
398
+ def {name}_abstract(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
399
+ return torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
400
+
401
+ @torch.library.impl("ours_lib::{name}", "cuda")
402
+ def {name}_cuda(x: torch.Tensor, q_weight: torch.Tensor, lut: torch.Tensor) -> torch.Tensor:
403
+ output = torch.zeros(x.shape[0], 1, q_weight.shape[0], dtype=torch.float16, device=x.device)
404
+ {kernel_name}(x, output, q_weight.view(torch.uint32), lut)
405
+ return output
406
+ """)
407
+ name = f"vq_pack_dequant_simt_{vec_sz}_{lut_bits}"
408
+ kernel_name = f"vq_pack_gemm.vq_pack_dequant_{lut_bits}_{vec_sz}_{code_n}_{codeT_sz}_{recons_n}"
409
+ torch.library.define(f"ours_lib::{name}", "(Tensor q_weight, Tensor lut, int m, int k) -> Tensor")
410
+ exec(f"""\
411
+ @torch.library.register_fake("ours_lib::{name}")
412
+ def {name}_abstract(q_weight: torch.Tensor, lut: torch.Tensor, m: int, k: int) -> torch.Tensor:
413
+ return torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
414
+
415
+ @torch.library.impl("ours_lib::{name}", "cuda")
416
+ def {name}_cuda(q_weight: torch.Tensor, lut: torch.Tensor, m: int, k: int) -> torch.Tensor:
417
+ output = torch.zeros(m, k, dtype=torch.float16, device=q_weight.device)
418
+ {kernel_name}(output, q_weight.view(torch.uint32), lut)
419
+ return output
420
+ """)
421
+
422
+
423
+ if __name__ == "__main__":
424
+ # layer = QTIPLinearTCQ(4096, 4096, 16, 16, 16, 8, 1, 9, False, torch.float16)
425
+ # print(layer._info())
426
+ # x = torch.randn(4, 4096)
427
+ # print(layer(x).shape)
428
+ layer = CombLinearTCQ(4096, 4096, 16, 16, (2048, 2048), 16, (3, 4), 2, 9, False)
429
+ print(layer._info())
430
+ layer.forward(torch.randn(1, 4096).cuda())
lib/linear/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
lib/linear/__pycache__/comb_linear.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
lib/linear/__pycache__/incoherent_linear.cpython-311.pyc ADDED
Binary file (42.7 kB). View file
 
lib/linear/__pycache__/quantized_linear.cpython-311.pyc ADDED
Binary file (6.42 kB). View file
 
lib/linear/__pycache__/tcq_linear.cpython-311.pyc ADDED
Binary file (6.88 kB). View file
 
lib/linear/__pycache__/vq_linear.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
lib/linear/comb_linear.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class CombLinearTCQ(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_features,
9
+ out_features,
10
+ td_x,
11
+ td_y,
12
+ out_part,
13
+ L, # trellis window
14
+ KV, # bpw
15
+ V, # vq dim
16
+ tlut_bits,
17
+ bias=False,
18
+ dtype=torch.float16,
19
+ ):
20
+ super().__init__()
21
+ assert len(out_part) == 2 and len(KV) == 2
22
+ assert out_part[0] + out_part[1] == out_features
23
+
24
+ self.in_features = in_features
25
+ self.out_features = out_features
26
+ self.out_part = out_part
27
+ self.td_x = td_x
28
+ self.td_y = td_y
29
+ self.L = L
30
+ self.KV = KV
31
+ self.V = V
32
+ self.tlut_bits = tlut_bits
33
+ self.dtype = dtype
34
+ # packed into int16
35
+ self.register_buffer(
36
+ 'trellis1',
37
+ torch.zeros((out_part[0] // td_x) * (in_features // td_y),
38
+ math.ceil((td_x * td_y) * KV[0] / 16 / V),
39
+ dtype=torch.int16))
40
+ self.register_buffer(
41
+ 'trellis2',
42
+ torch.zeros((out_part[1] // td_x) * (in_features // td_y),
43
+ math.ceil((td_x * td_y) * KV[1] / 16 / V),
44
+ dtype=torch.int16))
45
+ self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
46
+ V,
47
+ dtype=torch.float16),
48
+ requires_grad=False)
49
+
50
+ if bias:
51
+ self.register_buffer('bias', torch.ones(out_features))
52
+ else:
53
+ self.bias = None
54
+
55
+ if out_part[0] == out_part[1]:
56
+ self.use_comb_kernel = True
57
+ else:
58
+ self.use_comb_kernel = False
59
+
60
+
61
+ def _info(self):
62
+ info = {
63
+ "in_features": self.in_features,
64
+ "out_features": self.out_features,
65
+ "td_x": self.td_x,
66
+ "td_y": self.td_y,
67
+ "out_part": self.out_part,
68
+ "L": self.L,
69
+ "KV": self.KV,
70
+ "V": self.V,
71
+ 'tlut_bits': self.tlut_bits,
72
+ "dtype": self.dtype,
73
+ "trellis1": self.trellis1.detach().cpu(),
74
+ "trellis2": self.trellis2.detach().cpu(),
75
+ "tlut": self.tlut.detach().cpu().half(),
76
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
77
+ }
78
+ return info
79
+
80
+ def forward(self, inp, **kwargs):
81
+ x = inp.view(-1, self.in_features)
82
+ bs = x.shape[0]
83
+ m, k = self.out_features, self.in_features
84
+ if bs <= 8:
85
+ if self.use_comb_kernel:
86
+ wrapper = getattr(
87
+ torch.ops.ours_lib,
88
+ f"decompress_gemm_tcq_comb_{self.out_features}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
89
+ )
90
+ x = wrapper(self.trellis1, self.trellis2, x, self.tlut)
91
+ else:
92
+ wrapper1 = getattr(
93
+ torch.ops.ours_lib,
94
+ f"decompress_gemm_tcq_{self.out_part[0]}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}"
95
+ )
96
+ wrapper2 = getattr(
97
+ torch.ops.ours_lib,
98
+ f"decompress_gemm_tcq_{self.out_part[1]}_{bs}_{k}_{self.tlut_bits}_{self.KV[1]}"
99
+ )
100
+ x1 = wrapper1(self.trellis1, x, self.tlut)
101
+ x2 = wrapper2(self.trellis2, x, self.tlut)
102
+ x = torch.cat([x1, x2], dim=1)
103
+ else:
104
+ if self.use_comb_kernel:
105
+ wrapper = getattr(
106
+ torch.ops.ours_lib,
107
+ f"decompress_tcq_comb_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
108
+ )
109
+ with torch.no_grad():
110
+ dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, k)
111
+ x = x.to(dq.dtype) @ dq.T
112
+ else:
113
+ wrapper1 = getattr(
114
+ torch.ops.ours_lib,
115
+ f"decompress_tcq_{self.tlut_bits}_{self.KV[0]}"
116
+ )
117
+ wrapper2 = getattr(
118
+ torch.ops.ours_lib,
119
+ f"decompress_tcq_{self.tlut_bits}_{self.KV[1]}"
120
+ )
121
+ with torch.no_grad():
122
+ dq1 = wrapper1(self.trellis1, self.tlut, self.out_part[0], k)
123
+ dq2 = wrapper2(self.trellis2, self.tlut, self.out_part[1], k)
124
+ x1 = x.to(dq1.dtype) @ dq1.T
125
+ x2 = x.to(dq2.dtype) @ dq2.T
126
+ x = torch.cat([x1, x2], dim=1)
127
+ return x.view(*inp.shape[:-1], m).to(inp.dtype)
128
+
129
+ @staticmethod
130
+ def gen_layer_from_info(info):
131
+ layer = CombLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["out_part"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
132
+ layer.trellis1.data.copy_(info["trellis1"])
133
+ layer.trellis2.data.copy_(info["trellis2"])
134
+ layer.tlut.data.copy_(info["tlut"])
135
+ if info["bias"] is not None:
136
+ layer.bias.data.copy_(info["bias"])
137
+ return layer
138
+
139
+ def get_weight(self):
140
+ wrapper = getattr(
141
+ torch.ops.ours_lib,
142
+ f"decompress_tcq_comb_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
143
+ )
144
+ dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, self.in_features)
145
+ return dq
146
+
147
+
148
+ class CombtLinearTCQ(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_features,
152
+ out_features,
153
+ td_x,
154
+ td_y,
155
+ in_part,
156
+ L, # trellis window
157
+ KV, # bpw
158
+ V, # vq dim
159
+ tlut_bits,
160
+ bias=False,
161
+ dtype=torch.float16,
162
+ ):
163
+ super().__init__()
164
+ assert len(in_part) == 2 and len(KV) == 2
165
+ assert in_part[0] + in_part[1] == in_features
166
+
167
+ self.in_features = in_features
168
+ self.out_features = out_features
169
+ self.in_part = in_part
170
+ self.td_x = td_x
171
+ self.td_y = td_y
172
+ self.L = L
173
+ self.KV = KV
174
+ self.V = V
175
+ self.tlut_bits = tlut_bits
176
+ self.dtype = dtype
177
+ # packed into int16
178
+ self.register_buffer(
179
+ 'trellis1',
180
+ torch.zeros((out_features // td_x) * (in_part[0] // td_y),
181
+ math.ceil((td_x * td_y) * KV[0] / 16 / V),
182
+ dtype=torch.int16))
183
+ self.register_buffer(
184
+ 'trellis2',
185
+ torch.zeros((out_features // td_x) * (in_part[1] // td_y),
186
+ math.ceil((td_x * td_y) * KV[1] / 16 / V),
187
+ dtype=torch.int16))
188
+ self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
189
+ V,
190
+ dtype=torch.float16),
191
+ requires_grad=False)
192
+
193
+ if bias:
194
+ self.register_buffer('bias', torch.ones(out_features))
195
+ else:
196
+ self.bias = None
197
+
198
+ if in_part[0] == in_part[1]:
199
+ self.use_comb_kernel = True
200
+ else:
201
+ self.use_comb_kernel = False
202
+
203
+
204
+ def _info(self):
205
+ info = {
206
+ "in_features": self.in_features,
207
+ "out_features": self.out_features,
208
+ "td_x": self.td_x,
209
+ "td_y": self.td_y,
210
+ "in_part": self.in_part,
211
+ "L": self.L,
212
+ "KV": self.KV,
213
+ "V": self.V,
214
+ 'tlut_bits': self.tlut_bits,
215
+ "dtype": self.dtype,
216
+ "trellis1": self.trellis1.detach().cpu(),
217
+ "trellis2": self.trellis2.detach().cpu(),
218
+ "tlut": self.tlut.detach().cpu().half(),
219
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
220
+ }
221
+ return info
222
+
223
+ def forward(self, inp, **kwargs):
224
+ x = inp.view(-1, self.in_features)
225
+ bs = x.shape[0]
226
+ m, k = self.out_features, self.in_features
227
+ if bs <= 8:
228
+ if self.use_comb_kernel:
229
+ wrapper = getattr(
230
+ torch.ops.ours_lib,
231
+ f"decompress_gemm_tcq_combt_{self.out_features}_{bs}_{k}_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
232
+ )
233
+ x = wrapper(self.trellis1, self.trellis2, x, self.tlut)
234
+ else:
235
+ wrapper1 = getattr(
236
+ torch.ops.ours_lib,
237
+ f"decompress_gemm_tcq_{m}_{bs}_{self.in_part[0]}_{self.tlut_bits}_{self.KV[0]}"
238
+ )
239
+ wrapper2 = getattr(
240
+ torch.ops.ours_lib,
241
+ f"decompress_gemm_tcq_{m}_{bs}_{self.in_part[1]}_{self.tlut_bits}_{self.KV[1]}"
242
+ )
243
+ x1 = wrapper1(self.trellis1, x[:, :self.in_part[0]], self.tlut)
244
+ x2 = wrapper2(self.trellis2, x[:, self.in_part[0]:], self.tlut)
245
+ x = x1 + x2
246
+ else:
247
+ if self.use_comb_kernel:
248
+ wrapper = getattr(
249
+ torch.ops.ours_lib,
250
+ f"decompress_tcq_combt_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
251
+ )
252
+ with torch.no_grad():
253
+ dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, k)
254
+ x = x.to(dq.dtype) @ dq.T
255
+ else:
256
+ wrapper1 = getattr(
257
+ torch.ops.ours_lib,
258
+ f"decompress_tcq_{self.tlut_bits}_{self.KV[0]}"
259
+ )
260
+ wrapper2 = getattr(
261
+ torch.ops.ours_lib,
262
+ f"decompress_tcq_{self.tlut_bits}_{self.KV[1]}"
263
+ )
264
+ with torch.no_grad():
265
+ dq1 = wrapper1(self.trellis1, self.tlut, m, self.in_part[0])
266
+ dq2 = wrapper2(self.trellis2, self.tlut, m, self.in_part[1])
267
+ x1 = x[:, :self.in_part[0]].to(dq1.dtype) @ dq1.T
268
+ x2 = x[:, self.in_part[0]:].to(dq2.dtype) @ dq2.T
269
+ x = x1 + x2
270
+ return x.view(*inp.shape[:-1], m).to(inp.dtype)
271
+
272
+ @staticmethod
273
+ def gen_layer_from_info(info):
274
+ layer = CombtLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["in_part"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
275
+ layer.trellis1.data.copy_(info["trellis1"])
276
+ layer.trellis2.data.copy_(info["trellis2"])
277
+ layer.tlut.data.copy_(info["tlut"])
278
+ if info["bias"] is not None:
279
+ layer.bias.data.copy_(info["bias"])
280
+ return layer
281
+
282
+ def get_weight(self):
283
+ wrapper = getattr(
284
+ torch.ops.ours_lib,
285
+ f"decompress_tcq_combt_{self.tlut_bits}_{self.KV[0]}_{self.KV[1]}"
286
+ )
287
+ dq = wrapper(self.trellis1, self.trellis2, self.tlut, self.out_features, self.in_features)
288
+ return dq
289
+
290
+ @staticmethod
291
+ def merge_infos(info1, info2):
292
+ assert info1["in_features"] == info2["in_features"]
293
+ assert info1["td_x"] == info2["td_x"]
294
+ assert info1["td_y"] == info2["td_y"]
295
+ assert info1["L"] == info2["L"]
296
+ assert info1["KV"] == info2["KV"]
297
+ assert info1["V"] == info2["V"]
298
+ assert info1["tlut_bits"] == info2["tlut_bits"]
299
+ if not torch.allclose(info1["tlut"], info2["tlut"], atol=1e-4):
300
+ print("warning: tlut is not close. it is unexpected behavior if you do not use dummy quantizers.")
301
+ assert info1["bias"] is None and info2["bias"] is None
302
+ assert info1["dtype"] == info2["dtype"]
303
+ info = {}
304
+ info["in_features"] = info1["in_features"]
305
+ info["out_features"] = info1["out_features"] + info2["out_features"]
306
+ info["td_x"] = info1["td_x"]
307
+ info["td_y"] = info1["td_y"]
308
+ info["L"] = info1["L"]
309
+ info["KV"] = info1["KV"]
310
+ info["V"] = info1["V"]
311
+ info["tlut_bits"] = info1["tlut_bits"]
312
+ info["bias"] = None
313
+ info["dtype"] = info1["dtype"]
314
+ info["trellis1"] = torch.cat([info1["trellis1"], info2["trellis1"]], dim=0)
315
+ info["trellis2"] = torch.cat([info1["trellis2"], info2["trellis2"]], dim=0)
316
+ info["tlut"] = info1["tlut"]
317
+ info["in_part"] = info1["in_part"]
318
+
319
+
320
+ return info
321
+
322
+ if __name__ == "__main__":
323
+ layer = CombLinearTCQ(4096, 4096, 16, 16, (2048, 2048), 16, (3, 4), 2, 9, False)
324
+ print(layer._info())
325
+ layer.forward(torch.randn(1, 4096).cuda())
lib/linear/incoherent_linear.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from lib.utils import (get_hadK, matmul_hadU_cuda, matmul_hadUt_cuda, matmul_hadUt, matmul_hadU, matmul_hadUt_head, matmul_hadU_head, matmul_hadU_head_cuda, matmul_hadUt_head_cuda)
5
+ from lib.linear.tcq_linear import QTIPLinearTCQ
6
+ from lib.linear.vq_linear import VQLinearPackTensorCore, VQLinearPackSIMT
7
+ from lib.linear.comb_linear import CombLinearTCQ, CombtLinearTCQ
8
+ from transformers.activations import ACT2FN
9
+ from transformers.models.llama.configuration_llama import LlamaConfig
10
+ from typing import Optional, Tuple
11
+ from model.llama import LlamaRotaryEmbedding, repeat_kv, apply_rotary_pos_emb, Cache
12
+
13
+ def make_linear(info, use_simt=False):
14
+ if "tcq" in info["quant_info"]["quantizer_str"]:
15
+ linear = QTIPLinearTCQ.gen_layer_from_info(info["linear_info"])
16
+ elif use_simt and ("sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]):
17
+ linear = VQLinearPackSIMT.gen_layer_from_info(info["linear_info"])
18
+ elif "sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]:
19
+ linear = VQLinearPackTensorCore.gen_layer_from_info(info["linear_info"])
20
+ elif "tcomb" in info["quant_info"]["quantizer_str"]:
21
+ linear = CombtLinearTCQ.gen_layer_from_info(info["linear_info"])
22
+ elif "comb" in info["quant_info"]["quantizer_str"]:
23
+ linear = CombLinearTCQ.gen_layer_from_info(info["linear_info"])
24
+ else:
25
+ linear = nn.Linear(info["in_features"], info["out_features"], bias=False)
26
+ return linear
27
+
28
+ class IncoherentSdpaAttention(nn.Module):
29
+ def __init__(self, config, merge_qk=False, merge_kv=False, merge_qv=False, merge_qkv=False, layer_idx=None, dtype=torch.float16):
30
+ super().__init__()
31
+ self.config = config
32
+ self.attention_dropout = config.attention_dropout
33
+ self.hidden_size = config.hidden_size
34
+ self.num_heads = config.num_attention_heads
35
+ self.head_dim = getattr(config, "head_dim",
36
+ self.hidden_size // self.num_heads)
37
+ self.num_key_value_heads = config.num_key_value_heads
38
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
39
+ self.max_position_embeddings = config.max_position_embeddings
40
+ self.kv_out = self.hidden_size * self.num_key_value_heads // self.num_heads
41
+ self.rope_theta = config.rope_theta
42
+ self.is_causal = True
43
+
44
+ self.q_proj = None
45
+ self.k_proj = None
46
+ self.v_proj = None
47
+ self.o_proj = None
48
+ self.qk_proj = None
49
+ self.qkv_proj = None
50
+ self.kv_proj = None
51
+ self.dtype=dtype
52
+ self.layer_idx = layer_idx
53
+ self.register_buffer("SU_qkv", torch.ones(config.hidden_size, dtype=self.dtype))
54
+ self.register_buffer("SU_o", torch.ones(config.hidden_size, dtype=self.dtype))
55
+
56
+ hidden_had, hidden_K = get_hadK(config.hidden_size)
57
+
58
+ hidden_had_T = hidden_had.T.contiguous().cuda() if hidden_had is not None else None
59
+
60
+ self.register_buffer('Wscale_qkv', torch.ones(config.hidden_size + 2 * self.kv_out, dtype=self.dtype), persistent=False)
61
+ self.register_buffer('Wscale_o', torch.ones(config.hidden_size, dtype=self.dtype), persistent=False)
62
+ self.register_buffer('had_left_qkv_T', hidden_had_T, persistent=False)
63
+ self.register_buffer('had_left_o_T', hidden_had_T, persistent=False)
64
+
65
+ self.hidden_K = hidden_K
66
+ self.scale = 64.0
67
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
68
+
69
+ self.merge_qk = merge_qk
70
+ self.merge_kv = merge_kv
71
+ self.merge_qv = merge_qv
72
+ self.merge_qkv = merge_qkv
73
+
74
+ assert sum([self.merge_qk, self.merge_kv, self.merge_qv, self.merge_qkv]) <= 1, "Only one of merge_qk, merge_kv, merge_qv, merge_qkv can be True"
75
+
76
+ def compute_qkv(self, input):
77
+
78
+ n = len(self.SU_qkv)
79
+ x = input.view(-1, n).half()
80
+
81
+ x = matmul_hadU_cuda(x * self.SU_qkv, self.had_left_qkv_T, self.hidden_K) / self.scale
82
+ if self.merge_qkv:
83
+ qkv = self.qkv_proj(x.half()) * self.Wscale_qkv * self.scale
84
+ q, k, v = qkv.split([self.hidden_size, self.kv_out, self.kv_out], dim=-1)
85
+ elif self.merge_qk:
86
+ qk = self.qk_proj(x.half()) * self.Wscale_qkv[:self.hidden_size + self.kv_out] * self.scale
87
+ q, k = qk.split([self.hidden_size, self.kv_out], dim=-1)
88
+ v = self.v_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
89
+ elif self.merge_kv:
90
+ kv = self.kv_proj(x.half()) * self.Wscale_qkv[self.hidden_size:] * self.scale
91
+ k, v = kv.split([self.kv_out, self.kv_out], dim=-1)
92
+ q = self.q_proj(x.half()) * self.Wscale_qkv[:self.hidden_size] * self.scale
93
+ elif self.merge_qv:
94
+ qv = self.qv_proj(x.half()) * self.Wscale_qkv[:self.hidden_size + self.kv_out] * self.scale
95
+ q, v = qv.split([self.hidden_size, self.kv_out], dim=-1)
96
+ k = self.k_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
97
+ else:
98
+ q = self.q_proj(x.half()) * self.Wscale_qkv[:self.hidden_size] * self.scale
99
+ k = self.k_proj(x.half()) * self.Wscale_qkv[self.hidden_size:self.hidden_size + self.kv_out] * self.scale
100
+ v = self.v_proj(x.half()) * self.Wscale_qkv[self.hidden_size + self.kv_out:] * self.scale
101
+ return q.view(*input.shape[:-1], n), k.view(*input.shape[:-1], self.kv_out), v.view(*input.shape[:-1], self.kv_out)
102
+
103
+ def compute_o(self, input):
104
+ n = len(self.SU_o)
105
+ x = input.view(-1, n).half()
106
+ x = matmul_hadU_cuda(x * self.SU_o, self.had_left_o_T, self.hidden_K) / self.scale
107
+ x = self.o_proj(x.half()) * self.Wscale_o * self.scale
108
+ return x.view(*input.shape[:-1], n)
109
+
110
+ def forward(
111
+ self,
112
+ hidden_states: torch.Tensor,
113
+ attention_mask: Optional[torch.Tensor] = None,
114
+ position_ids: Optional[torch.LongTensor] = None,
115
+ past_key_value: Optional[Cache] = None,
116
+ output_attentions: bool = False,
117
+ use_cache: bool = False,
118
+ cache_position: Optional[torch.LongTensor] = None,
119
+ position_embeddings: Optional[
120
+ Tuple[torch.Tensor,
121
+ torch.Tensor]] = None, # will become mandatory in v4.46
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
124
+ Optional[Tuple[torch.Tensor]]]:
125
+ if output_attentions:
126
+ return super().forward(
127
+ hidden_states=hidden_states,
128
+ attention_mask=attention_mask,
129
+ position_ids=position_ids,
130
+ past_key_value=past_key_value,
131
+ output_attentions=output_attentions,
132
+ use_cache=use_cache,
133
+ cache_position=cache_position,
134
+ position_embeddings=position_embeddings,
135
+ )
136
+
137
+ bsz, q_len, _ = hidden_states.size()
138
+
139
+ # query_states = self.q_proj(hidden_states)
140
+ # key_states = self.k_proj(hidden_states)
141
+ # value_states = self.v_proj(hidden_states)
142
+ query_states, key_states, value_states = self.compute_qkv(hidden_states)
143
+
144
+ query_states = query_states.view(bsz, q_len, self.num_heads,
145
+ self.head_dim).transpose(1, 2)
146
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
147
+ self.head_dim).transpose(1, 2)
148
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
149
+ self.head_dim).transpose(1, 2)
150
+
151
+ if position_embeddings is None:
152
+ cos, sin = self.rotary_emb(value_states, position_ids)
153
+ else:
154
+ cos, sin = position_embeddings
155
+
156
+ query_states, key_states = apply_rotary_pos_emb(
157
+ query_states, key_states, cos, sin)
158
+
159
+
160
+ if past_key_value is not None:
161
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
162
+ cache_kwargs = {
163
+ "sin": sin,
164
+ "cos": cos,
165
+ "cache_position": cache_position
166
+ }
167
+ key_states, value_states = past_key_value.update(
168
+ key_states, value_states, self.layer_idx, cache_kwargs)
169
+
170
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
171
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
172
+
173
+ causal_mask = attention_mask
174
+ if attention_mask is not None:
175
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
176
+
177
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
178
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
179
+ if query_states.device.type == "cuda" and causal_mask is not None:
180
+ query_states = query_states.contiguous()
181
+ key_states = key_states.contiguous()
182
+ value_states = value_states.contiguous()
183
+
184
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
185
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
186
+ is_causal = True if causal_mask is None and q_len > 1 else False
187
+
188
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
189
+ query_states,
190
+ key_states.to(query_states.device),
191
+ value_states.to(query_states.device),
192
+ attn_mask=causal_mask,
193
+ dropout_p=self.attention_dropout if self.training else 0.0,
194
+ is_causal=is_causal,
195
+ )
196
+
197
+ attn_output = attn_output.transpose(1, 2).contiguous()
198
+ attn_output = attn_output.view(bsz, q_len, -1)
199
+
200
+ # attn_output = self.o_proj(attn_output)
201
+ attn_output = self.compute_o(attn_output)
202
+ return attn_output, None, past_key_value
203
+
204
+ @staticmethod
205
+ def gen_layer_from_info(config, layer_idx, info_q, info_k, info_v, info_o, merge_qk=False, merge_qv=False, merge_kv=False, merge_qkv=False, dummy=False, use_simt=False, use_simt_q=None, use_simt_k=None, use_simt_v=None, use_simt_o=None):
206
+ attn = IncoherentSdpaAttention(config, merge_qk=merge_qk, merge_qv=merge_qv, merge_kv=merge_kv, merge_qkv=merge_qkv, layer_idx=layer_idx)
207
+ if not dummy:
208
+ attn.SU_qkv.data.copy_(info_q["SU"])
209
+ attn.SU_o.data.copy_(info_o["SU"])
210
+ if not merge_qv:
211
+ attn.Wscale_qkv.data.copy_(torch.cat([info_q["Wscale"], info_k["Wscale"], info_v["Wscale"]], dim=-1))
212
+ else:
213
+ attn.Wscale_qkv.data.copy_(torch.cat([info_q["Wscale"], info_v["Wscale"], info_k["Wscale"]], dim=-1))
214
+ attn.Wscale_o.data.copy_(info_o["Wscale"])
215
+
216
+ use_simt_q = use_simt if use_simt_q is None else use_simt_q
217
+ use_simt_k = use_simt if use_simt_k is None else use_simt_k
218
+ use_simt_v = use_simt if use_simt_v is None else use_simt_v
219
+ use_simt_o = use_simt if use_simt_o is None else use_simt_o
220
+
221
+ if merge_qk:
222
+ to_merged, rest, target_proj, rest_proj = [info_q, info_k], [info_v], "qk_proj", ["v_proj"]
223
+ elif merge_kv:
224
+ to_merged, rest, target_proj, rest_proj = [info_k, info_v], [info_q], "kv_proj", ["q_proj"]
225
+ elif merge_qv:
226
+ to_merged, rest, target_proj, rest_proj = [info_q, info_v], [info_k], "qv_proj", ["k_proj"]
227
+ elif merge_qkv:
228
+ to_merged, rest, target_proj, rest_proj = [info_q, info_k, info_v], [], "qkv_proj", []
229
+ else:
230
+ to_merged, rest, target_proj, rest_proj = [], [info_q, info_k, info_v], "", ["q_proj", "k_proj", "v_proj"]
231
+
232
+ if merge_qk or merge_kv or merge_qv or merge_qkv:
233
+ if merge_kv: use_simt_merge = use_simt_k
234
+ elif merge_qk or merge_qv or merge_qkv: use_simt_merge = use_simt_q
235
+ else: raise ValueError
236
+
237
+ if "tcq" in to_merged[0]["quant_info"]["quantizer_str"]:
238
+ merged_linear = QTIPLinearTCQ
239
+ elif "sq" in to_merged[0]["quant_info"]["quantizer_str"] or "vq" in to_merged[0]["quant_info"]["quantizer_str"] or "ldlq" in to_merged[0]["quant_info"]["quantizer_str"]:
240
+ merged_linear = VQLinearPackTensorCore if not use_simt_merge else VQLinearPackSIMT
241
+ elif "tcomb" in to_merged[0]["quant_info"]["quantizer_str"]:
242
+ merged_linear = CombtLinearTCQ
243
+ elif "comb" in to_merged[0]["quant_info"]["quantizer_str"]:
244
+ merged_linear = CombLinearTCQ
245
+ merged = to_merged[0]['linear_info']
246
+ for info in to_merged[1:]:
247
+ merged = merged_linear.merge_infos(merged, info['linear_info'])
248
+ setattr(attn, target_proj, merged_linear.gen_layer_from_info(merged))
249
+ for info, proj in zip(rest, rest_proj):
250
+ if proj == "q_proj": cur_use_simt = use_simt_q
251
+ elif proj == "k_proj": cur_use_simt = use_simt_k
252
+ elif proj == "v_proj": cur_use_simt = use_simt_v
253
+ else: raise ValueError
254
+ setattr(attn, proj, make_linear(info, use_simt=cur_use_simt))
255
+ attn.o_proj = make_linear(info_o, use_simt=use_simt_o)
256
+ return attn
257
+
258
+ @staticmethod
259
+ def gen_layer_from_quantizer_str_and_key(config, layer_idx, quant_dir, quantizer_str_q, quantizer_str_k, quantizer_str_v, quantizer_str_o, key_q, key_k, key_v, key_o, merge_qk=False, merge_qv=False, merge_kv=False, merge_qkv=False, dummy=False, use_simt=False, use_simt_q=None, use_simt_k=None, use_simt_v=None, use_simt_o=None):
260
+ if not dummy:
261
+ info_q = torch.load(f"{quant_dir}/{quantizer_str_q}/{key_q}.pt")
262
+ info_k = torch.load(f"{quant_dir}/{quantizer_str_k}/{key_k}.pt")
263
+ info_v = torch.load(f"{quant_dir}/{quantizer_str_v}/{key_v}.pt")
264
+ info_o = torch.load(f"{quant_dir}/{quantizer_str_o}/{key_o}.pt")
265
+ else:
266
+ from lib.utils.mem_op import get_dummy_quant_results
267
+ from lib.config import MODEL_KEYS
268
+ model_key = MODEL_KEYS[config._name_or_path]
269
+ info_q = get_dummy_quant_results(model_key, f"self_attn.q_proj", quantizer_str_q)
270
+ info_k = get_dummy_quant_results(model_key, f"self_attn.k_proj", quantizer_str_k)
271
+ info_v = get_dummy_quant_results(model_key, f"self_attn.v_proj", quantizer_str_v)
272
+ info_o = get_dummy_quant_results(model_key, f"self_attn.o_proj", quantizer_str_o)
273
+
274
+ return IncoherentSdpaAttention.gen_layer_from_info(config, layer_idx, info_q, info_k, info_v, info_o, merge_qk=merge_qk, merge_qv=merge_qv, merge_kv=merge_kv, merge_qkv=merge_qkv, dummy=dummy, use_simt=use_simt, use_simt_q=use_simt_q, use_simt_k=use_simt_k, use_simt_v=use_simt_v, use_simt_o=use_simt_o)
275
+
276
+
277
+
278
+
279
+ class IncoherentMLP(nn.Module):
280
+ """
281
+ only support left only and unified SU for upgates.
282
+ """
283
+ def __init__(self, hidden_size, intermediate_size, hidden_act, merge_ug=False, bias=False, dtype=torch.float16):
284
+ super().__init__()
285
+ assert bias is False, "bias is not supported"
286
+ self.hidden_size = hidden_size
287
+ self.intermediate_size = intermediate_size
288
+ self.dtype = dtype
289
+
290
+ self.up_proj = None
291
+ self.gate_proj = None
292
+ self.ug_proj = None
293
+ self.down_proj = None
294
+
295
+ self.register_buffer("SU_ug", torch.ones(hidden_size, dtype=self.dtype))
296
+ self.register_buffer("SU_dp", torch.ones(intermediate_size, dtype=self.dtype))
297
+
298
+ hidden_had, hidden_K = get_hadK(hidden_size)
299
+ inter_had, inter_K = get_hadK(intermediate_size)
300
+
301
+ inter_had_T = inter_had.T.contiguous().cuda() if inter_had is not None else None
302
+ hidden_had_T = hidden_had.T.contiguous().cuda() if hidden_had is not None else None
303
+
304
+ self.register_buffer('Wscale_ug', torch.ones(intermediate_size * 2, dtype=self.dtype), persistent=False)
305
+ self.register_buffer('Wscale_dp', torch.ones(hidden_size, dtype=self.dtype), persistent=False)
306
+ self.register_buffer('had_left_ug_T', hidden_had_T, persistent=False)
307
+ self.register_buffer('had_left_dp_T', inter_had_T, persistent=False)
308
+
309
+ self.hidden_K = hidden_K
310
+ self.inter_K = inter_K
311
+
312
+ self.scale = 64.0
313
+
314
+ self.act_fn = ACT2FN[hidden_act]
315
+ self.merge_ug = merge_ug
316
+
317
+ def forward(self, input):
318
+ n = len(self.SU_ug)
319
+ x = input.view(-1, n).half()
320
+ x = self.compute_ug(x)
321
+ x = self.compute_dp(x)
322
+ return x.view(*input.shape[:-1], n).to(input.dtype)
323
+
324
+ def compute_ug(self, x):
325
+ x = matmul_hadU_cuda(x * self.SU_ug, self.had_left_ug_T, self.hidden_K) / self.scale
326
+ if self.merge_ug:
327
+ x = self.ug_proj(x.half()) * self.Wscale_ug * self.scale
328
+ x_up, x_gate = x.split(self.intermediate_size, dim=-1)
329
+ else:
330
+ x_up = self.up_proj(x.half()) * self.Wscale_ug[:self.intermediate_size] * self.scale
331
+ x_gate = self.gate_proj(x.half()) * self.Wscale_ug[self.intermediate_size:] * self.scale
332
+ x = self.act_fn(x_gate) * x_up
333
+ return x
334
+
335
+ def compute_dp(self, x):
336
+ x = matmul_hadU_cuda(x * self.SU_dp, self.had_left_dp_T, self.inter_K) / self.scale
337
+ x = self.down_proj(x.half()) * self.Wscale_dp * self.scale
338
+ return x
339
+
340
+ @staticmethod
341
+ def gen_layer_from_info(config, info_up, info_gate, info_down, merge_ug=False, dummy=False, use_simt=False, use_simt_u=None, use_simt_g=None, use_simt_d=None):
342
+ mlp = IncoherentMLP(
343
+ hidden_size=config.hidden_size,
344
+ intermediate_size=config.intermediate_size,
345
+ hidden_act=config.hidden_act,
346
+ merge_ug=merge_ug
347
+ )
348
+ if not dummy:
349
+ mlp.SU_ug.data.copy_(info_up["SU"])
350
+ mlp.SU_dp.data.copy_(info_down["SU"])
351
+ mlp.Wscale_ug.data.copy_(torch.cat([info_up["Wscale"], info_gate["Wscale"]], dim=-1))
352
+ mlp.Wscale_dp.data.copy_(info_down["Wscale"])
353
+
354
+ use_simt_u = use_simt if use_simt_u is None else use_simt_u
355
+ use_simt_g = use_simt if use_simt_g is None else use_simt_g
356
+ use_simt_d = use_simt if use_simt_d is None else use_simt_d
357
+
358
+ if merge_ug:
359
+ if "tcq" in info_up["quant_info"]["quantizer_str"]:
360
+ linear_info_ug = QTIPLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
361
+ mlp.ug_proj = QTIPLinearTCQ.gen_layer_from_info(linear_info_ug)
362
+ elif "vq" in info_up["quant_info"]["quantizer_str"] or "sq" in info_up["quant_info"]["quantizer_str"] or "ldlq" in info_up["quant_info"]["quantizer_str"]:
363
+ if use_simt_u:
364
+ linear_info_ug = VQLinearPackSIMT.merge_infos(info_up['linear_info'], info_gate['linear_info'])
365
+ mlp.ug_proj = VQLinearPackSIMT.gen_layer_from_info(linear_info_ug)
366
+ else:
367
+ linear_info_ug = VQLinearPackTensorCore.merge_infos(info_up['linear_info'], info_gate['linear_info'])
368
+ mlp.ug_proj = VQLinearPackTensorCore.gen_layer_from_info(linear_info_ug)
369
+ elif "tcomb" in info_up["quant_info"]["quantizer_str"]:
370
+ linear_info_ug = CombtLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
371
+ mlp.ug_proj = CombtLinearTCQ.gen_layer_from_info(linear_info_ug)
372
+ elif "comb" in info_up["quant_info"]["quantizer_str"]:
373
+ linear_info_ug = CombLinearTCQ.merge_infos(info_up['linear_info'], info_gate['linear_info'])
374
+ mlp.ug_proj = CombLinearTCQ.gen_layer_from_info(linear_info_ug)
375
+ else:
376
+ mlp.up_proj = make_linear(info_up, use_simt=use_simt_u)
377
+ mlp.gate_proj = make_linear(info_gate, use_simt=use_simt_g)
378
+ mlp.down_proj = make_linear(info_down, use_simt=use_simt_d)
379
+ return mlp
380
+
381
+ @staticmethod
382
+ def gen_layer_from_quantizer_str_and_key(config, quant_dir, quantizer_str_up, quantizer_str_gate, quantizer_str_down, key_up, key_gate, key_down, merge_ug=False, dummy=False, use_simt=False, use_simt_u=None, use_simt_g=None, use_simt_d=None):
383
+ if not dummy:
384
+ info_up = torch.load(f"{quant_dir}/{quantizer_str_up}/{key_up}.pt")
385
+ info_gate = torch.load(f"{quant_dir}/{quantizer_str_gate}/{key_gate}.pt")
386
+ info_down = torch.load(f"{quant_dir}/{quantizer_str_down}/{key_down}.pt")
387
+ else:
388
+ from lib.utils.mem_op import get_dummy_quant_results
389
+ from lib.config import MODEL_KEYS
390
+ model_key = MODEL_KEYS[config._name_or_path]
391
+ info_up = get_dummy_quant_results(model_key, f"mlp.up_proj", quantizer_str_up)
392
+ info_gate = get_dummy_quant_results(model_key, f"mlp.gate_proj", quantizer_str_gate)
393
+ info_down = get_dummy_quant_results(model_key, f"mlp.down_proj", quantizer_str_down)
394
+ return IncoherentMLP.gen_layer_from_info(config, info_up, info_gate, info_down, merge_ug, dummy=dummy, use_simt=use_simt, use_simt_u=use_simt_u, use_simt_g=use_simt_g, use_simt_d=use_simt_d)
395
+
396
+
397
+
398
+ class IncoherentLinear(nn.Module):
399
+ def __init__(
400
+ self,
401
+ in_features,
402
+ out_features,
403
+ hadU,
404
+ hadV,
405
+ bias=False,
406
+ dtype=torch.float16,
407
+ use_linear=True,
408
+ ):
409
+ super().__init__()
410
+
411
+ self.in_features = in_features
412
+ self.out_features = out_features
413
+ self.dtype = dtype
414
+
415
+ if use_linear:
416
+ self.linear = nn.Linear(in_features, out_features, bias=False, dtype=dtype)
417
+ else:
418
+ self.linear = None
419
+
420
+ if bias:
421
+ self.register_buffer('bias', torch.ones(out_features))
422
+ else:
423
+ self.bias = None
424
+
425
+ self.register_buffer("SU", torch.ones(in_features, dtype=self.dtype))
426
+ self.register_buffer("SV", torch.ones(out_features, dtype=self.dtype))
427
+
428
+ self.hadU = hadU
429
+ self.hadV = hadV
430
+ had_left, K_left = get_hadK(hadU)
431
+ had_right, K_right = get_hadK(hadV)
432
+ if had_left is not None:
433
+ had_left_T = had_left.T.contiguous().cuda()
434
+ else:
435
+ had_left_T = None
436
+ if had_right is not None:
437
+ had_right = had_right.cuda()
438
+ self.register_buffer('Wscale', torch.ones(out_features, dtype=self.dtype), persistent=True)
439
+ self.register_buffer('had_right', had_right, persistent=False)
440
+ self.register_buffer('had_left_T', had_left_T, persistent=False)
441
+ self.K_left = K_left
442
+ self.K_right = K_right
443
+ self.scale = 32.0
444
+
445
+ self.rot_info = "all"
446
+
447
+ self.skip_l = False
448
+ self.skip_r = False
449
+
450
+ def apply_rot_info(self):
451
+ if self.rot_info == "all":
452
+ self.skip_l = False
453
+ self.skip_r = False
454
+ elif self.rot_info == "skip_l":
455
+ self.skip_l = True
456
+ self.skip_r = False
457
+ elif self.rot_info == "skip_r":
458
+ self.skip_l = False
459
+ self.skip_r = True
460
+ elif self.rot_info == "skip_lr":
461
+ self.skip_l = True
462
+ self.skip_r = True
463
+ else:
464
+ raise ValueError(f"Invalid rot_info: {self.rot_info}")
465
+
466
+
467
+ def save_info(self, path, quant_info=None):
468
+ linear_info = self.linear._info()
469
+ info = {
470
+ "in_features": self.in_features,
471
+ "out_features": self.out_features,
472
+ "hadU": self.hadU,
473
+ "hadV": self.hadV,
474
+ "dtype": self.dtype,
475
+ "scale": self.scale,
476
+ "Wscale": self.Wscale.detach().cpu(),
477
+ "rot_info": self.rot_info,
478
+ "linear_info": linear_info,
479
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
480
+ "SU": self.SU.detach().cpu(),
481
+ "SV": self.SV.detach().cpu(),
482
+ "quant_info": quant_info,
483
+ }
484
+ torch.save(info, path)
485
+
486
+ def forward(self, input):
487
+ n, m = len(self.SU), len(self.SV)
488
+ x = input.view(-1, n).half()#.to(torch.float32)
489
+ if not self.skip_l:
490
+ x = x * self.SU
491
+ x = matmul_hadU_head_cuda(x, self.had_left_T, self.K_left, self.hadU) / self.scale
492
+ else:
493
+ # x = x * self.SU
494
+ x = x / self.scale
495
+ x = self.linear(x.half()) * self.Wscale#.float()
496
+ if not self.skip_r:
497
+ x = matmul_hadU_head_cuda(x, self.had_right, self.K_right, self.hadV)
498
+ x = x.to(self.SV.device) * (self.SV * self.scale)
499
+ else:
500
+ # x = x.to(self.SV.device) * (self.SV * self.scale)
501
+ x = x * self.scale
502
+
503
+ x = x.view(*input.shape[:-1], m).to(input.dtype)
504
+ if self.bias is not None:
505
+ x = x + self.bias
506
+ return x
507
+
508
+ @staticmethod
509
+ def gen_layer_from_info(info, merge_layers=False, dummy=False, use_simt=False):
510
+ layer = IncoherentLinear(
511
+ in_features=info["in_features"],
512
+ out_features=info["out_features"],
513
+ hadU=info["hadU"] if "hadU" in info else info["in_features"],
514
+ hadV=info["hadV"] if "hadV" in info else info["out_features"],
515
+ bias=info["bias"] is not None,
516
+ dtype=info["dtype"],
517
+ use_linear=False,
518
+ )
519
+ if not dummy:
520
+ if info["bias"] is not None:
521
+ layer.bias.data.copy_(info["bias"])
522
+ layer.SU.data.copy_(info["SU"])
523
+ layer.SV.data.copy_(info["SV"])
524
+ layer.Wscale.data.copy_(info["Wscale"])
525
+ if info["quant_info"] is not None:
526
+ if "tcq" in info["quant_info"]["quantizer_str"]:
527
+ layer.linear = QTIPLinearTCQ.gen_layer_from_info(info["linear_info"])
528
+ elif "sq" in info["quant_info"]["quantizer_str"] or "vq" in info["quant_info"]["quantizer_str"] or "ldlq" in info["quant_info"]["quantizer_str"]:
529
+ if use_simt:
530
+ layer.linear = VQLinearPackSIMT.gen_layer_from_info(info["linear_info"])
531
+ else:
532
+ layer.linear = VQLinearPackTensorCore.gen_layer_from_info(info["linear_info"])
533
+ elif "tcomb" in info["quant_info"]["quantizer_str"]:
534
+ layer.linear = CombtLinearTCQ.gen_layer_from_info(info["linear_info"])
535
+ elif "comb" in info["quant_info"]["quantizer_str"]:
536
+ layer.linear = CombLinearTCQ.gen_layer_from_info(info["linear_info"])
537
+ if "rot_info" in info["quant_info"]:
538
+ layer.rot_info = info["quant_info"]["rot_info"]
539
+ elif "rot_info" in info:
540
+ layer.rot_info = info["rot_info"]
541
+ else:
542
+ layer.rot_info = "all"
543
+ if merge_layers:
544
+ layer.apply_rot_info()
545
+ return layer
546
+
547
+ @staticmethod
548
+ def gen_layer_from_quantizer_str_and_key(config, quant_dir, quantizer_str, key, merge_layers=False, dummy=False, use_simt=False):
549
+ if not dummy:
550
+ info = torch.load(f"{quant_dir}/{quantizer_str}/{key}.pt")
551
+ else:
552
+ from lib.utils.mem_op import get_dummy_quant_results
553
+ from lib.config import MODEL_KEYS
554
+ model_key = MODEL_KEYS[config._name_or_path]
555
+ layer_id = key.split("_")[0]
556
+ layer_key = key.replace(f"{layer_id}_", "")
557
+ info = get_dummy_quant_results(model_key, f"{layer_key}", quantizer_str)
558
+ return IncoherentLinear.gen_layer_from_info(info, merge_layers=merge_layers, dummy=dummy, use_simt=use_simt)
559
+
560
+
561
+ def calc_kurtosis(W):
562
+ # W: (-1, n), ||W[i]|| = 1
563
+ W = W.to(torch.float64)
564
+ return W.pow(4).mean(-1) - 3.0
565
+
566
+ def calc_skewness(W):
567
+ # W: (-1, n), ||W[i]|| = 1
568
+ W = W.to(torch.float64)
569
+ return W.pow(3).mean(-1)
570
+
571
+ def linear_to_incoherent(linear, hadU, hadV, SU=None, SV=None, lnorm=None, rot_info="all"):
572
+ dtype_ = torch.float32
573
+ dtype = linear.weight.data.dtype
574
+ device = linear.weight.device
575
+ inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype)
576
+ if SU is None:
577
+ SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device)
578
+ if SV is None:
579
+ SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device)
580
+ if lnorm is not None:
581
+ lnorm = lnorm.to(device).to(dtype_)
582
+
583
+ if linear.bias is not None:
584
+ inc_linear.bias.data.copy_(linear.bias)
585
+
586
+ W = linear.weight.data.to(dtype_)
587
+ Wr = (W.to(torch.float64).to(device) @ torch.diag(lnorm).to(torch.float64)).to(dtype_).to(device) if lnorm is not None else W
588
+ if hadU != linear.in_features or hadV != linear.out_features:
589
+ Wr = matmul_hadUt_head(matmul_hadUt_head(Wr.T.to(device) * SV, hadV).T * SU, hadU)
590
+ else:
591
+ Wr = matmul_hadUt(matmul_hadUt(Wr.T.to(device) * SV).T * SU)
592
+ # Wscale = Wr.square().mean().sqrt()
593
+ Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_)
594
+
595
+ Wr = Wr / Wscale
596
+
597
+ inc_linear.SU.data.copy_(SU.to(inc_linear.SU.dtype))
598
+ # inc_linear.SV.data.copy_((SV * Wscale).to(inc_linear.SV.dtype))
599
+ inc_linear.SV.data.copy_((SV).to(inc_linear.SV.dtype))
600
+ inc_linear.Wscale.data.copy_(Wscale.view(-1))
601
+ inc_linear.linear.weight.data.copy_(Wr.to(inc_linear.linear.weight.dtype))
602
+ inc_linear = inc_linear.to(dtype).to(device)
603
+ inc_linear.rot_info = rot_info
604
+ inc_linear.apply_rot_info()
605
+
606
+
607
+ # anal weight
608
+ kurt = calc_kurtosis(inc_linear.linear.weight.data)
609
+ skew = calc_skewness(inc_linear.linear.weight.data)
610
+ # print(kurt.pow(2).mean(), kurt.mean(), kurt.std(), kurt.max(), kurt.min())
611
+ # print pretty
612
+ print(f"E[kurt^2]: {kurt.pow(2).mean():.4f}, E[kurt]: {kurt.mean():.4f}, std[kurt]: {kurt.std():.4f}, max[kurt]: {kurt.max():.4f}, min[kurt]: {kurt.min():.4f}")
613
+ print(f"E[skew^2]: {skew.pow(2).mean():.4f}, E[skew]: {skew.mean():.4f}, std[skew]: {skew.std():.4f}, max[skew]: {skew.max():.4f}, min[skew]: {skew.min():.4f}")
614
+ kurt_stats = {
615
+ "kurt_pow2_mean": kurt.pow(2).mean(),
616
+ "kurt_mean": kurt.mean(),
617
+ "kurt_std": kurt.std(),
618
+ "kurt_max": kurt.max(),
619
+ "kurt_min": kurt.min(),
620
+ "skew_pow2_mean": skew.pow(2).mean(),
621
+ "skew_mean": skew.mean(),
622
+ "skew_std": skew.std(),
623
+ "skew_max": skew.max(),
624
+ "skew_min": skew.min(),
625
+ }
626
+ return inc_linear, kurt_stats
627
+
628
+ if __name__ == "__main__":
629
+ linear = nn.Linear(4096, 4096, bias=True, dtype=torch.float16).cuda()
630
+ # linear.weight.data = linear.weight.data * 5 + 4
631
+ inc_linear = linear_to_incoherent(linear)
632
+
633
+ ran = torch.randn(4096, 4096, dtype=torch.float16).cuda()
634
+ orig = linear(ran)
635
+ inc = inc_linear(ran)
636
+
637
+ print((orig - inc).pow(2).mean() / orig.pow(2).mean())
638
+
639
+
lib/linear/quantized_linear.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from lib.codebook import bitshift
8
+ from lib.utils import (clean, dtype_from_str, get_hadK, has_kernel,
9
+ matmul_hadU_cuda)
10
+
11
+
12
+ class QuantizedLinear(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ in_features,
17
+ out_features,
18
+ td_x,
19
+ td_y,
20
+ L, # trellis window
21
+ K, # bpw
22
+ V, # vq dim
23
+ tlut_bits, # tunable LUT bits
24
+ decode_mode,
25
+ bias=False,
26
+ dtype=torch.float16,
27
+ mode='eval',
28
+ use_prev_kernel=True,
29
+ grad_ckpt=False,
30
+ ):
31
+ super().__init__()
32
+
33
+ self.in_features = in_features
34
+ self.out_features = out_features
35
+ self.td_x = td_x
36
+ self.td_y = td_y
37
+ self.L = L
38
+ self.K = K
39
+ self.V = V
40
+ self.tlut_bits = tlut_bits
41
+ self.decode_mode = decode_mode
42
+ self.register_buffer('rcp', torch.tensor(0))
43
+ # TP rank, not used unless rcp != 0
44
+ self.register_buffer('tp_rank', torch.tensor(8))
45
+ self.dtype = dtype
46
+ # packed into int16
47
+ self.register_buffer(
48
+ 'trellis',
49
+ torch.zeros((out_features // td_x) * (in_features // td_y),
50
+ math.ceil((td_x * td_y) * K / 16),
51
+ dtype=torch.int16))
52
+
53
+ if decode_mode in ['lut', 'quantlut', 'quantlut_sym']:
54
+ self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
55
+ V,
56
+ dtype=torch.float16),
57
+ requires_grad=False)
58
+ else:
59
+ self.tlut = None
60
+
61
+ if bias:
62
+ self.register_buffer('bias', torch.ones(out_features))
63
+ else:
64
+ self.bias = None
65
+
66
+ self.register_buffer("SU", torch.ones(in_features, dtype=self.dtype))
67
+ self.register_buffer("SV", torch.ones(out_features,
68
+ dtype=torch.float32))
69
+
70
+ self.built_codebook_class = False
71
+ self.built_graph = False
72
+
73
+ had_left, K_left = get_hadK(in_features)
74
+ had_right, K_right = get_hadK(out_features)
75
+ self.register_buffer('had_left', had_left, persistent=False)
76
+ self.register_buffer('had_right', had_right, persistent=False)
77
+ self.K_left = K_left
78
+ self.K_right = K_right
79
+ self.mode = mode
80
+ self.use_prev_kernel = use_prev_kernel
81
+ self.grad_ckpt = grad_ckpt
82
+ self.has_kernel = has_kernel(decode_mode, L, K, V, tlut_bits, td_x,
83
+ td_y)
84
+
85
+ def forward(self, input):
86
+ if self.grad_ckpt:
87
+ return self.ckpt_forward(input)
88
+ return self.no_ckpt_forward(input)
89
+
90
+ def ckpt_forward(self, input):
91
+ return torch.utils.checkpoint.checkpoint(self.no_ckpt_forward,
92
+ input,
93
+ use_reentrant=True)
94
+
95
+ def no_ckpt_forward(self, input):
96
+ if not self.built_codebook_class:
97
+ self.codebook_class = bitshift.BitshiftLinear(
98
+ self.td_x,
99
+ self.td_y,
100
+ self.L,
101
+ self.K,
102
+ self.V,
103
+ self.tlut_bits,
104
+ self.decode_mode,
105
+ dtype=self.dtype,
106
+ tlut=self.tlut,
107
+ has_kernel=self.has_kernel)
108
+
109
+ rcp = self.rcp.item()
110
+ del self.rcp
111
+ self.rcp = rcp
112
+
113
+ if self.mode == 'eval':
114
+ pass
115
+ elif self.mode == 'train-recons':
116
+ if not self.has_kernel:
117
+ self.packed_trellis = self.trellis.cpu()
118
+ unpacked_trellis = self.codebook_class.cb.unpack_trellis(
119
+ self.trellis, self.td_x * self.td_y)
120
+ self.trellis = unpacked_trellis
121
+ clean()
122
+ elif self.mode == 'train-fixW':
123
+ self.codebook_class.cache_hatW(self.trellis, self.had_left,
124
+ self.had_right, self.K_left,
125
+ self.K_right, len(self.SV),
126
+ len(self.SU), self.rcp,
127
+ self.tp_rank)
128
+ self.trellis = self.trellis.cpu()
129
+ del self.had_left, self.had_right, self.K_left, self.K_right
130
+ clean()
131
+ self.had_left = None
132
+ self.had_right = None
133
+ self.K_left = None
134
+ self.K_right = None
135
+ else:
136
+ raise Exception
137
+
138
+ self.built_codebook_class = True
139
+
140
+ result = self.codebook_class(input,
141
+ self.trellis,
142
+ self.SU,
143
+ self.SV,
144
+ self.had_left,
145
+ self.had_right,
146
+ self.K_left,
147
+ self.K_right,
148
+ self.rcp,
149
+ self.tp_rank,
150
+ mode=self.mode,
151
+ use_prev_kernel=self.use_prev_kernel) + 0
152
+ if self.bias is not None:
153
+ return result + self.bias
154
+ return result
lib/linear/rotation.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.utils import matmul_hadU_cuda, matmul_hadUt_cuda, matmul_hadUt_head, matmul_hadU_head, get_hadK
2
+ import torch.nn as nn
3
+
4
+ class RotateWeights(nn.Module):
5
+ def __init__(self, had_dim_U, had_dim_V, SU=None, SV=None):
6
+ super().__init__()
7
+ self.had_dim_U = had_dim_U
8
+ self.had_dim_V = had_dim_V
9
+ self.SU = SU
10
+ self.SV = SV
11
+
12
+ self.had_left_U, self.K_left_U = get_hadK(had_dim_U)
13
+ self.had_left_V, self.K_left_V = get_hadK(had_dim_V)
14
+
15
+ def apply_weights(self, weights):
16
+ return matmul_hadUt_head(matmul_hadUt_head(weights.T, self.had_left_U, self.K_left_U), self.had_left_V, self.K_left_V)
lib/linear/tcq_linear.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class QTIPLinearTCQ(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_features,
9
+ out_features,
10
+ td_x,
11
+ td_y,
12
+ L, # trellis window
13
+ KV, # bpw
14
+ V, # vq dim
15
+ tlut_bits,
16
+ bias=False,
17
+ dtype=torch.float16,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.td_x = td_x
24
+ self.td_y = td_y
25
+ self.L = L
26
+ self.KV = KV
27
+ self.V = V
28
+ self.tlut_bits = tlut_bits
29
+ self.dtype = dtype
30
+ # packed into int16
31
+ self.register_buffer(
32
+ 'trellis',
33
+ torch.zeros((out_features // td_x) * (in_features // td_y),
34
+ math.ceil((td_x * td_y) * KV / 16 / V),
35
+ dtype=torch.int16))
36
+
37
+ self.tlut = nn.Parameter(torch.zeros(2**tlut_bits,
38
+ V,
39
+ dtype=torch.float16),
40
+ requires_grad=False)
41
+
42
+ if bias:
43
+ self.register_buffer('bias', torch.ones(out_features))
44
+ else:
45
+ self.bias = None
46
+
47
+ def _info(self):
48
+ info = {
49
+ "in_features": self.in_features,
50
+ "out_features": self.out_features,
51
+ "td_x": self.td_x,
52
+ "td_y": self.td_y,
53
+ "L": self.L,
54
+ "KV": self.KV,
55
+ "V": self.V,
56
+ 'tlut_bits': self.tlut_bits,
57
+ "dtype": self.dtype,
58
+ "trellis": self.trellis.detach().cpu(),
59
+ "tlut": self.tlut.detach().cpu().half(),
60
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
61
+ }
62
+ return info
63
+
64
+ def forward(self, inp, **kwargs):
65
+ x = inp.view(-1, self.in_features)#.to(torch.float32)
66
+ bs = x.shape[0]
67
+ m, k = self.out_features, self.in_features
68
+ if bs <= 8:
69
+ wrapper = getattr(
70
+ torch.ops.ours_lib,
71
+ f"decompress_gemm_tcq_{m}_{bs}_{k}_{self.tlut_bits}_{self.KV}")
72
+
73
+ x = wrapper(self.trellis, x, self.tlut)
74
+
75
+ else:
76
+ wrapper = getattr(
77
+ torch.ops.ours_lib,
78
+ f"decompress_tcq_{self.tlut_bits}_{self.KV}"
79
+ )
80
+ # dq = wrapper(self.trellis, self.tlut).to(x.dtype)
81
+ # x = x @ dq.T
82
+ with torch.no_grad():
83
+ dq = wrapper(self.trellis, self.tlut, m, k) #.to(x.dtype)
84
+ x = (x.to(dq.dtype) @ dq.T)#.to(x.dtype)
85
+ return x.view(*inp.shape[:-1], m).to(inp.dtype)
86
+
87
+ @staticmethod
88
+ def gen_layer_from_info(info):
89
+ layer = QTIPLinearTCQ(info["in_features"], info["out_features"], info["td_x"], info["td_y"], info["L"], info["KV"], info["V"], info["tlut_bits"], info["bias"] is not None, info["dtype"])
90
+ layer.trellis.data.copy_(info["trellis"])
91
+ layer.tlut.data.copy_(info["tlut"])
92
+ if info["bias"] is not None:
93
+ layer.bias.data.copy_(info["bias"])
94
+ return layer
95
+
96
+ @staticmethod
97
+ def merge_infos(info1, info2):
98
+ assert info1["in_features"] == info2["in_features"]
99
+ assert info1["td_x"] == info2["td_x"]
100
+ assert info1["td_y"] == info2["td_y"]
101
+ assert info1["L"] == info2["L"]
102
+ assert info1["KV"] == info2["KV"]
103
+ assert info1["V"] == info2["V"]
104
+ assert info1["tlut_bits"] == info2["tlut_bits"]
105
+ if not torch.allclose(info1["tlut"], info2["tlut"], atol=1e-4):
106
+ print("warning: tlut is not close. it is unexpected behavior if you do not use dummy quantizers.")
107
+ assert info1["bias"] is None and info2["bias"] is None
108
+ assert info1["dtype"] == info2["dtype"]
109
+ info = {}
110
+ info["in_features"] = info1["in_features"]
111
+ info["out_features"] = info1["out_features"] + info2["out_features"]
112
+ info["td_x"] = info1["td_x"]
113
+ info["td_y"] = info1["td_y"]
114
+ info["L"] = info1["L"]
115
+ info["KV"] = info1["KV"]
116
+ info["V"] = info1["V"]
117
+ info["tlut_bits"] = info1["tlut_bits"]
118
+ info["bias"] = None
119
+ info["dtype"] = info1["dtype"]
120
+ info["trellis"] = torch.cat([info1["trellis"], info2["trellis"]], dim=0)
121
+ info["tlut"] = info1["tlut"]
122
+ return info
lib/linear/vq_linear.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class VQLinearPackTensorCore(nn.Module):
6
+ def __init__(self, in_features, out_features, lut_bits, vec_sz=2, bias=False, dtype=torch.half):
7
+ super().__init__()
8
+
9
+ self.in_features = in_features
10
+ self.out_features = out_features
11
+ self.lut_bits = lut_bits
12
+ self.dtype = dtype
13
+ self.vec_sz = vec_sz
14
+
15
+ self.register_buffer(
16
+ 'qweight',
17
+ torch.randint(0, 4, (out_features, lut_bits*in_features // 32 // vec_sz), dtype=torch.int32, device='cuda')
18
+ )
19
+
20
+ self.register_buffer(
21
+ 'lut',
22
+ torch.randn((2 ** lut_bits, vec_sz), dtype=self.dtype, device='cuda')
23
+ )
24
+
25
+ if bias:
26
+ self.register_buffer(
27
+ "bias",
28
+ torch.randn((out_features,), dtype=self.dtype, device='cuda')
29
+ )
30
+ else:
31
+ self.bias = None
32
+
33
+ self.vq_type = f"vq{self.vec_sz}" if self.vec_sz > 1 else "sq_dup" if lut_bits <= 4 else "sq"
34
+
35
+ def _info(self):
36
+ info = {
37
+ "in_features": self.in_features,
38
+ "out_features": self.out_features,
39
+ "lut_bits": self.lut_bits,
40
+ "dtype": self.dtype,
41
+ "vec_sz": self.vec_sz,
42
+ "qweight": self.qweight.detach().cpu(),
43
+ "lut": self.lut.detach().cpu().half(),
44
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
45
+ }
46
+ return info
47
+
48
+ def forward(self, inp, **kwargs):
49
+ x = inp.view(-1, self.in_features)
50
+ bs = x.shape[0]
51
+ m, k = self.out_features, self.in_features
52
+ if bs <= 8:
53
+ wrapper = getattr(
54
+ torch.ops.ours_lib,
55
+ f"decompress_gemm_{m}_{bs}_{k}_{self.lut_bits}_{self.vq_type}"
56
+ )
57
+
58
+ x = wrapper(self.qweight, x, self.lut)
59
+ else:
60
+ wrapper = getattr(
61
+ torch.ops.ours_lib,
62
+ f"decompress_{self.lut_bits}_{self.vq_type}"
63
+ )
64
+ with torch.no_grad():
65
+ dq = wrapper(self.qweight, self.lut, m, k)
66
+ x = (x.to(dq.dtype) @ dq.T)
67
+
68
+ return x.view(*inp.shape[:-1], m).to(inp.dtype)
69
+
70
+ @staticmethod
71
+ def gen_layer_from_info(info):
72
+ layer = VQLinearPackTensorCore(info["in_features"], info["out_features"], info["lut_bits"], info["vec_sz"], info["bias"] is not None, info["dtype"])
73
+ layer.qweight.data.copy_(info["qweight"])
74
+ layer.lut.data.copy_(info["lut"])
75
+ if info["bias"] is not None:
76
+ layer.bias.data.copy_(info["bias"])
77
+ return layer
78
+
79
+ @staticmethod
80
+ def merge_infos(info1, info2):
81
+ assert info1["in_features"] == info2["in_features"]
82
+ assert info1["lut_bits"] == info2["lut_bits"]
83
+ assert info1["vec_sz"] == info2["vec_sz"]
84
+ assert info1["bias"] is None and info2["bias"] is None
85
+ assert info1["dtype"] == info2["dtype"]
86
+ if not torch.allclose(info1["lut"], info2["lut"], atol=1e-4):
87
+ print("warning: lut is not close. it is unexpected behavior if you do not use dummy quantizers.")
88
+ info = {}
89
+ info["in_features"] = info1["in_features"]
90
+ info["out_features"] = info1["out_features"] + info2["out_features"]
91
+ info["lut_bits"] = info1["lut_bits"]
92
+ info["vec_sz"] = info1["vec_sz"]
93
+ info["bias"] = None
94
+ info["dtype"] = info1["dtype"]
95
+ info["qweight"] = torch.cat([info1["qweight"], info2["qweight"]], dim=0)
96
+ info["lut"] = info1["lut"]
97
+ return info
98
+
99
+ class VQLinearPackSIMT(nn.Module):
100
+ def __init__(self, in_features, out_features, lut_bits, vec_sz=1, bias=False, dtype=torch.half):
101
+ super().__init__()
102
+ self.in_features = in_features
103
+ self.out_features = out_features
104
+ self.lut_bits = lut_bits
105
+ self.dtype = dtype
106
+ self.vec_sz = vec_sz
107
+
108
+ self.register_buffer(
109
+ 'qweight',
110
+ torch.randint(0, 4, (out_features, lut_bits*in_features // 32 // vec_sz), dtype=torch.int32, device='cuda')
111
+ )
112
+
113
+ self.register_buffer(
114
+ 'lut',
115
+ torch.randn((2 ** lut_bits, vec_sz), dtype=self.dtype, device='cuda')
116
+ )
117
+
118
+ if bias:
119
+ self.register_buffer(
120
+ "bias",
121
+ torch.randn((out_features,), dtype=self.dtype, device='cuda')
122
+ )
123
+ else:
124
+ self.bias = None
125
+
126
+ def _info(self):
127
+ info = {
128
+ "in_features": self.in_features,
129
+ "out_features": self.out_features,
130
+ "lut_bits": self.lut_bits,
131
+ "dtype": self.dtype,
132
+ "vec_sz": self.vec_sz,
133
+ "qweight": self.qweight.detach().cpu(),
134
+ "lut": self.lut.detach().cpu().half(),
135
+ "bias": self.bias.detach().cpu() if self.bias is not None else None,
136
+ }
137
+ return info
138
+
139
+ def forward(self, inp, **kwargs):
140
+ x = inp.view(-1, 1, self.in_features)
141
+ bs = x.shape[0]
142
+ m, k = self.out_features, self.in_features
143
+ if bs <= 8:
144
+ if self.vec_sz == 1:
145
+ wrapper = getattr(
146
+ torch.ops.ours_lib,
147
+ f"sq_pack_gemm_simt"
148
+ )
149
+ x = wrapper(x, self.qweight, self.lut, self.lut_bits)
150
+ else:
151
+ wrapper = getattr(
152
+ torch.ops.ours_lib,
153
+ f"vq_pack_gemm_simt_{bs}_{self.vec_sz}_{self.lut_bits}"
154
+ )
155
+ x = wrapper(x, self.qweight, self.lut)
156
+ else:
157
+ if self.vec_sz == 1:
158
+ wrapper = getattr(
159
+ torch.ops.ours_lib,
160
+ f"sq_pack_dequant_simt"
161
+ )
162
+ with torch.no_grad():
163
+ dq = wrapper(self.qweight, self.lut, self.lut_bits, m, k)
164
+ else:
165
+ wrapper = getattr(
166
+ torch.ops.ours_lib,
167
+ f"vq_pack_dequant_simt_{self.vec_sz}_{self.lut_bits}"
168
+ )
169
+ with torch.no_grad():
170
+ dq = wrapper(self.qweight, self.lut, m, k)
171
+ x = (x.to(dq.dtype) @ dq.T)
172
+ return x.view(*inp.shape[:-1], m).to(inp.dtype)
173
+
174
+ @staticmethod
175
+ def gen_layer_from_info(info):
176
+ layer = VQLinearPackSIMT(info["in_features"], info["out_features"], info["lut_bits"], info["vec_sz"], info["bias"] is not None, info["dtype"])
177
+ if info["vec_sz"] <= 2:
178
+ from lib.quantizer.quant_op import convert_tensor_core_to_simt
179
+ # qweight is stored in tensor core format in default.
180
+ # we should convert it to simt format.
181
+ converted_qweight = convert_tensor_core_to_simt(info["qweight"], info["out_features"], info["in_features"], info["vec_sz"], info["lut_bits"], code_n=info["lut_bits"])
182
+ layer.qweight.data.copy_(converted_qweight)
183
+ else:
184
+ layer.qweight.data.copy_(info["qweight"])
185
+ layer.lut.data.copy_(info["lut"])
186
+ if info["bias"] is not None:
187
+ layer.bias.data.copy_(info["bias"])
188
+ return layer
189
+
190
+ @staticmethod
191
+ def merge_infos(info1, info2):
192
+ assert info1["in_features"] == info2["in_features"]
193
+ assert info1["lut_bits"] == info2["lut_bits"]
194
+ assert info1["vec_sz"] == info2["vec_sz"]
195
+ assert info1["bias"] is None and info2["bias"] is None
196
+ assert info1["dtype"] == info2["dtype"]
197
+ if not torch.allclose(info1["lut"], info2["lut"], atol=1e-4):
198
+ print("warning: lut is not close. it is unexpected behavior if you do not use dummy quantizers.")
199
+ info = {}
200
+ info["in_features"] = info1["in_features"]
201
+ info["out_features"] = info1["out_features"] + info2["out_features"]
202
+ info["lut_bits"] = info1["lut_bits"]
203
+ info["vec_sz"] = info1["vec_sz"]
204
+ info["bias"] = None
205
+ info["dtype"] = info1["dtype"]
206
+ info["qweight"] = torch.cat([info1["qweight"], info2["qweight"]], dim=0)
207
+ info["lut"] = info1["lut"]
208
+ return info
lib/quantizer/__pycache__/comb_quant.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
lib/quantizer/__pycache__/nuq_op.cpython-311.pyc ADDED
Binary file (23.7 kB). View file
 
lib/quantizer/__pycache__/pack_op.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.1.nbc ADDED
Binary file (43.8 kB). View file
 
lib/quantizer/__pycache__/pack_op.general_pack_32-88.py311.nbi ADDED
Binary file (1.59 kB). View file
 
lib/quantizer/__pycache__/pack_op.pack_32-242.py311.1.nbc ADDED
Binary file (39.4 kB). View file
 
lib/quantizer/__pycache__/pack_op.pack_32-242.py311.nbi ADDED
Binary file (1.62 kB). View file
 
lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.1.nbc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b3053fa289922289be50c3c5f21b9101f8426050f437286e34141275fa858e8
3
+ size 116854
lib/quantizer/__pycache__/pack_op.pack_codes_32-186.py311.nbi ADDED
Binary file (1.59 kB). View file
 
lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.1.nbc ADDED
Binary file (67.1 kB). View file
 
lib/quantizer/__pycache__/pack_op.pack_for_sq_pack_kernel-287.py311.nbi ADDED
Binary file (1.71 kB). View file
 
lib/quantizer/__pycache__/quant_op.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
lib/quantizer/__pycache__/tcq_quant.cpython-311.pyc ADDED
Binary file (13.2 kB). View file
 
lib/quantizer/__pycache__/vq_quant.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
lib/quantizer/__pycache__/vq_quant_ldlq.cpython-311.pyc ADDED
Binary file (6.71 kB). View file
 
lib/quantizer/comb_quant.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.quantizer.quant_op import load_hessian, load_group_hessian
3
+ from lib.quantizer.tcq_quant import qtip_quantize_mat, linear_to_incoherent_for_tcq, Args
4
+ from lib.linear import CombLinearTCQ, CombtLinearTCQ
5
+ from lib.codebook.bitshift import bitshift_codebook
6
+ from lib.utils import clean
7
+ import time
8
+ from lib.utils import block_LDL
9
+ from lib.algo.ldlq import LDLQ_combt
10
+
11
+ def pack_trellis(Qidxs, td_x, td_y, cb, m, n, KV, V):
12
+ Qidxs = Qidxs.cpu()
13
+ packed = cb.pack_trellis(
14
+ Qidxs.reshape(m // td_x, td_x, n // td_y,
15
+ td_y // V).transpose(1, 2).reshape(
16
+ -1, td_x * td_y // V))
17
+
18
+ packed_8 = packed.view(torch.uint8).view(-1, 2)
19
+ packed_4 = torch.cat([packed_8.unsqueeze(-1) & (2 ** 4 - 1), (packed_8.unsqueeze(-1) & (2 ** 8 - 2 ** 4)) >> 4], dim=-1).view(-1, 4).flip(
20
+ (-1, ))
21
+
22
+ packed_4 = packed_4.reshape(m // 16 // 2, 2, n // 16 // 2, 2, 16 * 16 // 8,
23
+ KV).permute(0, 2, 4, 3, 1, 5).flip(
24
+ (-1, )).contiguous().flatten()
25
+ packed_8 = torch.sum(packed_4.view(-1, 2) * torch.Tensor([[1, 2 ** 4]]).to(torch.uint8), dim=-1).to(torch.uint8).contiguous()
26
+ packed = packed_8.view(torch.int16).reshape(packed.shape).cuda()
27
+ return packed
28
+
29
+ def combt_quantize_mat(Wr, HRr, Wscale, cb1, cb2, td_x=16, td_y=16, KV=(4,5), V=2, use_hess=True):
30
+ (m, n) = Wr.shape
31
+ Wr = Wr.to(torch.float64)
32
+ HRr_orig = HRr.clone()
33
+ gs = HRr.shape[-1]
34
+ LRrs = []
35
+ diag = torch.arange(n, device=HRr.device)
36
+ if not use_hess:
37
+ eye = torch.eye(n, device=Wr.device, dtype=torch.float64)
38
+ LRr, D = block_LDL(eye, td_y)
39
+ LRr[diag, diag] = 0
40
+ LRrs.append(LRr)
41
+ else:
42
+ for i in range(gs):
43
+ LRr, D = block_LDL(HRr[:,:,i], td_y)
44
+ LRr[diag, diag] = 0
45
+ LRrs.append(LRr)
46
+
47
+ args = Args(td_x, td_y, V)
48
+
49
+ Qidxs_list = []
50
+ hatWr_list = []
51
+ for i in range(gs):
52
+ cur_Wr = Wr[m // gs * i:m // gs * (i+1)]
53
+ hatWr, Qidxs = LDLQ_combt(cur_Wr, LRrs[i], cb1.cuda(), cb2.cuda(), args, for_kernel=True)
54
+ hatWr_list.append(hatWr)
55
+ Qidxs_list.append(Qidxs)
56
+ torch._dynamo.reset()
57
+
58
+ hatWr = torch.cat(hatWr_list, dim=0)
59
+ Qidxs = torch.cat(Qidxs_list, dim=0)
60
+ assert hatWr.shape == Wr.shape, f"hatWr.shape {hatWr.shape} != Wr.shape {Wr.shape}"
61
+
62
+ packed1 = pack_trellis(Qidxs[:, :n//2//V].contiguous(), td_x, td_y, cb1, m, n//2, KV[0], V)
63
+ packed2 = pack_trellis(Qidxs[:, n//2//V:].contiguous(), td_x, td_y, cb2, m, n//2, KV[1], V)
64
+
65
+ Wr *= Wscale.reshape(-1, 1)
66
+ hatWr *= Wscale.reshape(-1, 1)
67
+
68
+ orig_err = (Wr - hatWr).pow(2).mean()
69
+ err = orig_err / Wr.pow(2).mean()
70
+ print(
71
+ f'err {err.item()} orig_err {orig_err.item()}'
72
+ )
73
+ quant_info = {
74
+ "quantizer": "combt_ldlq",
75
+ "td_x": td_x,
76
+ "td_y": td_y,
77
+ "KV": KV,
78
+ "V": V,
79
+ "tlut_bits": cb1.tlut_bits,
80
+ "use_hess": use_hess,
81
+ "orig_err": orig_err.item(),
82
+ "err": err.item(),
83
+ }
84
+ return packed1, packed2, hatWr, quant_info
85
+
86
+ def inc_linear_to_inc_combt_linear(inc_linear, HRr, cb1, cb2, td_x=16, td_y=16, in_part=(2048, 2048), KV=(3, 4), V=2, scale_override=0.9, use_hess=True):
87
+ Wr = (inc_linear.linear.weight.data * scale_override).to(HRr.dtype)
88
+ Wscale = inc_linear.Wscale.data / scale_override
89
+ inc_linear.Wscale.data.copy_(Wscale)
90
+ assert in_part[0] + in_part[1] == Wr.shape[1], "in_part is not correct"
91
+ assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
92
+
93
+ packed1, packed2, hatWr, quant_info = combt_quantize_mat(Wr, HRr, Wscale, cb1, cb2, td_x=td_x, td_y=td_y, KV=KV, V=V, use_hess=use_hess)
94
+ torch._dynamo.reset()
95
+ out_features, in_features = Wr.shape
96
+ comb_linear = CombtLinearTCQ(
97
+ in_features,
98
+ out_features,
99
+ td_x=td_x,
100
+ td_y=td_y,
101
+ in_part=in_part,
102
+ L=16,
103
+ KV=KV,
104
+ V=V,
105
+ tlut_bits=cb1.tlut_bits,
106
+ bias=inc_linear.bias is not None,
107
+ dtype=inc_linear.dtype,
108
+ )
109
+
110
+ comb_linear.trellis1.data.copy_(packed1)
111
+ comb_linear.trellis2.data.copy_(packed2)
112
+ comb_linear.tlut.data.copy_(cb1.tlut)
113
+ inc_linear.linear = comb_linear
114
+ return inc_linear, quant_info
115
+
116
+
117
+ def inc_linear_to_inc_comb_linear(inc_linear, HRr, cb1, cb2, td_x=16, td_y=16, out_part=(2048, 2048), KV=(3, 4), V=2, scale_override=0.9, use_hess=True):
118
+ Wr = (inc_linear.linear.weight.data * scale_override).to(HRr.dtype)
119
+ Wscale = inc_linear.Wscale.data / scale_override
120
+ inc_linear.Wscale.data.copy_(Wscale)
121
+ assert out_part[0] + out_part[1] == Wr.shape[0], "out_part is not correct"
122
+ assert len(HRr.shape) == 3 and HRr.shape[0] == HRr.shape[1] and HRr.shape[-1] == 1, f"support only none-grouped hessian but shape: {HRr.shape}"
123
+
124
+ packed1, hatWr1, quant_info1 = qtip_quantize_mat(Wr[:out_part[0]], HRr, Wscale[:out_part[0]], cb1, td_x=td_x, td_y=td_y, KV=KV[0], V=V, use_hess=use_hess)
125
+ torch._dynamo.reset()
126
+ packed2, hatWr2, quant_info2 = qtip_quantize_mat(Wr[out_part[0]:], HRr, Wscale[out_part[0]:], cb2, td_x=td_x, td_y=td_y, KV=KV[1], V=V, use_hess=use_hess)
127
+ torch._dynamo.reset()
128
+ out_features, in_features = Wr.shape
129
+ comb_linear = CombLinearTCQ(
130
+ in_features,
131
+ out_features,
132
+ td_x=td_x,
133
+ td_y=td_y,
134
+ out_part=out_part,
135
+ L=16,
136
+ KV=KV,
137
+ V=V,
138
+ tlut_bits=cb1.tlut_bits,
139
+ bias=inc_linear.bias is not None,
140
+ dtype=inc_linear.dtype,
141
+ )
142
+
143
+ comb_linear.trellis1.data.copy_(packed1)
144
+ comb_linear.trellis2.data.copy_(packed2)
145
+ comb_linear.tlut.data.copy_(cb1.tlut)
146
+
147
+ hatWr = torch.cat([hatWr1, hatWr2], dim=0).to(HRr.dtype)
148
+ orig_err = (Wr - hatWr).pow(2).mean()
149
+ err = orig_err / Wr.pow(2).mean()
150
+
151
+ quant_info = {
152
+ "quantizer": "comb_ldlq",
153
+ "td_x": td_x,
154
+ "td_y": td_y,
155
+ "KV": KV,
156
+ "V": V,
157
+ "use_hess": use_hess,
158
+ "orig_err": orig_err.item(),
159
+ "err": err.item(),
160
+ "quant_info1": quant_info1,
161
+ "quant_info2": quant_info2,
162
+ }
163
+
164
+ inc_linear.linear = comb_linear
165
+ return inc_linear, quant_info
166
+
167
+ def linear_to_comb_linear(target_layer, hess_path, cb1, cb2, scale_override=0.9, out_part=(2048, 2048), KV=[3, 4], V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
168
+ assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
169
+ t0 = time.time()
170
+ out_features, in_features = target_layer.weight.shape
171
+ if ghess_key == "":
172
+ HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
173
+ else:
174
+ HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
175
+ layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb1, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
176
+ HRr = HRr.cuda()
177
+ layer = layer.cuda()
178
+ layer, quant_info = inc_linear_to_inc_comb_linear(layer, HRr, cb1, cb2, scale_override=1.0, td_x=16, td_y=16, out_part=out_part, KV=KV, V=V, use_hess=use_hess)
179
+ quant_info["scale_override"] = scale_override
180
+ quant_info["hess_path"] = hess_path
181
+ quant_info["time"] = time.time() - t0
182
+
183
+ return layer.to(torch.float16), quant_info
184
+
185
+ def linear_to_combt_linear(target_layer, hess_path, cb1, cb2, scale_override=0.9, in_part=(2048, 2048), KV=[3, 4], V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
186
+ assert torch.allclose(cb1.tlut, cb2.tlut), "cb1 and cb2 must have the same tlut"
187
+ t0 = time.time()
188
+ out_features, in_features = target_layer.weight.shape
189
+ if ghess_key == "":
190
+ HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
191
+ else:
192
+ HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
193
+ layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb1, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
194
+ HRr = HRr.cuda()
195
+ layer = layer.cuda()
196
+ layer, quant_info = inc_linear_to_inc_combt_linear(layer, HRr, cb1, cb2, scale_override=1.0, td_x=16, td_y=16, in_part=in_part, KV=KV, V=V, use_hess=use_hess)
197
+ quant_info["scale_override"] = scale_override
198
+ quant_info["hess_path"] = hess_path
199
+ quant_info["time"] = time.time() - t0
200
+
201
+ return layer.to(torch.float16), quant_info
lib/quantizer/nuq_op.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ from tqdm import tqdm
6
+ from typing import Tuple
7
+
8
+ def get_progress_bar(total: int, desc: str):
9
+ return tqdm(
10
+ total=total,
11
+ desc=desc,
12
+ bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
13
+ )
14
+
15
+ @torch.no_grad()
16
+ def objective_function(
17
+ W: torch.Tensor,
18
+ H: torch.Tensor,
19
+ P: torch.Tensor,
20
+ C: torch.Tensor,
21
+ ) -> torch.Tensor:
22
+ """
23
+ Calculate the quantization error (objective value).
24
+
25
+ Args:
26
+ W: Weight matrix (row_count * group_count, group_size)
27
+ H: Hessian matrix (blk_num, group_size, group_size)
28
+ P: Assignment matrix (row_count * group_count, group_size//vec_sz, n_cluster)
29
+ C: Centroid matrix (n_cluster, vec_sz)
30
+
31
+ Returns:
32
+ Objective value (scalar)
33
+ """
34
+
35
+ device = torch.device("cuda")
36
+ P, C = P.to(device), C.to(device)
37
+ W_hat = torch.einsum('ijc,ck->ijk', P, C) # Shape: (row_count * group_count, group_size//vec_sz, vec_sz)
38
+ W_hat = W_hat.view(W_hat.shape[0], -1) # Shape: (row_count * group_count, group_size)
39
+ delta_w = W_hat - W
40
+
41
+ blk_num = H.shape[0]
42
+ blk_size = W.shape[0] // blk_num
43
+
44
+ delta_w = delta_w.reshape(blk_num, blk_size, delta_w.shape[-1])
45
+ objective_value = torch.einsum('nij,njk,nik->i', delta_w, H, delta_w)
46
+ total_error = objective_value.mean()
47
+
48
+ return total_error
49
+
50
+
51
+ @torch.no_grad()
52
+ def parallel_objective_function_sub(
53
+ W: torch.Tensor, # Shape: (b, g_cd)
54
+ quadratic: torch.Tensor, # Shape: (g_cd, g_cd)
55
+ linear: torch.Tensor, # Shape: (b, g_cd)
56
+ W_hat_options: torch.Tensor, # Shape: (b, g_cd, num_options)
57
+ ) -> torch.Tensor:
58
+ """
59
+ Calculate the quantization error (objective value), and return the list of errors for each options.
60
+ W_hat is a tensor with possible options concatenated along the last dimension.
61
+
62
+ Args:
63
+ W: Weight matrix (b, g_cd)
64
+ quadratic: torch.Tensor, # Shape: (g_cd, g_cd)
65
+ linear: torch.Tensor, # Shape: (b, g_cd)
66
+ W_hat_options: torch.Tensor, # Shape: (b, g_cd, num_options)
67
+
68
+ Returns:
69
+ Possible objective values for each options (b, num_options)
70
+ """
71
+ device = torch.device("cuda")
72
+ W_hat_options = W_hat_options.to(device)
73
+ b, g_cd, num_options = W_hat_options.shape
74
+
75
+ delta_w_g = W_hat_options - W.unsqueeze(2).expand(-1, -1, num_options)
76
+
77
+ quadratic_term = torch.einsum('jk,ijp,ikp->ip', quadratic, delta_w_g, delta_w_g)
78
+ linear_term = torch.einsum('ij,ijp->ip', linear, delta_w_g)
79
+ total_error_quad = quadratic_term + linear_term
80
+ return total_error_quad
81
+
82
+
83
+
84
+ def update_batch_P(
85
+ W: torch.Tensor, # Shape: (b, group_size)
86
+ H: torch.Tensor, # Shape: (blk_num, group_size, group_size)
87
+ P: torch.Tensor, # Shape: (b, group_size // vec_sz, n_cluster)
88
+ C: torch.Tensor, # Shape: (n_cluster, vec_sz)
89
+ iteration: int,
90
+ g_cd: int, # Number of weights to update at a time
91
+ cd_cycles: int,
92
+ verbose: bool = False,
93
+ ):
94
+ device = torch.device("cuda")
95
+ C = C.to(device)
96
+ C_ = C.unsqueeze(0).expand(P.shape[0], -1, -1)
97
+ assignments_prev = P.argmax(dim=-1).to(device) # Shape: (b, group_size // vec_sz)
98
+ b, d = assignments_prev.shape
99
+ n_cluster, vec_sz = C_.size(1), C_.size(2)
100
+ assert H.shape[0] == 1
101
+ H_ = H[0]
102
+
103
+ assignments = assignments_prev.clone()
104
+ update_size = cd_cycles * d
105
+
106
+ for update_start_idx in range(0, update_size, g_cd):
107
+ start_idx = update_start_idx % d
108
+ end_idx = min(start_idx + g_cd, d)
109
+ indices = torch.arange(start_idx * vec_sz, end_idx * vec_sz, device=device)
110
+ indices_assignments = torch.arange(start_idx, end_idx, device=device)
111
+ # Generate all possible assignments for the group
112
+ num_options = n_cluster ** g_cd
113
+ if num_options > 1e6:
114
+ print(f"Skipping group starting at index {start_idx} due to large number of assignments ({num_options}).")
115
+ continue
116
+
117
+ # Create all possible assignments for the group
118
+ from itertools import product
119
+ assignments_list = list(product(range(n_cluster), repeat=g_cd))
120
+ assignments_array = torch.tensor(assignments_list, device=device).T # Shape: (g_cd, num_options)
121
+ assignments_array = assignments_array.unsqueeze(0).expand(b, -1, -1) # Shape: (b, g_cd, num_options, vec_sz)
122
+
123
+ # Creating options for g_cd weights
124
+ C_expanded = C_.unsqueeze(1).expand(-1, g_cd, -1, -1) # Shape: (b, g_cd, n_cluster, vec_sz)
125
+ W_g_hat_options = torch.gather(C_expanded, dim=2, index=assignments_array.unsqueeze(-1).expand(-1, -1, -1, vec_sz)) # Shape: (b, g_cd, num_options, vec_sz)
126
+
127
+ # Gathering original quantized weights and compute linear & quadratic terms
128
+ # Expand C and gather original weights
129
+ C_expanded_org = C_.unsqueeze(1).expand(-1, d, -1, -1) # Shape: (b, d, n_cluster, vec_sz)
130
+ W_hat_org = torch.gather(C_expanded_org, dim=2, index=assignments.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, vec_sz)).squeeze(2) # Shape: (b, d, vec_sz)
131
+
132
+ # Compute deltas
133
+ delta_w_org = W_hat_org.view(b, -1) - W # Shape: (b, group_size)
134
+
135
+ # Get indices and slices
136
+ notg_indices = torch.cat([
137
+ torch.arange(0, start_idx * vec_sz, device=device),
138
+ torch.arange(end_idx * vec_sz, d * vec_sz, device=device)
139
+ ])
140
+
141
+ H_g_notg = H_[indices, :][:, notg_indices] # Shape: (g_cd * vec_sz, group_size - g_cd * vec_sz)
142
+
143
+ delta_w_org_notg = delta_w_org[:, notg_indices].to(device) # Shape: (b, group_size - g_cd * vec_sz)
144
+
145
+ # Compute quadratic and linear terms
146
+ quadratic = H_[indices, :][:, indices] # Shape: (g_cd * vec_sz, g_cd * vec_sz)
147
+ linear = 2 * torch.einsum('gd,id->ig', H_g_notg, delta_w_org_notg) # Shape: (b, g_cd * vec_sz)
148
+ W_g = W[:, indices] # Shape: (b, g_cd * vec_sz)
149
+
150
+ W_g_hat_options = W_g_hat_options.permute(0, 1, 3, 2).view(b, g_cd * vec_sz, num_options) # Shape: (b, g_cd * vec_sz, num_options)
151
+ # Objective function computation
152
+ cur_obj_value = parallel_objective_function_sub(W_g, quadratic, linear, W_g_hat_options) # Shape: (b, num_options)
153
+
154
+ # Update assignments
155
+ min_obj, argmin_obj = cur_obj_value.min(dim=1, keepdim=True)
156
+ expanded_argmin_obj = argmin_obj.unsqueeze(1).expand(-1, g_cd, -1).to(device)
157
+ assignments[:, indices_assignments] = assignments_array.gather(dim=2, index=expanded_argmin_obj).squeeze(-1) # Shape: (row_count * group_count, g_cd)
158
+
159
+ num_changed = (assignments_prev != assignments).sum().item()
160
+ total_assignments = assignments_prev.numel()
161
+ percentage_changed = num_changed / total_assignments * 100
162
+ if verbose:
163
+ logging.info(f"Percentage of assignments changed: {percentage_changed:.2f}%%")
164
+
165
+ # Convert assignments to one-hot encoding to create new P
166
+ P = torch.zeros((b, d, n_cluster), dtype=torch.float32, device=assignments.device)
167
+ P.scatter_(2, assignments.long().unsqueeze(-1), 1.0)
168
+
169
+ return P
170
+
171
+
172
+ def update_P(
173
+ W: torch.Tensor, # Shape: (row_count * group_count, group_size)
174
+ H: torch.Tensor, # Shape: (blk_num, group_size, group_size)
175
+ P: torch.Tensor, # Shape: (row_count * group_count, group_size//vec_sz, n_cluster)
176
+ C: torch.Tensor, # Shape: (n_cluster, vec_sz)
177
+ iteration: int,
178
+ g_cd: int = 1,
179
+ cd_cycles: int = 4,
180
+ ):
181
+ n_cluster = C.shape[0]
182
+ batch_output_size = 4096 # * 32 // max(32, n_cluster)
183
+ device = torch.device("cuda")
184
+ updated_P_list = []
185
+
186
+ pb = get_progress_bar((W.size(0) - 1) // batch_output_size + 1, f"Updating P (cd_cycles={cd_cycles})")
187
+ for out_idx in range(0, W.size(0), batch_output_size):
188
+ torch.cuda.reset_peak_memory_stats() # Reset memory stats at start of iteration
189
+
190
+ W_batch = W[out_idx:out_idx+batch_output_size].to(device)
191
+ P_batch = P[out_idx:out_idx+batch_output_size].to(device)
192
+ C_batch = C.to(device)
193
+
194
+ verbose = False # (out_idx == 0)
195
+
196
+ updated_P_batch = update_batch_P(W_batch, H, P_batch, C_batch, iteration, g_cd=g_cd, cd_cycles=cd_cycles, verbose=verbose).cpu()
197
+ updated_P_list.append(updated_P_batch)
198
+ pb.update(1)
199
+ pb.close()
200
+
201
+ # Log max CUDA memory usage
202
+ P = torch.cat(updated_P_list, dim=0)
203
+ return P
204
+
205
+ def project_to_pd(H, eps=1e-2):
206
+ H_sym = (H + H.T) / 2
207
+ eigenvalues, eigenvectors = torch.linalg.eigh(H_sym)
208
+ eigenvalues = torch.clamp(eigenvalues, min=eps)
209
+ H_spd = eigenvectors @ torch.diag(eigenvalues) @ eigenvectors.T
210
+ H_spd = (H_spd + H_spd.T) / 2
211
+ H_spd = H_spd.to(H.dtype)
212
+ return H_spd
213
+
214
+ import torch
215
+ def kron_with_identity_vec(P: torch.Tensor, vec_sz: int) -> torch.Tensor:
216
+ B, d, c = P.shape
217
+ I_vec_sz = torch.eye(vec_sz, vec_sz).to(P.device)
218
+ P_expanded = P.unsqueeze(-1).unsqueeze(-1) # (B, d, c, 1, 1)
219
+ I_expanded = I_vec_sz.unsqueeze(0).unsqueeze(0).unsqueeze(0) # (1, 1, 1, vec_sz, vec_sz)
220
+
221
+ out = P_expanded * I_expanded # (B, d, c, vec_sz, vec_sz)
222
+ out = out.permute(0, 1, 3, 2, 4) # (B, d, vec_sz, c, vec_sz)
223
+ out = out.reshape(B, d * vec_sz, c * vec_sz)
224
+ return out
225
+
226
+ def update_C(
227
+ W: torch.Tensor, # (row, gs)
228
+ H: torch.Tensor, # (1, gs, gs)
229
+ P: torch.Tensor, # (row, gs//vec_sz, n_cluster)
230
+ C: torch.Tensor, # (n_cluster, vec_sz)
231
+ batch_size: int = 256
232
+ ):
233
+ device = W.device
234
+ dtype = W.dtype
235
+
236
+ L = torch.linalg.cholesky(H[0]) # (gs, gs)
237
+ LT = L.transpose(0, 1) # (gs, gs)
238
+
239
+ row, gs = W.shape
240
+ n_cluster, vec_sz = C.shape
241
+
242
+ A = torch.zeros(n_cluster * vec_sz, n_cluster * vec_sz, device=device, dtype=dtype)
243
+ b = torch.zeros(n_cluster * vec_sz, device=device, dtype=dtype)
244
+
245
+ for start in range(0, row, batch_size):
246
+ end = min(start + batch_size, row)
247
+
248
+ # (B, gs // vec_sz, n_cluster)
249
+ P_chunk = P[start:end].to(device)
250
+ # (B, gs)
251
+ W_chunk = W[start:end].to(device)
252
+ B = P_chunk.shape[0]
253
+
254
+ # kronecker product with identity.
255
+ P_chunk_expanded = kron_with_identity_vec(P_chunk, vec_sz) # Shape: (B, gs, n_cluster * vec_sz)
256
+
257
+ X_temp = torch.einsum('ij,bjk->bik', LT, P_chunk_expanded) # Shape: (B, gs, n_cluster * vec_sz)
258
+ W_temp = torch.einsum('ij,bj->bi', LT, W_chunk) # Shape: (B, gs)
259
+
260
+ A += torch.einsum('bik,bil->kl', X_temp, X_temp) # Shape: (n_cluster * vec_sz, n_cluster * vec_sz)
261
+ b += torch.einsum('bik,bi->k', X_temp, W_temp) # Shape: (n_cluster * vec_sz)
262
+
263
+ C_flat = torch.linalg.solve(A, b) # Shape: (n_cluster * vec_sz)
264
+ C = C_flat.view(n_cluster, vec_sz) # Shape: (n_cluster, vec_sz)
265
+ return C
266
+
267
+ def train_least_squares(
268
+ W: np.ndarray, # Shape: (row_count * group_count, group_size)
269
+ init_P: np.ndarray, # Shape: (row_count * group_count, group_size//vec_sz, n_cluster)
270
+ init_centroids: np.ndarray, # Shape: (n_cluster, vec_sz)
271
+ H: np.ndarray, # Shape: (blk_num, group_size, group_size)
272
+ num_iterations: int = 3,
273
+ cd_cycles: int = 4,
274
+ eig_threshold: float = 1e-3,
275
+ ) -> Tuple[np.ndarray, np.ndarray]:
276
+ device = torch.device("cuda")
277
+
278
+ P = torch.tensor(init_P, dtype=torch.float32, device="cpu")
279
+ C = torch.tensor(init_centroids, dtype=torch.float32, device="cpu")
280
+ W = torch.tensor(W, dtype=torch.float32).to(device)
281
+ H = torch.tensor(H, dtype=torch.float32).to(device)
282
+
283
+ # eigenvalues = torch.linalg.eigvalsh(H)
284
+ # for i in range(eigenvalues.shape[0]):
285
+ # top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
286
+ # logging.info(f"{i+1}-th H has Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}, Projecting to PD with eps=1e-6 for numerical stability")
287
+ # H[i] = project_to_pd(H[i], eps=1e-6)
288
+
289
+ # eps = eig_threshold * 10
290
+ # while not torch.all(eigenvalues[i] > eig_threshold):
291
+ # top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
292
+ # logging.info(f"{i+1}-th H not PD, Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}, Projecting to PD with eps={eps}")
293
+ # H[i] = project_to_pd(H[i], eps=eps)
294
+ # eigenvalues = torch.linalg.eigvalsh(H)
295
+ # eps *= 10
296
+ # top_3_and_bottom_3 = [round(eig.item(), 2) for eig in torch.cat([eigenvalues[i][:3], eigenvalues[i][-3:]])]
297
+ # logging.info(f"{i+1}-th H PD, Eigenvalues (top 3 and bottom 3): {top_3_and_bottom_3}")
298
+ diag = torch.arange(H.shape[1], device=device)
299
+ for i in range(H.shape[0]):
300
+ avg_diag = torch.mean(torch.diag(H[i]))
301
+ damp, prev_damp = 1e-7, 0.
302
+ while True:
303
+ try:
304
+ torch.linalg.cholesky(H[i])
305
+ logging.info(f"{i+1}-th H is PD, dampening factor={prev_damp:.2e}")
306
+ break
307
+ except Exception as e:
308
+ print(e)
309
+ logging.info(f"{i+1}-th H is not PD, try dampening with factor={damp:.2e}")
310
+ H[i, diag, diag] += (damp - prev_damp) * avg_diag
311
+ prev_damp = damp
312
+ damp *= 10
313
+ if damp > 1e0:
314
+ exit()
315
+
316
+ best_obj_value = objective_function(W, H, P, C).item()
317
+ best_P, best_C = P.detach().cpu().clone(), C.detach().cpu().clone()
318
+ logging.info(f"Initial objective: {best_obj_value:.6f}")
319
+
320
+ log_dict = {"objective": [], "iteration": []}
321
+ log_dict["objective"].append(best_obj_value)
322
+ log_dict["iteration"].append(0)
323
+
324
+ for iteration in range(num_iterations):
325
+ start_time = time.time()
326
+
327
+ ######### Update P #########
328
+ if iteration > 0:
329
+ P = update_P(W, H, P, C, iteration, cd_cycles=cd_cycles)
330
+
331
+ # Compute objective value for logging
332
+ obj_value = objective_function(W, H, P, C).item()
333
+ logging.info(f"Iteration {iteration + 1} (P update): Objective: {obj_value:.4f}")
334
+ log_dict["objective"].append(obj_value)
335
+ log_dict["iteration"].append(iteration + 1)
336
+
337
+
338
+ ######### Update C #########
339
+ C = update_C(W, H, P, C)
340
+
341
+ # Check if the objective value improved
342
+ current_obj_value = objective_function(W, H, P, C).item()
343
+ log_dict["objective"].append(current_obj_value)
344
+ log_dict["iteration"].append(iteration + 1)
345
+ if current_obj_value < best_obj_value:
346
+ best_obj_value = current_obj_value
347
+ best_P, best_C = P.detach().cpu().clone(), C.detach().cpu().clone()
348
+ logging.info(f"Iteration {iteration + 1} (C update): Objective: {current_obj_value:.4f} | Improved and using this one.")
349
+ else:
350
+ logging.info(f"Iteration {iteration + 1} (C update): Objective: {current_obj_value:.4f} | Not improved. Using previous best values.")
351
+ P, C = best_P, best_C
352
+ break # Early stopping
353
+
354
+ end_time = time.time()
355
+
356
+ logging.info(f"Iteration {iteration + 1} / {num_iterations} completed. "
357
+ f"Update time: {end_time - start_time:.2f} sec")
358
+
359
+ end_time = time.time()
360
+ logging.info(f"Least squares training time: {end_time - start_time:.2f} seconds")
361
+
362
+ P = P.detach().cpu()
363
+ C = C.detach().cpu().to(torch.float32)
364
+
365
+ return P, C, log_dict
366
+
367
+
368
+ def test():
369
+ from lib.utils.kmeans import fit_kmeans
370
+ # set seed
371
+ torch.manual_seed(0)
372
+ np.random.seed(0)
373
+ torch.cuda.manual_seed(0)
374
+ torch.backends.cudnn.deterministic = True
375
+ torch.backends.cudnn.benchmark = False
376
+
377
+ vec_sz = 4
378
+ lut_size = (1 << (2 * vec_sz))
379
+ W = torch.randn(4096, 4096)
380
+ H = torch.randn(4096, 4096)
381
+
382
+ rand_data = torch.randn(10000, vec_sz)
383
+ C = fit_kmeans(rand_data, lut_size)[0]
384
+ print("C", C.shape)
385
+
386
+ W_vec = W.view(4096, 4096 // vec_sz, vec_sz)
387
+ W_vec = W_vec.unsqueeze(0) # Shape: (1, 4096, 2048, 2)
388
+ C_ = C.unsqueeze(1).unsqueeze(1) # Shape: (16, 1, 1, 2)
389
+ diff = W_vec - C_ # Shape: (16, 4096, 2048, 2)
390
+ dist_sq = diff.pow(2).sum(-1) # Shape: (16, 4096, 2048)
391
+ idx = dist_sq.argmin(dim=0) # Shape: (4096, 2048)
392
+ init_P = torch.zeros(4096, 4096 // vec_sz, lut_size)
393
+ init_P.scatter_(2, idx.unsqueeze(-1), 1)
394
+
395
+ H = H @ H.T
396
+ H = H + 1e-6 * torch.eye(4096, 4096)
397
+ H = H.unsqueeze(0)
398
+
399
+ P, C, log_dict = train_least_squares(
400
+ W=W.numpy(),
401
+ init_P=init_P.numpy(),
402
+ init_centroids=C.numpy(),
403
+ H=H.numpy(),
404
+ num_iterations=10,
405
+ cd_cycles=4,
406
+ eig_threshold=1e-3,
407
+ )
408
+
409
+ for i in range(len(log_dict["objective"])):
410
+ logging.info(f"Iteration {log_dict['iteration'][i]}: Objective: {log_dict['objective'][i]:.4f}")
411
+ print("P", P.shape)
412
+ print("C", C.shape)
413
+
414
+ # recons
415
+ W_hat = torch.einsum('ijc,ck->ijk', P, C)
416
+ W_hat = W_hat.view(W_hat.shape[0], -1)
417
+ err = (W - W_hat).pow(2).mean()
418
+ print("err", err.item())
419
+
420
+ dWHdW = (W - W_hat) @ H[0] @ (W - W_hat).T
421
+ err_tr = torch.trace(dWHdW) / H.shape[1]
422
+ print("err_tr", err_tr.item())
423
+
424
+ if __name__ == "__main__":
425
+ logging.basicConfig(
426
+ level=logging.INFO,
427
+ format='[%(levelname)s] %(asctime)s - %(message)s',
428
+ datefmt='%Y-%m-%d %H:%M:%S'
429
+ )
430
+ test()
431
+
lib/quantizer/pack_op.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+ import numpy as np
3
+ @numba.njit(cache=True)
4
+ def general_pack(unpacked, nbits, codeT, code_n):
5
+ '''
6
+ sequentially packing codes into codeT type array.
7
+ args:
8
+ unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
9
+ nbits: int
10
+ codeT: np.dtype
11
+
12
+ return: out_code (code_n,) dtype: codeT
13
+ '''
14
+ n_unpacked = unpacked.shape[0]
15
+ codeT_sz = codeT.itemsize * 8
16
+ assert n_unpacked * nbits / codeT_sz == code_n, "code_n must be equal to n_unpacked * nbits / codeT_sz"
17
+ out_code = np.zeros(code_n, dtype=codeT)
18
+ for i in range(n_unpacked):
19
+ val = codeT(unpacked[i])
20
+ offset = i * nbits
21
+ wIndex = offset // codeT_sz
22
+ bIndex = offset % codeT_sz
23
+ out_code[wIndex] |= (val << bIndex) & np.iinfo(codeT).max
24
+
25
+ bits_in_word = codeT_sz - bIndex
26
+ if bits_in_word < nbits:
27
+ upper = val >> bits_in_word
28
+ out_code[wIndex + 1] |= upper & np.iinfo(codeT).max
29
+
30
+ return out_code
31
+
32
+ @numba.njit(cache=True)
33
+ def general_pack_8(unpacked, nbits, code_n):
34
+ '''
35
+ sequentially packing codes into codeT type array.
36
+ args:
37
+ unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
38
+ nbits: int
39
+ codeT: np.dtype
40
+
41
+ return: out_code (code_n,) dtype: codeT
42
+ '''
43
+ n_unpacked = unpacked.shape[0]
44
+ assert n_unpacked * nbits / 8 == code_n, "code_n must be equal to n_unpacked * nbits / 8"
45
+ out_code = np.zeros(code_n, dtype=np.uint8)
46
+ for i in range(n_unpacked):
47
+ val = unpacked[i]
48
+ offset = i * nbits
49
+ wIndex = offset // 8
50
+ bIndex = offset % 8
51
+ out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint8).max
52
+
53
+ bits_in_word = 8 - bIndex
54
+ if bits_in_word < nbits:
55
+ upper = val >> bits_in_word
56
+ out_code[wIndex + 1] |= upper & np.iinfo(np.uint8).max
57
+ return out_code
58
+
59
+ @numba.njit(cache=True)
60
+ def general_pack_16(unpacked, nbits, code_n):
61
+ '''
62
+ sequentially packing codes into codeT type array.
63
+ args:
64
+ unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
65
+ nbits: int
66
+ codeT: np.dtype
67
+
68
+ return: out_code (code_n,) dtype: codeT
69
+ '''
70
+ n_unpacked = unpacked.shape[0]
71
+ assert n_unpacked * nbits / 16 == code_n, "code_n must be equal to n_unpacked * nbits / 16"
72
+ out_code = np.zeros(code_n, dtype=np.uint16)
73
+ for i in range(n_unpacked):
74
+ val = unpacked[i]
75
+ offset = i * nbits
76
+ wIndex = offset // 16
77
+ bIndex = offset % 16
78
+ out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint16).max
79
+
80
+ bits_in_word = 16 - bIndex
81
+ if bits_in_word < nbits:
82
+ upper = val >> bits_in_word
83
+ out_code[wIndex + 1] |= upper & np.iinfo(np.uint16).max
84
+
85
+ return out_code
86
+
87
+
88
+ @numba.njit(cache=True)
89
+ def general_pack_32(unpacked, nbits, code_n):
90
+ '''
91
+ sequentially packing codes into codeT type array.
92
+ args:
93
+ unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
94
+ nbits: int
95
+ codeT: np.dtype
96
+
97
+ return: out_code (code_n,) dtype: codeT
98
+ '''
99
+ n_unpacked = unpacked.shape[0]
100
+ assert n_unpacked * nbits / 32 == code_n, "code_n must be equal to n_unpacked * nbits / 32"
101
+ out_code = np.zeros(code_n, dtype=np.uint32)
102
+ for i in range(n_unpacked):
103
+ val = unpacked[i]
104
+ offset = i * nbits
105
+ wIndex = offset // 32
106
+ bIndex = offset % 32
107
+ out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint32).max
108
+
109
+ bits_in_word = 32 - bIndex
110
+ if bits_in_word < nbits:
111
+ upper = val >> bits_in_word
112
+ out_code[wIndex + 1] |= upper & np.iinfo(np.uint32).max
113
+
114
+ return out_code
115
+
116
+ @numba.njit(cache=True)
117
+ def general_pack_64(unpacked, nbits, code_n):
118
+ '''
119
+ sequentially packing codes into codeT type array.
120
+ args:
121
+ unpacked: np.int (n_unpacked,) each entry is 0 .. 2 ** nbits - 1
122
+ nbits: int
123
+ codeT: np.dtype
124
+
125
+ return: out_code (code_n,) dtype: codeT
126
+ '''
127
+ n_unpacked = unpacked.shape[0]
128
+ assert n_unpacked * nbits / 64 == code_n, "code_n must be equal to n_unpacked * nbits / 64"
129
+ out_code = np.zeros(code_n, dtype=np.uint64)
130
+ for i in range(n_unpacked):
131
+ val = unpacked[i]
132
+ offset = i * nbits
133
+ wIndex = offset // 64
134
+ bIndex = offset % 64
135
+ out_code[wIndex] |= (val << bIndex) & np.iinfo(np.uint64).max
136
+
137
+ bits_in_word = 64 - bIndex
138
+ if bits_in_word < nbits:
139
+ upper = val >> bits_in_word
140
+ out_code[wIndex + 1] |= upper & np.iinfo(np.uint64).max
141
+
142
+ return out_code
143
+
144
+ @numba.njit(cache=True)
145
+ def pack_codes_8(codes, nbits, code_n):
146
+ '''
147
+ sequentially packing codes into codeT type array.
148
+ args:
149
+ codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
150
+ nbits: int
151
+ codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
152
+ code_n: int
153
+
154
+ return:
155
+ packed_codes (-1, code_n) dtype: codeT
156
+ '''
157
+ n_samples = codes.shape[0]
158
+ n_unpacked = code_n * 8 // nbits
159
+ packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint8)
160
+ for i in range(n_samples // n_unpacked):
161
+ unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
162
+ packed_codes[i] = general_pack_8(unpacked, nbits, code_n)
163
+ return packed_codes
164
+
165
+ @numba.njit(cache=True)
166
+ def pack_codes_16(codes, nbits, code_n):
167
+ '''
168
+ sequentially packing codes into codeT type array.
169
+ args:
170
+ codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
171
+ nbits: int
172
+ codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
173
+ code_n: int
174
+
175
+ return:
176
+ packed_codes (-1, code_n) dtype: codeT
177
+ '''
178
+ n_samples = codes.shape[0]
179
+ n_unpacked = code_n * 16 // nbits
180
+ packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint16)
181
+ for i in range(n_samples // n_unpacked):
182
+ unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
183
+ packed_codes[i] = general_pack_16(unpacked, nbits, code_n)
184
+ return packed_codes
185
+
186
+ @numba.njit(cache=True)
187
+ def pack_codes_32(codes, nbits, code_n):
188
+ '''
189
+ sequentially packing codes into codeT type array.
190
+ args:
191
+ codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
192
+ nbits: int
193
+ codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
194
+ code_n: int
195
+
196
+ return:
197
+ packed_codes (-1, code_n) dtype: codeT
198
+ '''
199
+ n_samples = codes.shape[0]
200
+ n_unpacked = code_n * 32 // nbits
201
+ packed_codes = np.zeros((n_samples // n_unpacked, code_n), dtype=np.uint32)
202
+ for i in range(n_samples // n_unpacked):
203
+ unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
204
+ packed_codes[i] = general_pack_32(unpacked, nbits, code_n)
205
+ return packed_codes
206
+
207
+ @numba.njit(cache=True)
208
+ def pack_codes_64(codes, nbits, code_n):
209
+ '''
210
+ sequentially packing codes into codeT type array.
211
+ args:
212
+ codes: np.int (n_samples,) each entry is 0 .. 2 ** nbits - 1
213
+ nbits: int
214
+ codeT: np.dtype (uint64 or uint32 or uint16 or uint8)
215
+ code_n: int
216
+
217
+ return:
218
+ packed_codes (-1, code_n) dtype: codeT
219
+ '''
220
+ n_samples = codes.shape[0]
221
+ n_unpacked = code_n * 64 // nbits
222
+ packed_codes = np.zeros((int(n_samples // n_unpacked), code_n), dtype=np.uint64)
223
+ for i in range(n_samples // n_unpacked):
224
+ unpacked = codes[i * n_unpacked: (i + 1) * n_unpacked]
225
+ packed_codes[i] = general_pack_64(unpacked, nbits, code_n)
226
+ return packed_codes
227
+
228
+ def pack_codes(codes, nbits, code_n, codeT_sz):
229
+ if codeT_sz == 8:
230
+ return pack_codes_8(codes, nbits, code_n)
231
+ elif codeT_sz == 16:
232
+ return pack_codes_16(codes, nbits, code_n)
233
+ elif codeT_sz == 32:
234
+ return pack_codes_32(codes, nbits, code_n)
235
+ elif codeT_sz == 64:
236
+ return pack_codes_64(codes, nbits, code_n)
237
+ else:
238
+ raise ValueError(f"Unsupported codeT_sz: {codeT_sz}")
239
+
240
+
241
+
242
+ @numba.njit(cache=True)
243
+ def pack_32(cluster_idx: np.ndarray, nbits: int) -> np.ndarray:
244
+ """
245
+ NumPy 버전의 pack_32 함수.
246
+
247
+ Parameters
248
+ ----------
249
+ cluster_idx : np.ndarray of shape (32,), dtype=int
250
+ 길이 32의 정수 배열. (C 코드에서 const int* cluster_idx와 동일 역할)
251
+ nbits : int
252
+ 각 정수를 몇 비트로 저장할지.
253
+
254
+ Returns
255
+ -------
256
+ out_code : np.ndarray of shape (out_size,), dtype=np.uint32
257
+ 32개의 값(각각 nbits 비트)으로 구성된 연속 비트열을
258
+ 32비트 워드(uint32) 단위로 나눈 결과.
259
+ """
260
+
261
+ # 32개의 값을 nbits비트씩 사용하면 총 32*nbits 비트가 필요.
262
+ # 이를 32비트 단위로 나누면 아래처럼 워드 수가 결정됨.
263
+ out_size = (32 * nbits + 31) // 32 # 올림
264
+
265
+ # 결과 버퍼 (np.uint32로)
266
+ out_code = np.zeros(out_size, dtype=np.uint32)
267
+
268
+ for i in range(32):
269
+ # cluster_idx[i]를 unsigned 처리
270
+ val = np.uint32(cluster_idx[i])
271
+
272
+ offset = i * nbits
273
+ wIndex = offset // 32 # 몇 번째 워드인지
274
+ bIndex = offset % 32 # 그 워드 내에서 몇 번째 비트부터 시작?
275
+
276
+ # 첫 번째 워드에 bIndex부터 nbits비트 중 일부 혹은 전부를 저장
277
+ out_code[wIndex] |= (val << bIndex) & np.uint32(0xFFFFFFFF)
278
+
279
+ # 현재 워드에 다 못 들어가는 나머지 비트가 있으면, 다음 워드에 저장
280
+ bits_in_word = 32 - bIndex
281
+ if bits_in_word < nbits:
282
+ upper = val >> bits_in_word
283
+ out_code[wIndex + 1] |= upper & np.uint32(0xFFFFFFFF)
284
+
285
+ return out_code
286
+
287
+ @numba.njit(cache=True)
288
+ def pack_for_sq_pack_kernel(unpacked_code: np.ndarray,
289
+ nbits: int,
290
+ blockDimX: int=32) -> np.ndarray:
291
+ """
292
+ Python 버전의 pack_for_sq_pack_kernel.
293
+ unpacked_code: shape = (N, K), dtype=uint32
294
+ nbits: int
295
+ blockDimX: int
296
+ return: Bcode (1D np.uint32 배열, 길이 = N * (K*nbits//32))
297
+ """
298
+
299
+ N, K = unpacked_code.shape
300
+ out_size = N * (K * nbits // 32)
301
+ Bcode = np.zeros(out_size, dtype=np.uint32)
302
+
303
+ K_iter = int(np.ceil(K / (32 * blockDimX)))
304
+
305
+ # 임시 버퍼
306
+ unpacked_Bcode_row = np.zeros(32, dtype=np.uint32)
307
+
308
+ for n_ in range(N):
309
+ eff_warp_size = blockDimX
310
+ for k_ in range(K_iter):
311
+ for thx in range(blockDimX):
312
+ if k_ == K // (32 * blockDimX):
313
+ eff_warp_size = (K % (32 * blockDimX)) // 32
314
+ if thx >= eff_warp_size:
315
+ break
316
+ k_val = k_ * 32 * blockDimX + 8 * thx
317
+ k_code = k_ * nbits * blockDimX + thx
318
+
319
+ # unpacked_Bcode_row에 32개 로드
320
+ idx_out = 0
321
+ for j in range(4):
322
+ k_val_idx = k_val + 8 * j * eff_warp_size
323
+ for i in range(8):
324
+ unpacked_Bcode_row[idx_out] = unpacked_code[n_, k_val_idx + i]
325
+ idx_out += 1
326
+
327
+ # pack_32 호출
328
+ Bcode_row = pack_32(unpacked_Bcode_row, nbits)
329
+
330
+ # Bcode에 저장
331
+ for j in range(nbits):
332
+ k_code_idx = k_code + j * eff_warp_size
333
+ Bcode[n_ * (K * nbits // 32) + k_code_idx] = Bcode_row[j]
334
+
335
+ return Bcode
lib/quantizer/quant_op.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib.utils as utils
2
+ from lib.quantizer.pack_op import pack_codes, pack_for_sq_pack_kernel
3
+ import torch
4
+ _PERMUTE = torch.arange(256).reshape(2, 8, 2, 4, 2).permute(1, 3, 2, 0,
5
+ 4).flatten()
6
+ _INV_PERMUTE = torch.zeros(256, dtype=torch.int64)
7
+ _INV_PERMUTE[_PERMUTE] = torch.arange(256)
8
+
9
+ def random_mat(N, K):
10
+ return torch.randn(N, K, dtype=torch.float16).cuda()
11
+
12
+ def random_lut(nbits, vec_sz):
13
+ return torch.randn(1 << (nbits), vec_sz, dtype=torch.float16).cuda()
14
+
15
+ def vq_pack_reshape_pack_routine(packed_codes, warp_size, code_n, N):
16
+ '''
17
+ packed_codes: torch.Tensor, shape = (N, -1, code_n)
18
+ '''
19
+ packed_codes = packed_codes.reshape(N, -1, code_n)
20
+ if packed_codes.shape[1] % warp_size == 0:
21
+ packed_codes = packed_codes.reshape(N, -1, warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
22
+ else:
23
+ full_warp_part = packed_codes.shape[1] - packed_codes.shape[1] % warp_size
24
+ packed_codes_full = packed_codes[:, :full_warp_part]
25
+ packed_codes_partial = packed_codes[:, full_warp_part:]
26
+ effective_warp_size = packed_codes_partial.shape[1]
27
+ packed_codes_full = packed_codes_full.reshape(N, -1, warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
28
+ packed_codes_partial = packed_codes_partial.reshape(N, 1, effective_warp_size, code_n).permute(0, 1, 3, 2).reshape(N, -1)
29
+ packed_codes = torch.cat([packed_codes_full, packed_codes_partial], dim=1)
30
+
31
+ return packed_codes
32
+
33
+ def reshape_mat(mat, chunk_size, warp_size=32):
34
+ '''
35
+ mat: torch.Tensor, shape = (N, K)
36
+ chunk_size: int
37
+
38
+ return: reshaped_mat (N, K//chunk_size, chunk_size)
39
+ '''
40
+ N, K = mat.shape
41
+ assert K % 8 == 0 and chunk_size % 8 == 0, "K and chunk_size must be divisible by 8"
42
+ assert K % (chunk_size * warp_size) == 0, "K must be divisible by chunk_size * warp_size"
43
+ K_iter = K // (chunk_size * warp_size)
44
+ new_mat = mat.reshape(N, K_iter, chunk_size // 8, warp_size, 8).permute(0, 1, 3, 2, 4).reshape(N, K_iter, warp_size, chunk_size)
45
+ return new_mat
46
+
47
+ def vq_pack_reshape_mat_routine(mat, chunk_size, vec_sz, warp_size=32):
48
+ '''
49
+ mat: torch.Tensor, shape = (N, K)
50
+ chunk_size: int
51
+
52
+ return: vecs
53
+ '''
54
+ N, K = mat.shape
55
+ assert K % chunk_size == 0, "K must be divisible by chunk_size"
56
+ if K % (chunk_size * warp_size) == 0:
57
+ mat = reshape_mat(mat, chunk_size, warp_size)
58
+ vecs = mat.reshape(-1, vec_sz)
59
+ else:
60
+ full_warp_part = K - K % (chunk_size * warp_size)
61
+ mat_full = reshape_mat(mat[:, :full_warp_part], chunk_size, warp_size)
62
+ effective_warp_size = (K % (chunk_size * warp_size)) // chunk_size
63
+ mat_partial = reshape_mat(mat[:, full_warp_part:], chunk_size, effective_warp_size)
64
+
65
+ vecs = torch.cat([mat_full.reshape(N, -1, vec_sz), mat_partial.reshape(N, -1, vec_sz)], dim=1).reshape(-1, vec_sz)
66
+
67
+ return vecs
68
+
69
+ def pack_qweight_vq_simt(P, lut_bits, vec_sz, code_n, codeT_sz=32):
70
+ '''
71
+ P: (N, K // vec_sz)
72
+ '''
73
+ N = P.shape[0]
74
+ expanded_P = P.unsqueeze(-1).expand(-1, -1, vec_sz).reshape(N, -1)
75
+ reshaped_P = vq_pack_reshape_mat_routine(expanded_P, chunk_size=int(32*vec_sz), vec_sz=vec_sz)[:, 0].contiguous()
76
+ packed_codes = torch.from_numpy(pack_codes(reshaped_P.view(-1).cpu().numpy(), lut_bits, code_n, codeT_sz)).reshape(N, -1).cuda()
77
+ packed_codes = vq_pack_reshape_pack_routine(packed_codes, 32, code_n, N)
78
+ return packed_codes
79
+
80
+ def pack_qweight_sq_simt(P, lut_bits):
81
+ '''
82
+ P: (N, K)
83
+ '''
84
+ N = P.shape[0]
85
+ packed_codes = torch.from_numpy(pack_for_sq_pack_kernel(P.cpu().numpy(), lut_bits)).reshape(N, -1).cuda()
86
+ return packed_codes
87
+
88
+ # for tensor core
89
+ def pack_qweight(P, vec_sz, lut_bits, td_x=16, td_y=16, batch_size=1024):
90
+ '''
91
+ P: (N, K // vec_sz, 2 ** lut_bits) 0, 1
92
+ '''
93
+ N = P.shape[0]
94
+ mat_packed = []
95
+ for i in range(0, N, batch_size):
96
+ sidx, eidx = i, min(i + batch_size, N)
97
+ cur_size = eidx - sidx
98
+ mat_packed.append(pack_qweight_routine(P[sidx:eidx], vec_sz, lut_bits, td_x, td_y))
99
+ return torch.cat(mat_packed, dim=0)
100
+
101
+ def pack_qweight_routine(P, vec_sz, lut_bits, td_x=16, td_y=16):
102
+ '''
103
+ P: (N, K // vec_sz, 2 ** lut_bits) 0, 1
104
+ '''
105
+ if vec_sz == 1:
106
+ if len(P.shape) == 3:
107
+ P_ind = P.argmax(dim=-1) # (N, K)
108
+ elif len(P.shape) == 2:
109
+ P_ind = P
110
+ N, K = P_ind.shape
111
+ P_tiled = P_ind.reshape(N // td_x, td_x, K // td_y, td_y) \
112
+ .permute(0, 2, 1, 3) \
113
+ .reshape(-1, td_x * td_y)
114
+ P_tiled_permuted = P_tiled[..., _PERMUTE]
115
+ elif vec_sz == 2:
116
+ if len(P.shape) == 3:
117
+ P_ind = P.argmax(dim=-1) # (N, K // vec_sz)
118
+ else:
119
+ P_ind = P
120
+ P_ind = P_ind.unsqueeze(-1).expand(-1, -1, vec_sz)
121
+ P_ind = P_ind.reshape(P_ind.shape[0], -1).contiguous()
122
+ N, K = P_ind.shape
123
+ P_tiled = P_ind.reshape(N // td_x, td_x, K // td_y, td_y) \
124
+ .permute(0, 2, 1, 3) \
125
+ .reshape(-1, td_x * td_y)
126
+ # permute and flatten
127
+ P_tiled_permuted = P_tiled[..., _PERMUTE].contiguous().view(-1, vec_sz)
128
+ assert torch.allclose(P_tiled_permuted[:, 0], P_tiled_permuted[:, 1]), \
129
+ "P_tiled_permuted[:, 0] and P_tiled_permuted[:, 1] are not the same"
130
+ P_tiled_permuted = P_tiled_permuted[:, 0].contiguous()
131
+ P_tiled_permuted = P_tiled_permuted.reshape((N * K) // (td_x * td_y), td_x * td_y // vec_sz)
132
+
133
+ m = P_tiled_permuted.shape[0]
134
+ c = (td_x * td_y) // vec_sz
135
+
136
+ K_mask = 2 ** torch.arange(lut_bits, device=P.device).view(1, 1, -1) # => [1,1,lut_bits]
137
+ bits_bool = (P_tiled_permuted.unsqueeze(-1) & K_mask) > 0 # => [m, c, lut_bits]
138
+ if vec_sz == 1:
139
+
140
+ # group 4 bytes => 1 uint32
141
+ # group 8 bits => 1 byte
142
+ bits_bool_8 = bits_bool.reshape(m, (c * lut_bits) // 8, 8) # => [m, c*lut_bits/8, 8]
143
+ uint_mask = (2 ** torch.arange(8, device=P.device, dtype=torch.int16)).view(1, 1, 8)
144
+ packed_8 = (bits_bool_8.to(torch.int16) * uint_mask).sum(dim=-1).to(torch.uint8) # => [m, (c*lut_bits)//8]
145
+
146
+ mat_packed = packed_8.reshape(N // td_x // 2, 2, K // td_y // 2, 2, td_x * td_y // 8, lut_bits) \
147
+ .permute(0, 2, 4, 3, 1, 5).contiguous().flatten().view(torch.uint32)\
148
+ .reshape((N * K) // (td_x * td_y), (td_x * td_y * lut_bits) // (32 * vec_sz))
149
+ elif vec_sz == 2:
150
+ # group 8 bits => 1 byte
151
+ bits_bool_4 = bits_bool.reshape(m, (c * lut_bits) // 4, 4) # => [m, c*nbits/8, 8]
152
+ uint_mask = (2 ** torch.arange(4, device=bits_bool_4.device, dtype=torch.int16)).view(1, 1, 4)
153
+ packed_4 = (bits_bool_4.to(torch.int16) * uint_mask).sum(dim=-1).to(torch.uint8) # => [m, (c*nbits)//8]
154
+
155
+ mat_packed_48 = packed_4.reshape(N // td_x // 2, 2, K // td_y // 2, 2, td_x * td_y // 8, lut_bits) \
156
+ .permute(0, 2, 4, 3, 1, 5).contiguous().flatten()
157
+ # uint 4 packed in uint 8 to uint 32
158
+ packing_mask = torch.Tensor([1, 2 ** 4]).to(torch.int8).view(1,2).cuda()
159
+ mat_packed8 = (mat_packed_48.reshape(-1, 2) * packing_mask).sum(dim=-1).to(torch.uint8).contiguous().flatten()
160
+
161
+ mat_packed = mat_packed8.view(torch.uint32).reshape((N * K) // (td_x * td_y), (td_x * td_y * lut_bits) // (32 * vec_sz))
162
+ return mat_packed.view(N, -1)
163
+
164
+ def load_hessian(in_hess_path, sigma_reg=0.01):
165
+ H_data = torch.load(in_hess_path, map_location=torch.device('cpu'))
166
+ H = utils.flat_to_sym(H_data['flatH'], H_data['n'])
167
+ if 'mu' in H_data:
168
+ mu = H_data['mu']
169
+ H += mu[None, :] * mu[:, None]
170
+ del mu
171
+ del H_data
172
+ H = utils.regularize_H(H, sigma_reg)
173
+ assert len(H.shape) == 2 and H.shape[0] == H.shape[1], "H must be a square matrix"
174
+ return H.to(torch.float64).unsqueeze(-1)
175
+
176
+ def load_group_hessian(in_hess_path, sigma_reg=0.01, layer_key=None):
177
+ H_data = torch.load(in_hess_path, map_location=torch.device('cpu'))
178
+ H = H_data[layer_key]
179
+ for i in range(H.shape[-1]):
180
+ H[:, :, i] = utils.regularize_H(H[:, :, i], sigma_reg)
181
+ assert len(H.shape) == 3 and H.shape[0] == H.shape[1], "H must be a square matrix"
182
+ return H.to(torch.float64)
183
+
184
+ # deprecated func for dequantization
185
+ @torch.compile
186
+ def dequantize_mat_sq(mat_packed, lut, N, K, nbits, td_x=16, td_y=16):
187
+ packed = mat_packed.flatten().view(torch.uint8).reshape(N // td_x // 2,
188
+ K // td_y // 2,
189
+ td_x * td_y // 8,
190
+ 2, 2, nbits)
191
+ packed_8 = packed.permute(0, 4, 1, 3, 2, 5).contiguous().reshape(N * K // (td_x * td_y), (td_x * td_y) * nbits // 8)
192
+ bits_mask = (2 ** torch.arange(8, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 8)
193
+ bits_bool_8 = (packed_8.unsqueeze(-1) & bits_mask) > 0
194
+ bits_bool = bits_bool_8.reshape(N * K // (td_x * td_y), (td_x * td_y), nbits)
195
+ K_mask = 2 ** torch.arange(nbits, device=mat_packed.device).view(1, 1, -1)
196
+ indices = (bits_bool * K_mask).sum(dim=-1)
197
+ recon = lut[indices.long()].reshape(N * K // (td_x * td_y), td_x * td_y)
198
+ recon = recon.index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
199
+ recon = recon.reshape(N // td_x, K // td_y, td_x, td_y)
200
+ recon = recon.permute(0, 2, 1, 3).reshape(N, K)
201
+ return recon
202
+
203
+ @torch.compile
204
+ def dequantize_mat_sq_inds(mat_packed, N, K, nbits, td_x=16, td_y=16):
205
+ packed = mat_packed.flatten().view(torch.uint8).reshape(N // td_x // 2,
206
+ K // td_y // 2,
207
+ td_x * td_y // 8,
208
+ 2, 2, nbits)
209
+ packed_8 = packed.permute(0, 4, 1, 3, 2, 5).contiguous().reshape(N * K // (td_x * td_y), (td_x * td_y) * nbits // 8)
210
+ bits_mask = (2 ** torch.arange(8, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 8)
211
+ bits_bool_8 = (packed_8.unsqueeze(-1) & bits_mask) > 0
212
+ bits_bool = bits_bool_8.reshape(N * K // (td_x * td_y), (td_x * td_y), nbits)
213
+ K_mask = 2 ** torch.arange(nbits, device=mat_packed.device).view(1, 1, -1)
214
+ indices = (bits_bool * K_mask).sum(dim=-1)
215
+ indices = indices.reshape(N * K // (td_x * td_y), td_x * td_y).index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
216
+ indices = indices.reshape(N // td_x, K // td_y, td_x, td_y)
217
+ indices = indices.permute(0, 2, 1, 3).reshape(N, K).contiguous()
218
+ return indices
219
+
220
+ @torch.compile
221
+ def dequantize_mat_sq_inds_vec2(mat_packed: torch.Tensor,
222
+ N: int,
223
+ K: int,
224
+ lut_bits: int,
225
+ td_x: int = 16,
226
+ td_y: int = 16) -> torch.Tensor:
227
+ mat_packed8 = mat_packed.view(torch.uint8).flatten().unsqueeze(-1).expand(-1, 2).contiguous()
228
+
229
+ mat_packed48 = torch.zeros_like(mat_packed8)
230
+ mat_packed48[:, 0] = mat_packed8[:, 0] & 0b1111
231
+ mat_packed48[:, 1] = mat_packed8[:, 1] >> 4
232
+
233
+ packed_4 = mat_packed48.reshape(N // td_x // 2, K // td_y // 2, td_x * td_y // 8, 2, 2, lut_bits).permute(0,4,1,3,2,5).reshape(N * K // (td_x * td_y), -1).contiguous()
234
+ bits_mask = (2 ** torch.arange(4, device=mat_packed.device, dtype=torch.int16)).view(1, 1, 4)
235
+ bits_bool_4 = (packed_4.unsqueeze(-1) & bits_mask) > 0
236
+ bits_bool = bits_bool_4.reshape(N * K // (td_x * td_y), td_x * td_y // 2, lut_bits)
237
+ K_mask = 2 ** torch.arange(lut_bits, device=mat_packed.device).view(1, 1, -1)
238
+ indices = (bits_bool * K_mask).sum(dim=-1)
239
+ indices = indices.reshape(N * K // (td_x * td_y), td_x * td_y // 2, 1).expand(-1, -1, 2).reshape(N * K // (td_x * td_y), td_x * td_y).contiguous()
240
+ indices = indices.index_select(dim=1, index=_INV_PERMUTE.to(mat_packed.device))
241
+ indices = indices.reshape(N // td_x, K // td_y, td_x, td_y)
242
+ indices = indices.permute(0, 2, 1, 3).reshape(N, K // 2, 2)
243
+ indices = indices[:, :, 0].contiguous()
244
+ return indices
245
+
246
+ def convert_tensor_core_to_simt(mat_packed, N, K, vec_sz, lut_bit, code_n, codeT_sz=32, td_x=16, td_y=16):
247
+ device = mat_packed.device
248
+ mat_packed = mat_packed.cuda()
249
+ if vec_sz == 2:
250
+ indices = dequantize_mat_sq_inds_vec2(mat_packed, N, K, lut_bit, td_x, td_y)
251
+ packed_codes = pack_qweight_vq_simt(indices, lut_bit, vec_sz, code_n, codeT_sz)
252
+ packed_codes = packed_codes.to(device)
253
+ else:
254
+ indices = dequantize_mat_sq_inds(mat_packed, N, K, lut_bit, td_x, td_y)
255
+ packed_codes = pack_qweight_sq_simt(indices, lut_bit)
256
+ packed_codes = packed_codes.to(device)
257
+ return packed_codes.contiguous()
258
+
259
+
260
+
261
+ if __name__ == "__main__":
262
+ nbits = 5
263
+ Qidxs = torch.randint(0, 2**nbits, (11008, 4096)).cuda()
264
+ packed = pack_qweight(Qidxs, 1, nbits)
265
+
266
+ indices = dequantize_mat_sq_inds(packed, 11008, 4096, nbits)
267
+
268
+ converted = convert_tensor_core_to_simt(packed, 11008, 4096, 1, nbits, code_n=nbits)
269
+
270
+ Qidxs2 = torch.randint(0, 2**nbits, (11008, 2048)).cuda()
271
+ packed2 = pack_qweight(Qidxs2, 2, nbits)
272
+
273
+ indices2 = dequantize_mat_sq_inds_vec2(packed2, 11008, 4096, nbits)
274
+
275
+ converted2 = convert_tensor_core_to_simt(packed2, 11008, 4096, 2, nbits, code_n=nbits)
276
+
277
+ import ipdb; ipdb.set_trace()
lib/quantizer/tcq_quant.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.utils import block_LDL, matmul_hadUt, matmul_hadUt_head
3
+ from lib.algo.ldlq import LDLQ
4
+ from lib.quantizer.quant_op import load_hessian, load_group_hessian
5
+ from lib.linear import QTIPLinearTCQ, IncoherentLinear
6
+ from lib.codebook.bitshift import bitshift_codebook
7
+ import torch._dynamo
8
+ import time
9
+ class Args:
10
+ def __init__(self, td_x, td_y, V):
11
+ self.td_x = td_x
12
+ self.td_y = td_y
13
+ self.V = V
14
+
15
+ def qtip_quantize_mat(Wr, HRr, Wscale, cb, td_x=16, td_y=16, KV=4, V=2, use_hess=True):
16
+ HRr_orig = HRr.clone()
17
+ Wr = Wr.to(torch.float64)
18
+ (m, n) = Wr.shape
19
+ gs = HRr.shape[-1]
20
+ LRrs = []
21
+ diag = torch.arange(n, device=HRr.device)
22
+ if not use_hess:
23
+ eye = torch.eye(n, device=Wr.device, dtype=torch.float64)
24
+ LRr, D = block_LDL(eye, td_y)
25
+ LRr[diag, diag] = 0
26
+ LRrs.append(LRr)
27
+ else:
28
+ for i in range(gs):
29
+ LRr, D = block_LDL(HRr[:,:,i], td_y)
30
+ LRr[diag, diag] = 0
31
+ LRrs.append(LRr)
32
+
33
+ args = Args(td_x, td_y, V)
34
+
35
+ Qidxs_list = []
36
+ hatWr_list = []
37
+ for i in range(gs):
38
+ cur_Wr = Wr[m // gs * i:m // gs * (i+1)]
39
+ hatWr, Qidxs = LDLQ(cur_Wr, LRrs[i], cb.cuda(), args, for_kernel=True)
40
+ hatWr_list.append(hatWr)
41
+ Qidxs_list.append(Qidxs)
42
+ hatWr = torch.cat(hatWr_list, dim=0)
43
+ Qidxs = torch.cat(Qidxs_list, dim=0)
44
+ assert hatWr.shape == Wr.shape, f"hatWr.shape {hatWr.shape} != Wr.shape {Wr.shape}"
45
+
46
+ Qidxs = Qidxs.cpu()
47
+ packed = cb.pack_trellis(
48
+ Qidxs.reshape(m // td_x, td_x, n // td_y,
49
+ td_y // V).transpose(1, 2).reshape(
50
+ -1, td_x * td_y // V))
51
+
52
+ packed_8 = packed.view(torch.uint8).view(-1, 2)
53
+ packed_4 = torch.cat([packed_8.unsqueeze(-1) & (2 ** 4 - 1), (packed_8.unsqueeze(-1) & (2 ** 8 - 2 ** 4)) >> 4], dim=-1).view(-1, 4).flip(
54
+ (-1, ))
55
+
56
+ packed_4 = packed_4.reshape(m // 16 // 2, 2, n // 16 // 2, 2, 16 * 16 // 8,
57
+ KV).permute(0, 2, 4, 3, 1, 5).flip(
58
+ (-1, )).contiguous().flatten()
59
+ packed_8 = torch.sum(packed_4.view(-1, 2) * torch.Tensor([[1, 2 ** 4]]).to(torch.uint8), dim=-1).to(torch.uint8).contiguous()
60
+ packed = packed_8.view(torch.int16).reshape(packed.shape).cuda()
61
+
62
+ Wr *= Wscale.reshape(-1, 1)
63
+ hatWr *= Wscale.reshape(-1, 1)
64
+
65
+ orig_err = (Wr - hatWr).pow(2).mean()
66
+ err = orig_err / Wr.pow(2).mean()
67
+ print(
68
+ f'err {err.item()} orig_err {orig_err.item()}'
69
+ )
70
+ quant_info = {
71
+ "quantizer": "tcq_ldlq",
72
+ "td_x": td_x,
73
+ "td_y": td_y,
74
+ "KV": KV,
75
+ "V": V,
76
+ "use_hess": use_hess,
77
+ "orig_err": orig_err.item(),
78
+ "err": err.item(),
79
+ }
80
+ return packed, hatWr, quant_info
81
+
82
+ def inc_linear_to_inc_tcq_linear(inc_linear, HRr, cb, td_x=16, td_y=16, KV=4, V=2, scale_override=0.9, use_hess=True):
83
+ Wr = inc_linear.linear.weight.data * scale_override
84
+ Wscale = inc_linear.Wscale.data / scale_override
85
+ inc_linear.Wscale.data.copy_(Wscale)
86
+
87
+ packed, hatWr, quant_info = qtip_quantize_mat(Wr, HRr, Wscale, cb, td_x=td_x, td_y=td_y, KV=KV, V=V, use_hess=use_hess)
88
+ out_features, in_features = Wr.shape
89
+ tcq_linear = QTIPLinearTCQ(
90
+ in_features,
91
+ out_features,
92
+ td_x=16,
93
+ td_y=16,
94
+ L=16,
95
+ KV=KV,
96
+ V=V,
97
+ tlut_bits=cb.tlut_bits,
98
+ bias=inc_linear.bias is not None,
99
+ dtype=inc_linear.dtype,
100
+ )
101
+
102
+ tcq_linear.trellis.data.copy_(packed)
103
+ tcq_linear.tlut.data.copy_(cb.tlut)
104
+
105
+ inc_linear.linear = tcq_linear
106
+ return inc_linear, quant_info
107
+
108
+ def linear_to_incoherent_for_tcq(linear, cb, HR, scale_override=0.9, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False):
109
+ dtype_ = torch.float32
110
+ device = linear.weight.device
111
+ inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype_)
112
+ if SU is None:
113
+ SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
114
+ if SV is None:
115
+ SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
116
+
117
+ if left_only:
118
+ SV = torch.ones_like(SV)
119
+
120
+ if linear.bias is not None:
121
+ inc_linear.bias.data.copy_(linear.bias)
122
+
123
+ W = linear.weight.data.clone().to(dtype_)
124
+ Wr = matmul_hadUt_head(matmul_hadUt_head(W.T.to(device) * SV, hadV).T * SU, hadU) if not left_only else matmul_hadUt_head(W * SU, hadU)
125
+
126
+ if left_only:
127
+ Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_) / (cb.lut.to(torch.float64).square().mean().sqrt().float() * scale_override) # (out_features, 1)
128
+ else:
129
+ Wscale = Wr.to(torch.float64).square().mean().sqrt().view(-1, 1).to(dtype_) / (cb.lut.to(torch.float64).square().mean().sqrt().float() * scale_override) # (1, 1)
130
+ Wscale = Wscale.repeat(Wr.shape[0], 1) # (out_features, 1)
131
+
132
+ Wr = Wr / Wscale
133
+ HRr = torch.zeros_like(HR)
134
+ for i in range(HR.shape[-1]):
135
+ HRr[:,:,i] = matmul_hadUt_head(matmul_hadUt_head(HR[:,:,i].to(device).contiguous() * (1./ SU), hadU).T * (1./ SU), hadU)
136
+
137
+ inc_linear.SU.data.copy_(1./SU.to(dtype_))
138
+ inc_linear.SV.data.copy_(1./SV.to(dtype_))
139
+ inc_linear.Wscale.data.copy_(Wscale.view(-1))
140
+ inc_linear.linear.weight.data.copy_(Wr.to(dtype_))
141
+ inc_linear.rot_info = rot_info
142
+ inc_linear.apply_rot_info()
143
+ return inc_linear, HRr
144
+
145
+ def linear_to_tcq_linear(target_layer, hess_path, cb, scale_override=0.9, KV=4, V=2, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
146
+ t0 = time.time()
147
+ out_features, in_features = target_layer.weight.shape
148
+ if ghess_key == "":
149
+ HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
150
+ else:
151
+ HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
152
+ layer, HRr = linear_to_incoherent_for_tcq(target_layer, cb, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
153
+ HRr = HRr.cuda()
154
+ layer = layer.cuda()
155
+ layer, quant_info = inc_linear_to_inc_tcq_linear(layer, HRr, cb, scale_override=1.0, td_x=16, td_y=16, KV=KV, V=V, use_hess=use_hess)
156
+ quant_info["scale_override"] = scale_override
157
+ quant_info["hess_path"] = hess_path
158
+ quant_info["time"] = time.time() - t0
159
+
160
+ return layer.to(torch.float16), quant_info
lib/quantizer/vq_quant.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.quantizer.quant_op import _INV_PERMUTE, load_hessian, load_group_hessian
3
+ from lib.linear import IncoherentLinear
4
+ from lib.utils import matmul_hadUt, matmul_hadUt_head, clean
5
+ from lib.utils.kmeans import kmeans_sklearn, kmeans_flash1d
6
+ from lib.quantizer.nuq_op import train_least_squares
7
+ from lib.quantizer.quant_op import pack_qweight
8
+ from lib.linear import VQLinearPackTensorCore
9
+
10
+ import random
11
+ import time
12
+ def simple_vq(Wr, vec_sz, lut_bits, batch_size=256):
13
+ batch_size = 64 if lut_bits >= 12 else batch_size
14
+ # kmeans
15
+ Wr_flatten = Wr.reshape(-1, vec_sz)
16
+
17
+ if vec_sz == 1:
18
+ init_centroids = kmeans_flash1d(Wr_flatten, 2 ** lut_bits)
19
+ else:
20
+ init_centroids = kmeans_sklearn(Wr_flatten, 2 ** lut_bits)
21
+ Wr_vec = Wr.reshape(Wr.shape[0], -1, vec_sz) # (W_row, W_col // vec_sz, vec_sz)
22
+
23
+ min_indices = torch.zeros(Wr.shape[0], Wr.shape[1] // vec_sz).to(Wr.device).long()
24
+ for s_idx in range(0, Wr.shape[0], batch_size):
25
+ e_idx = min(s_idx + batch_size, Wr.shape[0])
26
+ dist_sq = ((Wr_vec[s_idx:e_idx].unsqueeze(2) - init_centroids.unsqueeze(0).unsqueeze(0)) ** 2).sum(dim=-1)
27
+ idx = dist_sq.argmin(dim=-1) # batch_size, W_col // vec_sz
28
+ min_indices[s_idx:e_idx] = idx
29
+
30
+ init_P = torch.zeros(Wr.shape[0], Wr.shape[1] // vec_sz, 2 ** lut_bits, dtype=torch.uint8).to(Wr.device)
31
+ init_P.scatter_(2, min_indices.unsqueeze(-1), 1)
32
+
33
+ return init_P, init_centroids.to(torch.float32)
34
+
35
+ def vq_quantize_mat(Wr, HRr, Wscale, vec_sz, lut_bits, iterations=6, use_hess=True):
36
+ Wr = Wr.to(torch.float64)
37
+ if use_hess:
38
+ assert len(HRr.shape) == 3, "HRr must be a 3D tensor"
39
+ assert HRr.shape[0] == HRr.shape[1], "HRr must be a square matrix"
40
+ init_P, init_centroids = simple_vq(Wr, vec_sz, lut_bits)
41
+ P, C, log_dict = train_least_squares(
42
+ W=Wr.detach().cpu().numpy(),
43
+ init_P=init_P.detach().cpu().to(torch.float32).numpy(),
44
+ init_centroids=init_centroids.detach().cpu().numpy(),
45
+ H=HRr.permute(2, 0, 1).detach().cpu().numpy(),
46
+ num_iterations=iterations,
47
+ )
48
+ else:
49
+ init_P, init_centroids = simple_vq(Wr, vec_sz, lut_bits)
50
+ P, C = init_P, init_centroids
51
+ P = P.to(Wr.device)
52
+ C = C.to(Wr.device)
53
+
54
+ P_ind = torch.argmax(P, dim=-1)
55
+ hatWr = C[P_ind]
56
+ hatWr = hatWr.view(hatWr.shape[0], -1)
57
+
58
+ Wr *= Wscale.view(-1, 1)
59
+ hatWr *= Wscale.view(-1, 1)
60
+
61
+ orig_err = (Wr - hatWr).pow(2).mean()
62
+ err = (Wr - hatWr).pow(2).mean() / (Wr.pow(2).mean())
63
+ print(
64
+ f'err {err.item()} orig_err {orig_err.item()}'
65
+ )
66
+ quant_info = {
67
+ "quantizer": "vq_lnq",
68
+ "vec_sz": vec_sz,
69
+ "lut_bits": lut_bits,
70
+ "use_hess": use_hess,
71
+ "iterations": iterations,
72
+ "orig_err": orig_err.item(),
73
+ "err": err.item(),
74
+ }
75
+
76
+ # pack P appropriately to kernel
77
+ packed = pack_qweight(P, vec_sz, lut_bits)
78
+ return packed, C, hatWr, quant_info
79
+
80
+ def inc_linear_to_inc_vq_linear(inc_linear, HRr, lut_bits=4, vec_sz=2, scale_override=0.9, use_hess=True):
81
+ Wr = inc_linear.linear.weight.data * scale_override
82
+ Wscale = inc_linear.Wscale.data / scale_override
83
+ inc_linear.Wscale.data.copy_(Wscale)
84
+
85
+ packed, C, hatWr, quant_info = vq_quantize_mat(Wr, HRr, Wscale, vec_sz, lut_bits, use_hess=use_hess)
86
+ out_features, in_features = Wr.shape
87
+ sq_linear = VQLinearPackTensorCore(
88
+ in_features,
89
+ out_features,
90
+ lut_bits=lut_bits,
91
+ vec_sz=vec_sz,
92
+ bias=inc_linear.bias is not None,
93
+ dtype=inc_linear.dtype,
94
+ )
95
+ sq_linear.qweight.data.copy_(packed)
96
+ sq_linear.lut.data.copy_(C.view(2 ** lut_bits, vec_sz))
97
+
98
+ inc_linear.linear = sq_linear
99
+ return inc_linear, quant_info
100
+
101
+ def linear_to_incoherent_for_vq(linear, HR, scale_override=0.9, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False):
102
+ dtype_ = torch.float32
103
+ device = linear.weight.device
104
+ inc_linear = IncoherentLinear(linear.in_features, linear.out_features, hadU, hadV, linear.bias is not None, dtype_)
105
+ if SU is None:
106
+ SU = ((torch.randn(linear.in_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
107
+ if SV is None:
108
+ SV = ((torch.randn(linear.out_features, dtype=dtype_) > 0.0) * 2.0 - 1.0).to(device).to(dtype_)
109
+
110
+ if left_only:
111
+ SV = torch.ones_like(SV)
112
+
113
+ if linear.bias is not None:
114
+ inc_linear.bias.data.copy_(linear.bias)
115
+
116
+ W = linear.weight.data.to(dtype_)
117
+ Wr = matmul_hadUt_head(matmul_hadUt_head(W.T.to(device) * SV, hadV).T * SU, hadU) if not left_only else matmul_hadUt_head(W * SU, hadU)
118
+ Wscale = Wr.to(torch.float64).square().mean(-1).sqrt().view(-1, 1).to(dtype_) / scale_override
119
+
120
+ Wr = Wr / Wscale
121
+ HRr = torch.zeros_like(HR)
122
+ for i in range(HR.shape[-1]):
123
+ HRr[:,:,i] = matmul_hadUt_head(matmul_hadUt_head(HR[:,:,i].to(device).contiguous() * (1./ SU), hadU).T * (1./ SU), hadU)
124
+
125
+ inc_linear.SU.data.copy_(1./SU.to(dtype_))
126
+ inc_linear.SV.data.copy_((1./SV).to(dtype_))
127
+ inc_linear.Wscale.data.copy_(Wscale.view(-1))
128
+ inc_linear.linear.weight.data.copy_(Wr.to(dtype_))
129
+ inc_linear.rot_info = rot_info
130
+ inc_linear.apply_rot_info()
131
+ return inc_linear, HRr
132
+
133
+ def linear_to_vq_linear(target_layer, hess_path, scale_override=0.9, lut_bits=4, vec_sz=1, use_hess=True, SU=None, SV=None, lnorm=None, hadU=None, hadV=None, rot_info="all", left_only=False, ghess_key=""):
134
+ t0 = time.time()
135
+ out_features, in_features = target_layer.weight.shape
136
+ if ghess_key == "":
137
+ HR = load_hessian(hess_path).cuda() if hess_path is not None else torch.eye(in_features, device="cuda", dtype=torch.float64).unsqueeze(-1)
138
+ else:
139
+ HR = load_group_hessian(hess_path, layer_key=ghess_key).cuda()
140
+ layer, HRr = linear_to_incoherent_for_vq(target_layer, HR, scale_override, SU=SU, SV=SV, lnorm=lnorm, hadU=hadU, hadV=hadV, rot_info=rot_info, left_only=left_only)
141
+ HRr = HRr.cuda()
142
+ layer = layer.cuda()
143
+ layer, quant_info = inc_linear_to_inc_vq_linear(layer, HRr, scale_override=1.0, lut_bits=lut_bits, vec_sz=vec_sz, use_hess=use_hess)
144
+
145
+ quant_info["scale_override"] = scale_override
146
+ quant_info["hess_path"] = hess_path
147
+ quant_info["time"] = time.time() - t0
148
+ print("elapsed time", time.time() - t0)
149
+ return layer.to(torch.float16), quant_info