yuekai commited on
Commit
efd7691
·
verified ·
1 Parent(s): 7c20f1f

Upload folder using huggingface_hub

Browse files
export_onnx.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ import onnxruntime
24
+ import random
25
+ import torch
26
+ from tqdm import tqdm
27
+ from hyperpyyaml import load_hyperpyyaml
28
+
29
+
30
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
31
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
32
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
33
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
34
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
35
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
36
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ return x, mask, mu, t, spks, cond
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser(description='export your model for deployment')
42
+ parser.add_argument('--model_dir',
43
+ type=str,
44
+ default='Step-Audio-2-mini/token2wav',
45
+ help='local path')
46
+ parser.add_argument('--onnx_model',
47
+ type=str,
48
+ default='flow.decoder.estimator.fp32.dynamic_batch.onnx',
49
+ help='onnx model name')
50
+ args = parser.parse_args()
51
+ print(args)
52
+ return args
53
+
54
+
55
+ @torch.no_grad()
56
+ def main():
57
+ args = get_args()
58
+ logging.basicConfig(level=logging.DEBUG,
59
+ format='%(asctime)s %(levelname)s %(message)s')
60
+
61
+ with open(f"{args.model_dir}/flow.yaml", "r") as f:
62
+ configs = load_hyperpyyaml(f)
63
+ flow_model = configs['flow']
64
+
65
+ device = torch.device('cuda')
66
+
67
+
68
+ # 1. export flow decoder estimator
69
+ flow_model.load_state_dict(torch.load(f"{args.model_dir}/flow.pt", map_location="cpu", weights_only=True), strict=True)
70
+ estimator = flow_model.decoder.estimator
71
+ estimator.eval()
72
+ estimator.to(device)
73
+
74
+
75
+ batch_size, seq_len = 2, 256
76
+ out_channels = flow_model.decoder.estimator.out_channels
77
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
78
+ torch.onnx.export(
79
+ estimator,
80
+ (x, mask, mu, t, spks, cond),
81
+ f'{args.model_dir}/{args.onnx_model}',
82
+ export_params=True,
83
+ opset_version=18,
84
+ do_constant_folding=True,
85
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
86
+ output_names=['estimator_out'],
87
+ dynamic_axes={
88
+ 'x': {0: 'batch_size', 2: 'seq_len'},
89
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
90
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
91
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
92
+ 't': {0: 'batch_size'},
93
+ 'spks': {0: 'batch_size'},
94
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
95
+
96
+ }
97
+ )
98
+
99
+ # 2. test computation consistency
100
+ option = onnxruntime.SessionOptions()
101
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
102
+ option.intra_op_num_threads = 1
103
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
104
+ estimator_onnx = onnxruntime.InferenceSession(f'{args.model_dir}/{args.onnx_model}',
105
+ sess_options=option, providers=providers)
106
+
107
+ for _ in tqdm(range(10)):
108
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
109
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
110
+ ort_inputs = {
111
+ 'x': x.cpu().numpy(),
112
+ 'mask': mask.cpu().numpy(),
113
+ 'mu': mu.cpu().numpy(),
114
+ 't': t.cpu().numpy(),
115
+ 'spks': spks.cpu().numpy(),
116
+ 'cond': cond.cpu().numpy()
117
+ }
118
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
119
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
120
+ logging.info('successfully export estimator')
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
125
+
flow.decoder.estimator.fp32.dynamic_batch.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c31e8c62abb6eacdd929baf37910cb58ce71c1b16bd71ced48d807ea6697da48
3
+ size 458962271