drain_module_parameters
Browse files- optimization.py +6 -10
- optimization_utils.py +9 -0
optimization.py
CHANGED
|
@@ -12,8 +12,9 @@ from torchao.quantization import quantize_
|
|
| 12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 13 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 14 |
|
| 15 |
-
from optimization_utils import capture_component_call
|
| 16 |
from optimization_utils import aoti_compile
|
|
|
|
|
|
|
| 17 |
from optimization_utils import ZeroGPUCompiledModel
|
| 18 |
|
| 19 |
|
|
@@ -105,13 +106,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 105 |
else:
|
| 106 |
return cp2(*args, **kwargs)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
pipeline.transformer = combined_transformer_1
|
| 112 |
-
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
|
| 113 |
-
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
| 114 |
|
| 115 |
-
pipeline.transformer_2 = combined_transformer_2
|
| 116 |
-
pipeline.transformer_2
|
| 117 |
-
pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
| 12 |
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 13 |
from torchao.quantization import Int8WeightOnlyConfig
|
| 14 |
|
|
|
|
| 15 |
from optimization_utils import aoti_compile
|
| 16 |
+
from optimization_utils import capture_component_call
|
| 17 |
+
from optimization_utils import drain_module_parameters
|
| 18 |
from optimization_utils import ZeroGPUCompiledModel
|
| 19 |
|
| 20 |
|
|
|
|
| 106 |
else:
|
| 107 |
return cp2(*args, **kwargs)
|
| 108 |
|
| 109 |
+
pipeline.transformer.forward = combined_transformer_1
|
| 110 |
+
drain_module_parameters(pipeline.transformer)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
pipeline.transformer_2.forward = combined_transformer_2
|
| 113 |
+
drain_module_parameters(pipeline.transformer_2)
|
|
|
optimization_utils.py
CHANGED
|
@@ -96,3 +96,12 @@ def capture_component_call(
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
except CapturedCallException as e:
|
| 97 |
captured_call.args = e.args
|
| 98 |
captured_call.kwargs = e.kwargs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def drain_module_parameters(module: torch.nn.Module):
|
| 102 |
+
state_dict_meta = {name: tensor.to('meta') for name, tensor in module.state_dict().items()}
|
| 103 |
+
state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
| 104 |
+
module.load_state_dict(state_dict, assign=True)
|
| 105 |
+
for name, param in state_dict.items():
|
| 106 |
+
meta = state_dict_meta[name]
|
| 107 |
+
param.data = torch.Tensor([]).to(device=meta.device, dtype=meta.dtype)
|