yujiwang0606 commited on
Commit
adecc3c
·
1 Parent(s): 4d42c48
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +15 -0
  2. README.md +5 -5
  3. app.py +175 -0
  4. imaginaire/__init__.py +14 -0
  5. imaginaire/callbacks/__init__.py +14 -0
  6. imaginaire/callbacks/every_n.py +84 -0
  7. imaginaire/callbacks/manual_gc.py +49 -0
  8. imaginaire/config.py +410 -0
  9. imaginaire/lazy_config/__init__.py +73 -0
  10. imaginaire/lazy_config/file_io.py +24 -0
  11. imaginaire/lazy_config/instantiate.py +119 -0
  12. imaginaire/lazy_config/lazy.py +442 -0
  13. imaginaire/lazy_config/omegaconf_patch.py +65 -0
  14. imaginaire/lazy_config/registry.py +74 -0
  15. imaginaire/model.py +137 -0
  16. imaginaire/trainer.py +322 -0
  17. imaginaire/utils/__init__.py +14 -0
  18. imaginaire/utils/callback.py +518 -0
  19. imaginaire/utils/checkpointer.py +282 -0
  20. imaginaire/utils/config_helper.py +201 -0
  21. imaginaire/utils/device.py +39 -0
  22. imaginaire/utils/distributed.py +444 -0
  23. imaginaire/utils/easy_io/__init__.py +14 -0
  24. imaginaire/utils/easy_io/backends/__init__.py +28 -0
  25. imaginaire/utils/easy_io/backends/base_backend.py +60 -0
  26. imaginaire/utils/easy_io/backends/http_backend.py +91 -0
  27. imaginaire/utils/easy_io/backends/local_backend.py +551 -0
  28. imaginaire/utils/easy_io/backends/registry_utils.py +125 -0
  29. imaginaire/utils/easy_io/easy_io.py +1034 -0
  30. imaginaire/utils/easy_io/file_client.py +448 -0
  31. imaginaire/utils/easy_io/handlers/__init__.py +29 -0
  32. imaginaire/utils/easy_io/handlers/base.py +44 -0
  33. imaginaire/utils/easy_io/handlers/byte_handler.py +39 -0
  34. imaginaire/utils/easy_io/handlers/csv_handler.py +42 -0
  35. imaginaire/utils/easy_io/handlers/gzip_handler.py +33 -0
  36. imaginaire/utils/easy_io/handlers/imageio_video_handler.py +168 -0
  37. imaginaire/utils/easy_io/handlers/json_handler.py +49 -0
  38. imaginaire/utils/easy_io/handlers/jsonl_handler.py +80 -0
  39. imaginaire/utils/easy_io/handlers/np_handler.py +89 -0
  40. imaginaire/utils/easy_io/handlers/pandas_handler.py +31 -0
  41. imaginaire/utils/easy_io/handlers/pickle_handler.py +42 -0
  42. imaginaire/utils/easy_io/handlers/pil_handler.py +96 -0
  43. imaginaire/utils/easy_io/handlers/registry_utils.py +82 -0
  44. imaginaire/utils/easy_io/handlers/tarfile_handler.py +39 -0
  45. imaginaire/utils/easy_io/handlers/torch_handler.py +34 -0
  46. imaginaire/utils/easy_io/handlers/torchjit_handler.py +34 -0
  47. imaginaire/utils/easy_io/handlers/txt_handler.py +34 -0
  48. imaginaire/utils/easy_io/handlers/yaml_handler.py +38 -0
  49. imaginaire/utils/ema.py +315 -0
  50. imaginaire/utils/fused_adam.py +398 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache and build files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+
7
+ # Virtual environments
8
+ .venv/
9
+ venv/
10
+ env/
11
+
12
+ # IDE and misc
13
+ .idea/
14
+ .vscode/
15
+ .DS_Store
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: RCM Wan 720p
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
@@ -11,4 +11,4 @@ license: apache-2.0
11
  short_description: rCM model for Wan2.1
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: rCM-Wan 720p
3
+ emoji: 🐠
4
+ colorFrom: green
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
 
11
  short_description: rCM model for Wan2.1
12
  ---
13
 
14
+ This demo uses the unofficial rCM models for Wan from worstcoder/rcm-Wan.
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import requests
5
+ from wan2pt1_t2v_rcm_infer import inference, prepare_models
6
+ from huggingface_hub import hf_hub_download
7
+ import random
8
+ from types import SimpleNamespace
9
+ import gc
10
+ import torch
11
+ from imaginaire.lazy_config import LazyCall as L, LazyDict, instantiate
12
+ from wan2pt1_t2v_rcm_infer import load_dit_model, WanModel
13
+
14
+
15
+ import flash_attn
16
+ print("flash_attn version: ", flash_attn.__version__)
17
+
18
+ WAN2PT1_1PT3B_T2V: LazyDict = L(WanModel)(
19
+ dim=1536,
20
+ eps=1e-06,
21
+ ffn_dim=8960,
22
+ freq_dim=256,
23
+ in_dim=16,
24
+ model_type="t2v",
25
+ num_heads=12,
26
+ num_layers=30,
27
+ out_dim=16,
28
+ text_len=512,
29
+ )
30
+
31
+ WAN2PT1_14B_T2V: LazyDict = L(WanModel)(
32
+ dim=5120,
33
+ eps=1e-06,
34
+ ffn_dim=13824,
35
+ freq_dim=256,
36
+ in_dim=16,
37
+ model_type="t2v",
38
+ num_heads=40,
39
+ num_layers=40,
40
+ out_dim=16,
41
+ text_len=512,
42
+ )
43
+
44
+ dit_configs = {"1.3B": WAN2PT1_1PT3B_T2V, "14B": WAN2PT1_14B_T2V}
45
+
46
+ dit_path_14B_720p = hf_hub_download(
47
+ repo_id="worstcoder/rcm-Wan",
48
+ filename="rCM_Wan2.1_T2V_14B_720p.pt",
49
+ )
50
+
51
+ vae_path = hf_hub_download(
52
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
53
+ filename="Wan2.1_VAE.pth"
54
+ )
55
+
56
+ text_encoder_path = hf_hub_download(
57
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
58
+ filename="models_t5_umt5-xxl-enc-bf16.pth"
59
+ )
60
+
61
+ net_14B_720p, tokenizer, t5_encoder = prepare_models(dit_path_14B_720p, vae_path, text_encoder_path)
62
+ print("Loaded models")
63
+ gc.collect()
64
+
65
+ def random_seed():
66
+ return random.randint(0, 2**32 - 1)
67
+
68
+ @spaces.GPU(duration=360)
69
+ def generate_videos(prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed):
70
+ if seed is None:
71
+ seed = random.randint(0, 2**32 - 1)
72
+
73
+ if "480p" in model_size:
74
+ resolution = "480p"
75
+ else:
76
+ resolution = "720p"
77
+
78
+ args = SimpleNamespace(
79
+ prompt=prompt,
80
+ model_size=model_size,
81
+ num_steps=num_steps,
82
+ num_samples=num_samples,
83
+ sigma_max=sigma_max,
84
+ num_frames=77,
85
+ resolution=resolution,
86
+ aspect_ratio=aspect_ratio,
87
+ seed=seed,
88
+ )
89
+
90
+ with torch.no_grad():
91
+ video_list = inference(args, net_14B_720p, tokenizer, t5_encoder)
92
+
93
+ if aspect_ratio == "16:9":
94
+ return video_list, None
95
+ else:
96
+ return None, video_list
97
+
98
+ def update_num_samples(model_choice):
99
+ if model_choice == "rCM-Wan2.1-T2V-1.3B-480p":
100
+ options = [1, 2, 3, 4]
101
+ elif model_choice == "rCM-Wan2.1-T2V-14B-480p":
102
+ options = [1, 2, 3]
103
+ else:
104
+ options = [1, 2, 3]
105
+ return gr.Dropdown(choices=options, value=options[0], label="num_samples")
106
+
107
+ def update_sigma_max(model_choice):
108
+ if "480p" in model_choice:
109
+ options = [80, 120, 200, 400, 800, 1600]
110
+ else:
111
+ options = [120, 200, 400, 800, 1600]
112
+ return gr.Dropdown(choices=options, value=options[0], label="sigma_max")
113
+
114
+ with gr.Blocks() as demo:
115
+ gr.Markdown("## rCM model for Wan")
116
+
117
+ examples = [
118
+ ["A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."],
119
+ ["A close-up shot captures a steaming hot pot brimming with vegetables and dumplings, set on a rustic wooden table. The camera focuses on the bubbling broth as a woman, dressed in a light, patterned blouse, reaches in with chopsticks to lift a tender leaf of cabbage from the simmering mixture. Steam rises around her as she leans back slightly, her warm smile reflecting satisfaction and joy. Her movements are smooth and deliberate, showcasing her comfort and familiarity with the dining process. The background includes a small bowl of dipping sauce and a clay pot, adding to the cozy, communal dining atmosphere."],
120
+ ["A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond."]
121
+ ]
122
+
123
+ with gr.Row():
124
+ with gr.Column(scale=1):
125
+ with gr.Row():
126
+ prompt = gr.Textbox(label="Text prompt", placeholder="Text prompt for videos")
127
+ model_size = gr.Radio(
128
+ ["rCM-Wan2.1-T2V-14B-720p"],
129
+ value="rCM-Wan2.1-T2V-14B-720p",
130
+ label="Model"
131
+ )
132
+
133
+ with gr.Row():
134
+ num_samples = gr.Dropdown([1, 2], value=1, label="num_samples")
135
+ aspect_ratio = gr.Radio(["16:9", "9:16"], value="16:9", label="aspect_ratio")
136
+ sigma_max = gr.Dropdown([120, 200, 400, 800, 1600], value=120, label="sigma_max")
137
+
138
+ with gr.Row():
139
+ num_steps = gr.Slider(1, 4, value=4, step=1, label="num_steps")
140
+ seed = gr.Number(label="seed", value=random_seed(), interactive=True)
141
+
142
+ with gr.Row():
143
+ regenerate_btn = gr.Button("New Seed")
144
+ run_btn = gr.Button("Generate Videos")
145
+
146
+ with gr.Row():
147
+ gr.Examples(
148
+ examples,
149
+ inputs=[prompt],
150
+ label="Example prompts"
151
+ )
152
+
153
+ with gr.Column(scale=1):
154
+ video_16_9 = gr.Video(label="Videos 16:9", width=832)
155
+ video_9_16 = gr.Video(label="Videos 9:16", width=480, visible=False)
156
+
157
+ def show_video(aspect):
158
+ if aspect == "16:9":
159
+ return gr.update(visible=True), gr.update(visible=False, value=None)
160
+ else:
161
+ return gr.update(visible=False, value=None), gr.update(visible=True)
162
+
163
+ model_size.change(fn=update_num_samples, inputs=model_size, outputs=num_samples)
164
+ model_size.change(fn=update_sigma_max, inputs=model_size, outputs=sigma_max)
165
+
166
+ aspect_ratio.change(show_video, inputs=aspect_ratio, outputs=[video_16_9, video_9_16])
167
+ regenerate_btn.click(fn=random_seed, outputs=seed)
168
+
169
+ run_btn.click(
170
+ fn=generate_videos,
171
+ inputs=[prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed],
172
+ outputs=[video_16_9, video_9_16],
173
+ )
174
+
175
+ demo.launch()
imaginaire/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/callbacks/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/callbacks/every_n.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import abstractmethod
17
+
18
+ import torch
19
+
20
+ from imaginaire.model import ImaginaireModel
21
+ from imaginaire.trainer import ImaginaireTrainer
22
+ from imaginaire.utils import distributed, log
23
+ from imaginaire.utils.callback import Callback
24
+
25
+
26
+ class EveryN(Callback):
27
+ def __init__(
28
+ self,
29
+ every_n: int | None = None,
30
+ step_size: int = 1,
31
+ barrier_after_run: bool = True,
32
+ run_at_start: bool = False,
33
+ ) -> None:
34
+ """Constructor for `EveryN`.
35
+
36
+ Args:
37
+ every_n (int): Frequency with which callback is run during training.
38
+ step_size (int): Size of iteration step count. Default 1.
39
+ barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts.
40
+ run_at_start (bool): Whether to run at the beginning of training. Default False.
41
+ """
42
+ self.every_n = every_n
43
+ if self.every_n == 0:
44
+ log.warning(
45
+ f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped."
46
+ )
47
+
48
+ self.step_size = step_size
49
+ self.barrier_after_run = barrier_after_run
50
+ self.run_at_start = run_at_start
51
+
52
+ def on_training_step_end(
53
+ self,
54
+ model: ImaginaireModel,
55
+ data_batch: dict[str, torch.Tensor],
56
+ output_batch: dict[str, torch.Tensor],
57
+ loss: torch.Tensor,
58
+ iteration: int = 0,
59
+ ) -> None:
60
+ # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training
61
+ if self.every_n != 0:
62
+ trainer = self.trainer
63
+ global_step = iteration // self.step_size
64
+ should_run = (iteration == 1 and self.run_at_start) or (
65
+ global_step % self.every_n == 0
66
+ ) # (self.every_n - 1)
67
+ if should_run:
68
+ log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}")
69
+ self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration)
70
+ log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}")
71
+ # add necessary barrier to avoid timeout
72
+ if self.barrier_after_run:
73
+ distributed.barrier()
74
+
75
+ @abstractmethod
76
+ def every_n_impl(
77
+ self,
78
+ trainer: ImaginaireTrainer,
79
+ model: ImaginaireModel,
80
+ data_batch: dict[str, torch.Tensor],
81
+ output_batch: dict[str, torch.Tensor],
82
+ loss: torch.Tensor,
83
+ iteration: int,
84
+ ) -> None: ...
imaginaire/callbacks/manual_gc.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+
18
+ from imaginaire.callbacks.every_n import EveryN
19
+ from imaginaire.utils import log
20
+
21
+
22
+ class ManualGarbageCollection(EveryN):
23
+ """
24
+ Disable auto gc and manually trigger garbage collection every N iterations
25
+ It is super useful for large scale training to reduce gpu sync time!
26
+ Can reach 50% speedup.
27
+
28
+ It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses.
29
+
30
+ We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM
31
+ """
32
+
33
+ def __init__(self, *args, warm_up: int = 5, **kwargs):
34
+ kwargs["barrier_after_run"] = False
35
+ super().__init__(*args, **kwargs)
36
+
37
+ self.counter = 0
38
+ self.warm = warm_up
39
+
40
+ def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration):
41
+ del trainer, model, data_batch, output_batch, loss
42
+ self.counter += 1
43
+ if self.counter < self.warm:
44
+ return
45
+ if self.counter == self.warm:
46
+ gc.disable()
47
+ log.critical("Garbage collection disabled")
48
+
49
+ gc.collect(1)
imaginaire/config.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Training config system for Imaginare4"""
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ from typing import Any, TypeVar
22
+
23
+ import attrs
24
+ import torch
25
+ import torch.utils.data
26
+
27
+ from imaginaire.model import ImaginaireModel
28
+
29
+ try:
30
+ from megatron.core import ModelParallelConfig
31
+
32
+ USE_MEGATRON = True
33
+ except ImportError:
34
+ USE_MEGATRON = False
35
+ print("Megatron-core is not installed.")
36
+
37
+ import builtins
38
+
39
+ from imaginaire.lazy_config import LazyCall as L
40
+ from imaginaire.lazy_config import LazyDict
41
+ from imaginaire.utils import callback, distributed
42
+ from imaginaire.utils.misc import Color
43
+
44
+ T = TypeVar("T")
45
+
46
+
47
+ def _is_attrs_instance(obj: object) -> bool:
48
+ """
49
+ Helper function to check if an object is an instance of an attrs-defined class.
50
+
51
+ Args:
52
+ obj: The object to check.
53
+
54
+ Returns:
55
+ bool: True if the object is an instance of an attrs-defined class, False otherwise.
56
+ """
57
+ return hasattr(obj, "__attrs_attrs__")
58
+
59
+
60
+ def make_freezable(cls: T) -> T:
61
+ """
62
+ A decorator that adds the capability to freeze instances of an attrs-defined class.
63
+
64
+ NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
65
+ to hack on a "_is_frozen" attribute.
66
+
67
+ This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
68
+ Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
69
+ any attrs-defined objects that are attributes of the class.
70
+
71
+ Usage:
72
+ @make_freezable
73
+ @attrs.define(slots=False)
74
+ class MyClass:
75
+ attribute1: int
76
+ attribute2: str
77
+
78
+ obj = MyClass(1, 'a')
79
+ obj.freeze() # Freeze the instance
80
+ obj.attribute1 = 2 # Raises AttributeError
81
+
82
+ Args:
83
+ cls: The class to be decorated.
84
+
85
+ Returns:
86
+ The decorated class with added freezing capability.
87
+ """
88
+
89
+ if not hasattr(cls, "__dict__"):
90
+ raise TypeError(
91
+ "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
92
+ "class was defined with `@attrs.define(slots=False)`"
93
+ )
94
+
95
+ original_setattr = cls.__setattr__
96
+
97
+ def setattr_override(self, key, value) -> None:
98
+ """
99
+ Override __setattr__ to allow modifications during initialization
100
+ and prevent modifications once the instance is frozen.
101
+ """
102
+ if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
103
+ raise AttributeError("Cannot modify frozen instance")
104
+ original_setattr(self, key, value) # type: ignore
105
+
106
+ cls.__setattr__ = setattr_override # type: ignore
107
+
108
+ def freeze(self: object) -> None:
109
+ """
110
+ Freeze the instance and all its attrs-defined attributes.
111
+ """
112
+ for _, value in attrs.asdict(self, recurse=False).items():
113
+ if _is_attrs_instance(value) and hasattr(value, "freeze"):
114
+ value.freeze()
115
+ self._is_frozen = True # type: ignore
116
+
117
+ cls.freeze = freeze # type: ignore
118
+
119
+ return cls
120
+
121
+
122
+ def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
123
+ """
124
+ Recursively pretty prints attrs objects with color.
125
+ """
126
+
127
+ assert attrs.has(obj.__class__)
128
+
129
+ lines: list[str] = []
130
+ for attribute in attrs.fields(obj.__class__):
131
+ value = getattr(obj, attribute.name)
132
+ if attrs.has(value.__class__):
133
+ if use_color:
134
+ lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
135
+ else:
136
+ lines.append(" " * indent + "* " + attribute.name + ":")
137
+ lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
138
+ else:
139
+ if use_color:
140
+ lines.append(
141
+ " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
142
+ )
143
+ else:
144
+ lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
145
+ return "\n".join(lines)
146
+
147
+
148
+ def pretty_print_overrides(overrides: list[str] | None = None, use_color: bool = False) -> str:
149
+ """
150
+ Pretty prints overrides.
151
+ """
152
+
153
+ lines: list[str] = []
154
+ lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
155
+ for override in overrides:
156
+ if override == "--":
157
+ continue
158
+ if override.startswith("~"):
159
+ attribute_name = override[1:]
160
+ attribute_value = None
161
+ else:
162
+ attribute_name, attribute_value = override.split("=")
163
+ if use_color:
164
+ lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
165
+ else:
166
+ lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
167
+
168
+ return "\n".join(lines)
169
+
170
+
171
+ @make_freezable
172
+ @attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
173
+ class ObjectStoreConfig:
174
+ # Whether the file I/O is from object store instead of local disk.
175
+ enabled: bool = False
176
+ # Path to the object store credentials file.
177
+ credentials: str = ""
178
+ # Object store bucket to read from / write to the objects.
179
+ bucket: str = ""
180
+
181
+
182
+ @make_freezable
183
+ @attrs.define(slots=False)
184
+ class JobConfig:
185
+ # Project name.
186
+ project: str = ""
187
+ # Experiment name.
188
+ group: str = ""
189
+ # Run/job name.
190
+ name: str = ""
191
+
192
+ @property
193
+ def path(self) -> str:
194
+ return f"{self.project}/{self.group}/{self.name}"
195
+
196
+ @property
197
+ def path_local(self) -> str:
198
+ local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "checkpoints")
199
+ return f"{local_root}/{self.path}"
200
+
201
+
202
+ @make_freezable
203
+ @attrs.define(slots=False)
204
+ class EMAConfig:
205
+ # Enable tracking a set of exponential moving average (EMA) weights.
206
+ enabled: bool = False
207
+ # EMA decay rate.
208
+ beta: float = 0.9999
209
+ # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
210
+ torch_compile_buffer_renaming: bool = False
211
+
212
+
213
+ @make_freezable
214
+ @attrs.define(slots=False)
215
+ class PowerEMAConfig:
216
+ # Enable tracking a set of exponential moving average (EMA) weights.
217
+ enabled: bool = False
218
+ # EDM2 paper EMA decay rate.
219
+ s: float = 0.1
220
+ # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
221
+ torch_compile_buffer_renaming: bool = False
222
+
223
+
224
+ @make_freezable
225
+ @attrs.define(slots=False)
226
+ class DDPConfig:
227
+ # Traverse the computation graph to find parameters that don't receive gradients.
228
+ find_unused_parameters: bool = False
229
+ # Set to True if the computation graph does not change during the whole training loop.
230
+ static_graph: bool = True
231
+ # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
232
+ broadcast_buffers: bool = True
233
+
234
+
235
+ @make_freezable
236
+ @attrs.define(slots=False)
237
+ class CuDNNConfig:
238
+ # Set to True for better reproducibility of the results (only using deterministic cudnn functions).
239
+ deterministic: bool = False
240
+ # If set to True, cudnn will benchmark several algorithms and pick the fastest one.
241
+ benchmark: bool = True
242
+
243
+
244
+ @make_freezable
245
+ @attrs.define(slots=False)
246
+ class JITConfig:
247
+ # Enable exporting a JIT compiled model.
248
+ enabled: bool = False
249
+ # Input tensor shape, for example input.
250
+ input_shape: list[int] | None = None
251
+ # Device to compile onto.
252
+ device: str = "cuda"
253
+ # # Data type to compile onto.
254
+ dtype: str = "bfloat16"
255
+ # Strict mode for PyTorch JIT.
256
+ strict: bool = True
257
+
258
+
259
+ @make_freezable
260
+ @attrs.define(slots=False)
261
+ class CheckpointConfig:
262
+ # possible checkpoint class
263
+ type: dict | None = None
264
+ # for dcp, whether to use async mode
265
+ dcp_async_mode_enabled: bool = False
266
+ # Save the checkpoint every N iterations.
267
+ save_iter: int = 999999999
268
+ # Path of model weights to resume the checkpoint from.
269
+ load_path: str = ""
270
+ # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
271
+ load_training_state: bool = False
272
+ # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
273
+ only_load_scheduler_state: bool = False
274
+ # Load state_dict to the models in strict mode.
275
+ strict_resume: bool = True
276
+ # Configs for JIT compiling EMA model.
277
+ jit: JITConfig = attrs.field(factory=JITConfig)
278
+ # Print detailed information during checkpoint saving/loading.
279
+ verbose: bool = True
280
+ # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
281
+ keys_not_to_resume: list[str] = [] # noqa: RUF008
282
+ # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
283
+ broadcast_via_filesystem: bool = False
284
+ load_ema_to_reg: bool = False
285
+ # In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
286
+ dcp_allow_mismatched_size: bool = False
287
+
288
+
289
+ @make_freezable
290
+ @attrs.define(slots=False)
291
+ class NVTXConfig:
292
+ """Config for NVTX ranges used in the main training loop.
293
+
294
+ See tutorials/nanogpt for more details on how to integrate profiling into your model."""
295
+
296
+ # Enable the NVTX ranges.
297
+ enabled: bool = False
298
+ # Synchronize everything in each NVTX range.
299
+ cuda_synchronize: bool = False
300
+
301
+
302
+ @make_freezable
303
+ @attrs.define(slots=False)
304
+ class Profiling:
305
+ enable_profiling: bool = False
306
+ enable_memory_snapshot: bool = False
307
+ profile_freq: int = 1
308
+ first_n_rank: int = 8 # -1 means all ranks, n means first n ranks dumpy profiling info
309
+ record_shape: bool = True
310
+ profile_memory: bool = True
311
+ with_stack: bool = True
312
+ with_modules: bool = True
313
+
314
+
315
+ @make_freezable
316
+ @attrs.define(slots=False)
317
+ class TrainerConfig:
318
+ from imaginaire.trainer import ImaginaireTrainer
319
+
320
+ type: builtins.type[ImaginaireTrainer] = ImaginaireTrainer
321
+ # Set the callback class.
322
+ # Defaults to the callbacks below.
323
+ callbacks: LazyDict[dict[str, callback.Callback]] = LazyDict( # noqa: RUF009
324
+ dict(
325
+ ema=L(callback.EMAModelCallback)(),
326
+ progress_bar=L(callback.ProgressBarCallback)(),
327
+ )
328
+ )
329
+ # distributed parallelism strategy
330
+ distributed_parallelism: str = "ddp"
331
+ # Distributed data parallel configs.
332
+ ddp: DDPConfig = attrs.field(factory=DDPConfig)
333
+ # cuDNN configs.
334
+ cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
335
+ # Set the random seed.
336
+ seed: int = 0
337
+ # Gradient scaler arguments (for torch.amp.GradScaler).
338
+ grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
339
+ # Maximum number of iterations to train the model.
340
+ max_iter: int = 999999999
341
+ # Maximum number of iterations to validate the model. If None, validate on the entire dataset.
342
+ max_val_iter: int | None = None
343
+ # How often we log the training stats.
344
+ logging_iter: int = 100
345
+ # Whether we want to run the validation routines.
346
+ run_validation: bool = True
347
+ # How often we evaluate on the validation set.
348
+ validation_iter: int = 999999999
349
+ # Kill the process after N seconds since the last iteration (usually means dead job).
350
+ timeout_period: int = 999999999
351
+ # Tensor memory organization format.
352
+ memory_format: torch.memory_format = torch.preserve_format
353
+ # Gradient accumulation (update step every N iteration).
354
+ grad_accum_iter: int = 1
355
+ # Profiling config
356
+ profiling: Profiling = attrs.field(factory=Profiling)
357
+
358
+
359
+ @make_freezable
360
+ @attrs.define(slots=False)
361
+ class Config:
362
+ """Config for an imaginaire4 job.
363
+
364
+ See /README.md/Configuration System for more info.
365
+ """
366
+
367
+ # Model configs.
368
+ model: LazyDict[ImaginaireModel]
369
+ # Optimizer configs.
370
+ optimizer: LazyDict[torch.optim.Optimizer]
371
+ # Scheduler configs.
372
+ scheduler: LazyDict[torch.optim.lr_scheduler.LRScheduler]
373
+ # Training data configs.
374
+ dataloader_train: LazyDict[torch.utils.data.DataLoader]
375
+ # Validation data configs.
376
+ dataloader_val: LazyDict[torch.utils.data.DataLoader]
377
+
378
+ # Training job configs.
379
+ job: JobConfig = attrs.field(factory=JobConfig)
380
+
381
+ # Trainer configs.
382
+ trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
383
+
384
+ if USE_MEGATRON:
385
+ # Megatron-Core configs
386
+ model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
387
+ else:
388
+ model_parallel: None = None
389
+
390
+ # Checkpointer configs.
391
+ checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
392
+
393
+ def pretty_print(self, use_color: bool = False) -> str:
394
+ return _pretty_print_attrs_instance(self, 0, use_color)
395
+
396
+ def to_dict(self) -> dict[str, Any]:
397
+ return attrs.asdict(self)
398
+
399
+ def validate(self) -> None:
400
+ """Validate that the config has all required fields."""
401
+
402
+ # broadcast job.name across all ranks to make sure it is consistent
403
+ # otherwise, unaligned job names leads unaligned path to save checkpoints
404
+ job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
405
+ distributed.broadcast(job_name_tensor, 0)
406
+ self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
407
+
408
+ assert self.job.project != ""
409
+ assert self.job.group != ""
410
+ assert self.job.name != ""
imaginaire/lazy_config/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ from imaginaire.lazy_config.instantiate import instantiate
21
+ from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict
22
+ from imaginaire.lazy_config.omegaconf_patch import to_object
23
+
24
+ OmegaConf.to_object = to_object
25
+
26
+ PLACEHOLDER = None
27
+
28
+ __all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"]
29
+
30
+
31
+ DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
32
+
33
+
34
+ def fixup_module_metadata(module_name, namespace, keys=None):
35
+ """
36
+ Fix the __qualname__ of module members to be their exported api name, so
37
+ when they are referenced in docs, sphinx can find them. Reference:
38
+ https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
39
+ """
40
+ if not DOC_BUILDING:
41
+ return
42
+ seen_ids = set()
43
+
44
+ def fix_one(qualname, name, obj):
45
+ # avoid infinite recursion (relevant when using
46
+ # typing.Generic, for example)
47
+ if id(obj) in seen_ids:
48
+ return
49
+ seen_ids.add(id(obj))
50
+
51
+ mod = getattr(obj, "__module__", None)
52
+ if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
53
+ obj.__module__ = module_name
54
+ # Modules, unlike everything else in Python, put fully-qualitied
55
+ # names into their __name__ attribute. We check for "." to avoid
56
+ # rewriting these.
57
+ if hasattr(obj, "__name__") and "." not in obj.__name__:
58
+ obj.__name__ = name
59
+ obj.__qualname__ = qualname
60
+ if isinstance(obj, type):
61
+ for attr_name, attr_value in obj.__dict__.items():
62
+ fix_one(objname + "." + attr_name, attr_name, attr_value)
63
+
64
+ if keys is None:
65
+ keys = namespace.keys()
66
+ for objname in keys:
67
+ if not objname.startswith("_"):
68
+ obj = namespace[objname]
69
+ fix_one(objname, objname, obj)
70
+
71
+
72
+ fixup_module_metadata(__name__, globals(), __all__)
73
+ del fixup_module_metadata
imaginaire/lazy_config/file_io.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
17
+ from iopath.common.file_io import PathManager as PathManagerBase
18
+
19
+ __all__ = ["PathHandler", "PathManager"]
20
+
21
+
22
+ PathManager = PathManagerBase()
23
+ PathManager.register_handler(HTTPURLHandler())
24
+ PathManager.register_handler(OneDrivePathHandler())
imaginaire/lazy_config/instantiate.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections.abc as abc
17
+ import dataclasses
18
+ from typing import Any
19
+
20
+ import attrs
21
+
22
+ from imaginaire.lazy_config.registry import _convert_target_to_string, locate
23
+ from imaginaire.utils import log
24
+
25
+ __all__ = ["dump_dataclass", "instantiate"]
26
+
27
+
28
+ def is_dataclass_or_attrs(target):
29
+ return dataclasses.is_dataclass(target) or attrs.has(target)
30
+
31
+
32
+ def dump_dataclass(obj: Any):
33
+ """
34
+ Dump a dataclass recursively into a dict that can be later instantiated.
35
+
36
+ Args:
37
+ obj: a dataclass object
38
+
39
+ Returns:
40
+ dict
41
+ """
42
+ assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
43
+ "dump_dataclass() requires an instance of a dataclass."
44
+ )
45
+ ret = {"_target_": _convert_target_to_string(type(obj))}
46
+ for f in dataclasses.fields(obj):
47
+ v = getattr(obj, f.name)
48
+ if dataclasses.is_dataclass(v):
49
+ v = dump_dataclass(v)
50
+ if isinstance(v, (list, tuple)):
51
+ v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
52
+ ret[f.name] = v
53
+ return ret
54
+
55
+
56
+ def instantiate(cfg, *args, **kwargs):
57
+ """
58
+ Recursively instantiate objects defined in dictionaries by
59
+ "_target_" and arguments.
60
+
61
+ Args:
62
+ cfg: a dict-like object with "_target_" that defines the caller, and
63
+ other keys that define the arguments
64
+ args: Optional positional parameters pass-through.
65
+ kwargs: Optional named parameters pass-through.
66
+
67
+ Returns:
68
+ object instantiated by cfg
69
+ """
70
+ from omegaconf import DictConfig, ListConfig, OmegaConf
71
+
72
+ if isinstance(cfg, ListConfig):
73
+ lst = [instantiate(x) for x in cfg]
74
+ return ListConfig(lst, flags={"allow_objects": True})
75
+ if isinstance(cfg, list):
76
+ # Specialize for list, because many classes take
77
+ # list[objects] as arguments, such as ResNet, DatasetMapper
78
+ return [instantiate(x) for x in cfg]
79
+
80
+ # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
81
+ # instantiate it to the actual dataclass.
82
+ if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
83
+ return OmegaConf.to_object(cfg)
84
+
85
+ if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
86
+ # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
87
+ # but faster: https://github.com/facebookresearch/hydra/issues/1200
88
+ is_recursive = getattr(cfg, "_recursive_", True)
89
+ if is_recursive:
90
+ cfg = {k: instantiate(v) for k, v in cfg.items()}
91
+ else:
92
+ cfg = {k: v for k, v in cfg.items()}
93
+ # pop the _recursive_ key to avoid passing it as a parameter
94
+ if "_recursive_" in cfg:
95
+ cfg.pop("_recursive_")
96
+ cls = cfg.pop("_target_")
97
+ cls = instantiate(cls)
98
+
99
+ if isinstance(cls, str):
100
+ cls_name = cls
101
+ cls = locate(cls_name)
102
+ assert cls is not None, cls_name
103
+ else:
104
+ try:
105
+ cls_name = cls.__module__ + "." + cls.__qualname__
106
+ except Exception:
107
+ # target could be anything, so the above could fail
108
+ cls_name = str(cls)
109
+ assert callable(cls), f"_target_ {cls} does not define a callable object"
110
+ try:
111
+ # override config with kwargs
112
+ instantiate_kwargs = {}
113
+ instantiate_kwargs.update(cfg)
114
+ instantiate_kwargs.update(kwargs)
115
+ return cls(*args, **instantiate_kwargs)
116
+ except TypeError:
117
+ log.error(f"Error when instantiating {cls_name}!")
118
+ raise
119
+ return cfg # return as-is if don't know what to do
imaginaire/lazy_config/lazy.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import builtins
18
+ import collections.abc as abc
19
+ import importlib
20
+ import inspect
21
+ import logging
22
+ import os
23
+ import pickle
24
+ import uuid
25
+ from collections import OrderedDict
26
+ from contextlib import contextmanager
27
+ from copy import deepcopy
28
+ from dataclasses import is_dataclass
29
+ from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast
30
+
31
+ import attrs
32
+ import yaml
33
+ from omegaconf import DictConfig, ListConfig, OmegaConf
34
+
35
+ from imaginaire.utils import log
36
+
37
+ try:
38
+ import dill as dill_pickle
39
+ except ImportError:
40
+ dill_pickle = None
41
+
42
+ try:
43
+ import cloudpickle
44
+ except ImportError:
45
+ cloudpickle = None
46
+
47
+ from imaginaire.lazy_config.file_io import PathManager
48
+ from imaginaire.lazy_config.registry import _convert_target_to_string
49
+
50
+ __all__ = ["LazyCall", "LazyConfig", "LazyDict"]
51
+
52
+ T = TypeVar("T")
53
+
54
+
55
+ def sort_dict(d: dict[str, Any]) -> OrderedDict[str, Any]:
56
+ return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
57
+
58
+
59
+ def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
60
+ return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
61
+
62
+
63
+ def sort_recursive(obj: dict[str, Any] | list[Any] | Any) -> OrderedDict[str, Any] | list[Any] | Any:
64
+ if isinstance(obj, dict):
65
+ return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
66
+ elif isinstance(obj, list):
67
+ return [sort_recursive(item) for item in obj]
68
+ return obj
69
+
70
+
71
+ yaml.add_representer(OrderedDict, dict_representer)
72
+
73
+ OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
74
+ OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
75
+
76
+
77
+ def get_default_params(cls_or_func):
78
+ if callable(cls_or_func):
79
+ # inspect signature for function
80
+ signature = inspect.signature(cls_or_func)
81
+ else:
82
+ # inspect signature for class
83
+ signature = inspect.signature(cls_or_func.__init__)
84
+ params = signature.parameters
85
+ default_params = {
86
+ name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
87
+ }
88
+ return default_params
89
+
90
+
91
+ if TYPE_CHECKING:
92
+ # Have `LazyDict[T]` behave as `T`, so that attribute access works. Ideally, it
93
+ # would be a subclass of `T`, but this doesn't seem to be possible in the type
94
+ # system yet.
95
+ LazyDict: TypeAlias = T
96
+ else:
97
+ LazyDict = DictConfig
98
+
99
+
100
+ class LazyCall(Generic[T]):
101
+ """
102
+ Wrap a callable so that when it's called, the call will not be executed,
103
+ but returns a dict that describes the call.
104
+
105
+ LazyCall object has to be called with only keyword arguments. Positional
106
+ arguments are not yet supported.
107
+
108
+ Examples:
109
+ ::
110
+ from detectron2.config import instantiate, LazyCall
111
+
112
+ layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
113
+ layer_cfg.out_channels = 64 # can edit it afterwards
114
+ layer = instantiate(layer_cfg)
115
+ """
116
+
117
+ def __init__(self, target: type[T]):
118
+ if not (callable(target) or isinstance(target, (str, abc.Mapping))):
119
+ raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
120
+ self._target = target
121
+
122
+ def __call__(self, **kwargs) -> LazyDict[T]:
123
+ if is_dataclass(self._target) or attrs.has(self._target):
124
+ # omegaconf object cannot hold dataclass type
125
+ # https://github.com/omry/omegaconf/issues/784
126
+ target = _convert_target_to_string(self._target)
127
+ else:
128
+ target = self._target
129
+ kwargs["_target_"] = target
130
+
131
+ _final_params = get_default_params(self._target)
132
+ _final_params.update(kwargs)
133
+
134
+ return cast(LazyDict[T], DictConfig(content=_final_params, flags={"allow_objects": True}))
135
+
136
+
137
+ def _visit_dict_config(cfg, func):
138
+ """
139
+ Apply func recursively to all DictConfig in cfg.
140
+ """
141
+ if isinstance(cfg, DictConfig):
142
+ func(cfg)
143
+ for v in cfg.values():
144
+ _visit_dict_config(v, func)
145
+ elif isinstance(cfg, ListConfig):
146
+ for v in cfg:
147
+ _visit_dict_config(v, func)
148
+
149
+
150
+ def _validate_py_syntax(filename):
151
+ # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
152
+ with PathManager.open(filename, "r") as f:
153
+ content = f.read()
154
+ try:
155
+ ast.parse(content)
156
+ except SyntaxError as e:
157
+ raise SyntaxError(f"Config file {filename} has syntax error!") from e
158
+
159
+
160
+ def _cast_to_config(obj):
161
+ # if given a dict, return DictConfig instead
162
+ if isinstance(obj, dict):
163
+ return DictConfig(obj, flags={"allow_objects": True})
164
+ return obj
165
+
166
+
167
+ _CFG_PACKAGE_NAME = "detectron2._cfg_loader"
168
+ """
169
+ A namespace to put all imported config into.
170
+ """
171
+
172
+
173
+ def _random_package_name(filename):
174
+ # generate a random package name when loading config files
175
+ return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
176
+
177
+
178
+ @contextmanager
179
+ def _patch_import():
180
+ """
181
+ Enhance relative import statements in config files, so that they:
182
+ 1. locate files purely based on relative location, regardless of packages.
183
+ e.g. you can import file without having __init__
184
+ 2. do not cache modules globally; modifications of module states has no side effect
185
+ 3. support other storage system through PathManager, so config files can be in the cloud
186
+ 4. imported dict are turned into omegaconf.DictConfig automatically
187
+ """
188
+ old_import = builtins.__import__
189
+
190
+ def find_relative_file(original_file, relative_import_path, level):
191
+ # NOTE: "from . import x" is not handled. Because then it's unclear
192
+ # if such import should produce `x` as a python module or DictConfig.
193
+ # This can be discussed further if needed.
194
+ relative_import_err = """
195
+ Relative import of directories is not allowed within config files.
196
+ Within a config file, relative import can only import other config files.
197
+ """.replace("\n", " ")
198
+ if not len(relative_import_path):
199
+ raise ImportError(relative_import_err)
200
+
201
+ cur_file = os.path.dirname(original_file)
202
+ for _ in range(level - 1):
203
+ cur_file = os.path.dirname(cur_file)
204
+ cur_name = relative_import_path.lstrip(".")
205
+ for part in cur_name.split("."):
206
+ cur_file = os.path.join(cur_file, part)
207
+ if not cur_file.endswith(".py"):
208
+ cur_file += ".py"
209
+ if not PathManager.isfile(cur_file):
210
+ cur_file_no_suffix = cur_file[: -len(".py")]
211
+ if PathManager.isdir(cur_file_no_suffix):
212
+ raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
213
+ else:
214
+ raise ImportError(
215
+ f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
216
+ )
217
+ return cur_file
218
+
219
+ def new_import(name, globals=None, locals=None, fromlist=(), level=0):
220
+ if (
221
+ # Only deal with relative imports inside config files
222
+ level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
223
+ ):
224
+ cur_file = find_relative_file(globals["__file__"], name, level)
225
+ _validate_py_syntax(cur_file)
226
+ spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
227
+ module = importlib.util.module_from_spec(spec)
228
+ module.__file__ = cur_file
229
+ with PathManager.open(cur_file) as f:
230
+ content = f.read()
231
+ exec(compile(content, cur_file, "exec"), module.__dict__)
232
+ for name in fromlist: # turn imported dict into DictConfig automatically
233
+ val = _cast_to_config(module.__dict__[name])
234
+ module.__dict__[name] = val
235
+ return module
236
+ return old_import(name, globals, locals, fromlist=fromlist, level=level)
237
+
238
+ builtins.__import__ = new_import
239
+ yield new_import
240
+ builtins.__import__ = old_import
241
+
242
+
243
+ class LazyConfig:
244
+ """
245
+ Provide methods to save, load, and overrides an omegaconf config object
246
+ which may contain definition of lazily-constructed objects.
247
+ """
248
+
249
+ @staticmethod
250
+ def load_rel(filename: str, keys: None | str | tuple[str, ...] = None):
251
+ """
252
+ Similar to :meth:`load()`, but load path relative to the caller's
253
+ source file.
254
+
255
+ This has the same functionality as a relative import, except that this method
256
+ accepts filename as a string, so more characters are allowed in the filename.
257
+ """
258
+ caller_frame = inspect.stack()[1]
259
+ caller_fname = caller_frame[0].f_code.co_filename
260
+ assert caller_fname != "<string>", "load_rel Unable to find caller"
261
+ caller_dir = os.path.dirname(caller_fname)
262
+ filename = os.path.join(caller_dir, filename)
263
+ return LazyConfig.load(filename, keys)
264
+
265
+ @staticmethod
266
+ def load(filename: str, keys: None | str | tuple[str, ...] = None):
267
+ """
268
+ Load a config file.
269
+
270
+ Args:
271
+ filename: absolute path or relative path w.r.t. the current working directory
272
+ keys: keys to load and return. If not given, return all keys
273
+ (whose values are config objects) in a dict.
274
+ """
275
+ has_keys = keys is not None
276
+ filename = filename.replace("/./", "/") # redundant
277
+ if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
278
+ raise ValueError(f"Config file {filename} has to be a python or yaml file.")
279
+ if filename.endswith(".py"):
280
+ _validate_py_syntax(filename)
281
+
282
+ with _patch_import():
283
+ # Record the filename
284
+ module_namespace = {
285
+ "__file__": filename,
286
+ "__package__": _random_package_name(filename),
287
+ }
288
+ with PathManager.open(filename) as f:
289
+ content = f.read()
290
+ # Compile first with filename to:
291
+ # 1. make filename appears in stacktrace
292
+ # 2. make load_rel able to find its parent's (possibly remote) location
293
+ exec(compile(content, filename, "exec"), module_namespace)
294
+
295
+ ret = module_namespace
296
+ else:
297
+ with PathManager.open(filename) as f:
298
+ obj = yaml.unsafe_load(f)
299
+ ret = OmegaConf.create(obj, flags={"allow_objects": True})
300
+
301
+ if has_keys:
302
+ if isinstance(keys, str):
303
+ return _cast_to_config(ret[keys])
304
+ else:
305
+ return tuple(_cast_to_config(ret[a]) for a in keys)
306
+ else:
307
+ if filename.endswith(".py"):
308
+ # when not specified, only load those that are config objects
309
+ ret = DictConfig(
310
+ {
311
+ name: _cast_to_config(value)
312
+ for name, value in ret.items()
313
+ if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
314
+ },
315
+ flags={"allow_objects": True},
316
+ )
317
+ return ret
318
+
319
+ @staticmethod
320
+ def save_pkl(cfg, filename: str) -> str:
321
+ """
322
+ Saves a Config object to a file using pickle serialization. This method is typically used
323
+ when the configuration object contains complex objects, such as lambdas, that are not supported by
324
+ simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
325
+ object before serialization to ensure that the original object remains unmodified.
326
+
327
+ Args:
328
+ cfg: A Config object to be serialized and saved.
329
+ filename: The path and name of the file where the configuration should be saved. The function
330
+ assumes the file extension indicates a pickle format (e.g., .pkl).
331
+
332
+ Returns:
333
+ str: The filename to which the configuration was saved. This can be used to verify the file location
334
+ or log the outcome.
335
+
336
+ Notes:
337
+ - The function logs a warning if the configuration is successfully saved using pickle.
338
+ - If saving fails, an error is logged with the exception details.
339
+ """
340
+ try:
341
+ cfg = deepcopy(cfg)
342
+ except Exception:
343
+ pass
344
+
345
+ try:
346
+ with PathManager.open(filename, "wb") as f:
347
+ pickle.dump(cfg, f)
348
+ log.warning(f"Config is saved using pickle at {filename}.")
349
+ except Exception as e:
350
+ log.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
351
+ if dill_pickle:
352
+ try:
353
+ with PathManager.open(filename, "wb") as f:
354
+ pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
355
+ log.warning(f"Config is saved using dill at {filename}.")
356
+ except Exception as e:
357
+ log.error(f"Failed to save config to {filename}: {e}.")
358
+ if cloudpickle:
359
+ try:
360
+ with PathManager.open(filename, "wb") as f:
361
+ pickle.dump(cloudpickle.dumps(cfg), f)
362
+ log.warning(f"Config is saved using cloudpickle at {filename}.")
363
+ except Exception as e:
364
+ log.error(f"Failed to save config to {filename}: {e}.")
365
+ else:
366
+ log.error("cloudpickle is not available. Cannot save the config.")
367
+ raise e
368
+
369
+ return filename
370
+
371
+ @staticmethod
372
+ def save_yaml(cfg, filename: str) -> str:
373
+ """
374
+ Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
375
+
376
+ Args:
377
+ cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
378
+ filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
379
+
380
+ Returns:
381
+ str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
382
+
383
+ Notes:
384
+ - The function logs a warning if the configuration is successfully saved using YAML.
385
+ - If saving fails, an error is logged with the exception details.
386
+ """
387
+ logger = logging.getLogger(__name__)
388
+ try:
389
+ cfg = deepcopy(cfg)
390
+ except Exception:
391
+ pass
392
+
393
+ # Define a function to check if an item is serializable to YAML
394
+ def is_serializable(item):
395
+ try:
396
+ OmegaConf.to_yaml(item)
397
+ return True
398
+ except Exception as e:
399
+ return False
400
+
401
+ # Function to convert unserializable items to strings
402
+ def serialize_config(config):
403
+ if isinstance(config, DictConfig):
404
+ for key, value in config.items():
405
+ if isinstance(value, (DictConfig, ListConfig)):
406
+ try:
407
+ if "_target_" in value:
408
+ default_params = get_default_params(value["_target_"])
409
+ for default_key, default_v in default_params.items():
410
+ if default_key not in value:
411
+ value[default_key] = default_v
412
+ except Exception as e:
413
+ log.error(f"Failed to add default argument values: {e}")
414
+
415
+ serialize_config(value)
416
+ else:
417
+ if not is_serializable(value) and value is not None:
418
+ config[key] = str(value)
419
+ elif isinstance(config, ListConfig):
420
+ for i, item in enumerate(config):
421
+ if isinstance(item, (DictConfig, ListConfig)):
422
+ serialize_config(item)
423
+ else:
424
+ if not is_serializable(item) and item is not None:
425
+ config[i] = str(item)
426
+ else:
427
+ raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
428
+ return config
429
+
430
+ # Convert Config object to a DictConfig object.
431
+ config_dict = attrs.asdict(cfg)
432
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
433
+
434
+ # Serialize the DictConfig object by converting non-serializable objects to strings.
435
+ config_omegaconf = serialize_config(config_omegaconf)
436
+
437
+ config_dict: dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
438
+ sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
439
+ with open(filename, "w") as f:
440
+ yaml.dump(sorted_config, f, default_flow_style=False)
441
+ log.warning(f"Config is saved using omegaconf at {filename}.")
442
+ return filename
imaginaire/lazy_config/omegaconf_patch.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ from omegaconf import OmegaConf
19
+ from omegaconf.base import DictKeyType, SCMode
20
+ from omegaconf.dictconfig import DictConfig # pragma: no cover
21
+
22
+
23
+ def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any:
24
+ """
25
+ Converts an OmegaConf configuration object to a native Python container (dict or list), unless
26
+ the configuration is specifically created by LazyCall, in which case the original configuration
27
+ is returned directly.
28
+
29
+ This function serves as a modification of the original `to_object` method from OmegaConf,
30
+ preventing DictConfig objects created by LazyCall from being automatically converted to Python
31
+ dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
32
+ structure and behavior.
33
+
34
+ Differences from OmegaConf's original `to_object`:
35
+ - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
36
+
37
+ Reference:
38
+ - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
39
+
40
+ Args:
41
+ cfg (Any): The OmegaConf configuration object to convert.
42
+
43
+ Returns:
44
+ Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
45
+ `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
46
+
47
+ Examples:
48
+ >>> cfg = DictConfig({"key": "value", "_target_": "Model"})
49
+ >>> to_object(cfg)
50
+ DictConfig({"key": "value", "_target_": "Model"})
51
+
52
+ >>> cfg = DictConfig({"list": [1, 2, 3]})
53
+ >>> to_object(cfg)
54
+ {'list': [1, 2, 3]}
55
+ """
56
+ if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
57
+ return cfg
58
+
59
+ return OmegaConf.to_container(
60
+ cfg=cfg,
61
+ resolve=True,
62
+ throw_on_missing=True,
63
+ enum_to_str=False,
64
+ structured_config_mode=SCMode.INSTANTIATE,
65
+ )
imaginaire/lazy_config/registry.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pydoc
17
+ from typing import Any
18
+
19
+ from fvcore.common.registry import Registry # for backward compatibility.
20
+
21
+ """
22
+ ``Registry`` and `locate` provide ways to map a string (typically found
23
+ in config files) to callable objects.
24
+ """
25
+
26
+ __all__ = ["Registry", "locate"]
27
+
28
+
29
+ def _convert_target_to_string(t: Any) -> str:
30
+ """
31
+ Inverse of ``locate()``.
32
+
33
+ Args:
34
+ t: any object with ``__module__`` and ``__qualname__``
35
+ """
36
+ module, qualname = t.__module__, t.__qualname__
37
+
38
+ # Compress the path to this object, e.g. ``module.submodule._impl.class``
39
+ # may become ``module.submodule.class``, if the later also resolves to the same
40
+ # object. This simplifies the string, and also is less affected by moving the
41
+ # class implementation.
42
+ module_parts = module.split(".")
43
+ for k in range(1, len(module_parts)):
44
+ prefix = ".".join(module_parts[:k])
45
+ candidate = f"{prefix}.{qualname}"
46
+ try:
47
+ if locate(candidate) is t:
48
+ return candidate
49
+ except ImportError:
50
+ pass
51
+ return f"{module}.{qualname}"
52
+
53
+
54
+ def locate(name: str) -> Any:
55
+ """
56
+ Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
57
+ such as "module.submodule.class_name".
58
+
59
+ Raise Exception if it cannot be found.
60
+ """
61
+ obj = pydoc.locate(name)
62
+
63
+ # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
64
+ # by pydoc.locate. Try a private function from hydra.
65
+ if obj is None:
66
+ try:
67
+ # from hydra.utils import get_method - will print many errors
68
+ from hydra.utils import _locate
69
+ except ImportError as e:
70
+ raise ImportError(f"Cannot dynamically locate object {name}!") from e
71
+ else:
72
+ obj = _locate(name) # it raises if fails
73
+
74
+ return obj
imaginaire/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import torch
19
+
20
+ from imaginaire.lazy_config import LazyDict, instantiate
21
+
22
+
23
+ class ImaginaireModel(torch.nn.Module):
24
+ """The base model class of Imaginaire. It is inherited from torch.nn.Module.
25
+
26
+ All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
27
+ computation graphs. All inheriting child classes should implement the following methods:
28
+ - training_step(): The training step of the model, including the loss computation.
29
+ - validation_step(): The validation step of the model, including the loss computation.
30
+ - forward(): The computation graph for model inference.
31
+ The following methods have default implementations in ImaginaireModel:
32
+ - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ super().__init__()
37
+
38
+ def init_optimizer_scheduler(
39
+ self,
40
+ optimizer_config: LazyDict[torch.optim.Optimizer],
41
+ scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler],
42
+ ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
43
+ """Creates the optimizer and scheduler for the model.
44
+
45
+ Args:
46
+ config_model (ModelConfig): The config object for the model.
47
+
48
+ Returns:
49
+ optimizer (torch.optim.Optimizer): The model optimizer.
50
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
51
+ """
52
+ optimizer_config.params = self.parameters()
53
+ optimizer = instantiate(optimizer_config)
54
+ scheduler_config.optimizer = optimizer
55
+ scheduler = instantiate(scheduler_config)
56
+ return optimizer, scheduler
57
+
58
+ def training_step(
59
+ self, data_batch: dict[str, torch.Tensor], iteration: int
60
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
61
+ """The training step of the model, including the loss computation.
62
+
63
+ Args:
64
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
65
+ iteration (int): Current iteration number.
66
+
67
+ Returns:
68
+ output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
69
+ loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
70
+ """
71
+ raise NotImplementedError
72
+
73
+ @torch.no_grad()
74
+ def validation_step(
75
+ self, data_batch: dict[str, torch.Tensor], iteration: int
76
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
77
+ """The validation step of the model, including the loss computation.
78
+
79
+ Args:
80
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
81
+ iteration (int): Current iteration number.
82
+
83
+ Returns:
84
+ output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
85
+ loss (torch.Tensor): The total loss (weighted sum of various losses).
86
+ """
87
+ raise NotImplementedError
88
+
89
+ @torch.inference_mode()
90
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
91
+ """The computation graph for model inference.
92
+
93
+ Args:
94
+ *args: Whatever you decide to pass into the forward method.
95
+ **kwargs: Keyword arguments are also possible.
96
+
97
+ Return:
98
+ Your model's output.
99
+ """
100
+ raise NotImplementedError
101
+
102
+ def on_model_init_start(self, set_barrier=False) -> None:
103
+ return
104
+
105
+ def on_model_init_end(self, set_barrier=False) -> None:
106
+ return
107
+
108
+ def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
109
+ """The model preparation before the training is launched
110
+
111
+ Args:
112
+ memory_format (torch.memory_format): Memory format of the model.
113
+ """
114
+ pass
115
+
116
+ def on_before_zero_grad(
117
+ self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
118
+ ) -> None:
119
+ """Hook before zero_grad() is called.
120
+
121
+ Args:
122
+ optimizer (torch.optim.Optimizer): The model optimizer.
123
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
124
+ iteration (int): Current iteration number.
125
+ """
126
+ pass
127
+
128
+ def on_after_backward(self, iteration: int = 0) -> None:
129
+ """Hook after loss.backward() is called.
130
+
131
+ This method is called immediately after the backward pass, allowing for custom operations
132
+ or modifications to be performed on the gradients before the optimizer step.
133
+
134
+ Args:
135
+ iteration (int): Current iteration number.
136
+ """
137
+ pass
imaginaire/trainer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import os
19
+ import signal
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch.utils.data
24
+
25
+ from imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
26
+
27
+ try:
28
+ from megatron.core import parallel_state
29
+
30
+ USE_MEGATRON = True
31
+ except ImportError:
32
+ USE_MEGATRON = False
33
+ print("Megatron-core is not installed.")
34
+
35
+
36
+ from imaginaire.lazy_config import LazyConfig, instantiate
37
+ from imaginaire.model import ImaginaireModel
38
+ from imaginaire.utils import callback, distributed, log, misc
39
+ from imaginaire.utils.checkpointer import Checkpointer
40
+
41
+
42
+ class ImaginaireTrainer:
43
+ """The base trainer class of Imaginaire.
44
+
45
+ All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
46
+ (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
47
+ mixed-precision training (fp16/bf16).
48
+
49
+ Attributes:
50
+ checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
51
+ training_timer (misc.Timer): Timer object to time code blocks and functions.
52
+ """
53
+
54
+ def __init__(self, config):
55
+ """Constructor of the trainer.
56
+
57
+ Args:
58
+ config (Config): The config object for the Imaginaire codebase.
59
+ """
60
+ super().__init__()
61
+ self.config = config
62
+ # Set up the distributed computing environment.
63
+ with misc.timer("init_distributed"):
64
+ distributed.init()
65
+ # Set up parallel states.
66
+ if hasattr(config.model, "context_parallel_size"):
67
+ if config.model_parallel.context_parallel_size > 1:
68
+ raise ValueError(
69
+ "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
70
+ "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
71
+ )
72
+ else:
73
+ log.critical(
74
+ "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
75
+ )
76
+ config.model_parallel.context_parallel_size = config.model.context_parallel_size
77
+ if USE_MEGATRON:
78
+ if (
79
+ "create_gloo_process_groups"
80
+ in inspect.signature(parallel_state.initialize_model_parallel).parameters
81
+ ):
82
+ parallel_state.initialize_model_parallel(
83
+ pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
84
+ tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
85
+ context_parallel_size=config.model_parallel.context_parallel_size,
86
+ create_gloo_process_groups=False,
87
+ )
88
+ else:
89
+ parallel_state.initialize_model_parallel(
90
+ pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
91
+ tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
92
+ context_parallel_size=config.model_parallel.context_parallel_size,
93
+ )
94
+ # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
95
+ # It is not part of the original `parallel_state` API, so we need to set it manually.
96
+ parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
97
+ if parallel_state.sequence_parallel:
98
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
99
+
100
+ # Create the local job directory, save the config file, and pipe to a local log.
101
+ if distributed.is_rank0():
102
+ os.makedirs(config.job.path_local, exist_ok=True)
103
+ # Save the config as .pkl for reproducibility.
104
+ LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
105
+ # Save the config as .yaml for reading or parsing experiment hyperparameters.
106
+ LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
107
+ dist.barrier()
108
+ log.init_loguru_file(f"{config.job.path_local}/stdout.log")
109
+ if distributed.is_rank0():
110
+ # Print important environment variables and the effective config.
111
+ log.info("Config:\n" + config.pretty_print(use_color=True))
112
+ misc.print_environ_variables(["TORCH_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
113
+ # Set the random seed. If multi-GPU, different ranks are set with different seeds.
114
+ misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
115
+ # Initialize cuDNN.
116
+ torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
117
+ torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
118
+ # Floating-point precision settings.
119
+ torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
120
+ # Initialize the callback functions.
121
+ self.callbacks = callback.CallBackGroup(config=config, trainer=self)
122
+ # Initialize the model checkpointer.
123
+ if config.checkpoint.type is None:
124
+ self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
125
+ else:
126
+ self.checkpointer: Checkpointer = instantiate(
127
+ config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
128
+ )
129
+ # Initialize the timer for speed benchmarking.
130
+ self.training_timer = misc.TrainingTimer()
131
+ # Send a TimeoutError if a training step takes over timeout_period seconds.
132
+ signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
133
+
134
+ def train(
135
+ self,
136
+ model: ImaginaireModel,
137
+ dataloader_train: torch.utils.data.DataLoader,
138
+ dataloader_val: torch.utils.data.DataLoader,
139
+ ) -> None:
140
+ """The training function.
141
+
142
+ Args:
143
+ model (ImaginaireModel): The PyTorch model.
144
+ dataloader_train (torch.utils.data.DataLoader): The training data loader.
145
+ dataloader_val (torch.utils.data.DataLoader): The validation data loader.
146
+ """
147
+ # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
148
+ model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
149
+ model.on_train_start(self.config.trainer.memory_format)
150
+
151
+ # Initialize the optimizer, scheduler, and grad_scaler.
152
+ self.callbacks.on_optimizer_init_start()
153
+ optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
154
+ grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
155
+ self.callbacks.on_optimizer_init_end()
156
+ # Load the model checkpoint and get the starting iteration number.
157
+ iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
158
+ grad_accum_iter = 0
159
+ log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
160
+ if self.config.trainer.distributed_parallelism == "ddp":
161
+ # Create a DDP model wrapper.
162
+ model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
163
+ elif self.config.trainer.distributed_parallelism == "fsdp":
164
+ model_ddp = model
165
+ else:
166
+ raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
167
+ log.info("Starting training...")
168
+ self.callbacks.on_train_start(model, iteration=iteration)
169
+ # Initial validation.
170
+ if self.config.trainer.run_validation and iteration == 0:
171
+ self.validate(model, dataloader_val, iteration=iteration)
172
+ log.info("Initial validation done.")
173
+ _end_training = False
174
+ with (
175
+ maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
176
+ maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
177
+ ):
178
+ while True:
179
+ dataloader_train_iter = iter(dataloader_train)
180
+ while True:
181
+ self.callbacks.on_before_dataloading(iteration)
182
+ try:
183
+ with self.training_timer("dataloader_train"):
184
+ data_batch = next(dataloader_train_iter)
185
+ except StopIteration:
186
+ break
187
+ finally:
188
+ self.callbacks.on_after_dataloading(iteration)
189
+ # If max_iter is reached, exit the training loop.
190
+ if iteration >= self.config.trainer.max_iter:
191
+ _end_training = True
192
+ break
193
+ # Move all tensors in the data batch to GPU device.
194
+ data_batch = misc.to(data_batch, device="cuda")
195
+ # The actual training step.
196
+ self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
197
+ self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
198
+ if not model.training:
199
+ model_ddp.train()
200
+ assert model_ddp.training, "model_ddp is not in training mode."
201
+ assert model.training, "model is not in training mode."
202
+ output_batch, loss, grad_accum_iter = self.training_step(
203
+ model_ddp,
204
+ optimizer,
205
+ scheduler,
206
+ grad_scaler,
207
+ data_batch,
208
+ iteration=iteration,
209
+ grad_accum_iter=grad_accum_iter,
210
+ )
211
+ self.callbacks.on_training_step_batch_end(
212
+ model, data_batch, output_batch, loss, iteration=iteration
213
+ )
214
+ # If the gradients are still being accumulated, continue to load the next training batch.
215
+ if grad_accum_iter != 0:
216
+ continue
217
+ # Do the following when an actual optimizer (update) step has been made.
218
+ iteration += 1
219
+ # Save checkpoint.
220
+ if iteration % self.config.checkpoint.save_iter == 0:
221
+ self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
222
+ self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
223
+ # Validation.
224
+ if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
225
+ self.validate(model, dataloader_val, iteration=iteration)
226
+ # This iteration is successful; reset the timeout signal.
227
+ signal.alarm(self.config.trainer.timeout_period)
228
+ if torch_profiler:
229
+ torch_profiler.step()
230
+ if memory_profiler:
231
+ memory_profiler.step()
232
+ if _end_training:
233
+ break
234
+ log.success("Done with training.")
235
+ if iteration % self.config.checkpoint.save_iter != 0:
236
+ self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
237
+ self.callbacks.on_train_end(model, iteration=iteration)
238
+ self.checkpointer.finalize()
239
+ distributed.barrier()
240
+ self.callbacks.on_app_end()
241
+
242
+ def training_step(
243
+ self,
244
+ model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
245
+ optimizer: torch.optim.Optimizer,
246
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
247
+ grad_scaler: torch.amp.GradScaler,
248
+ data: dict[str, torch.Tensor],
249
+ iteration: int = 0,
250
+ grad_accum_iter: int = 0,
251
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
252
+ """The training step.
253
+
254
+ Args:
255
+ model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
256
+ module, depending on whether distributed training is enabled or not.
257
+ optimizer (torch.optim.Optimizer): The model optimizer.
258
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
259
+ grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
260
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
261
+ iteration (int): Current iteration number.
262
+ grad_accum_iter (int): Number of gradient accumulation iterations.
263
+
264
+ Returns:
265
+ output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
266
+ loss (torch.Tensor): The total loss of the training data batch.
267
+ """
268
+ # Only let DDP sync gradient at the last iteration of the gradient accumulation window
269
+ with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
270
+ self.callbacks.on_before_forward(iteration=iteration)
271
+ with self.training_timer("forward"):
272
+ output_batch, loss = model_ddp.training_step(data, iteration)
273
+ self.callbacks.on_after_forward(iteration=iteration)
274
+ self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
275
+ with self.training_timer("backward"):
276
+ loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
277
+ loss_scaled.backward()
278
+ if self.config.trainer.distributed_parallelism == "ddp":
279
+ model_ddp.module.on_after_backward()
280
+ else:
281
+ model_ddp.on_after_backward()
282
+ self.callbacks.on_after_backward(model_ddp, iteration=iteration)
283
+ grad_accum_iter += 1
284
+ if grad_accum_iter == self.config.trainer.grad_accum_iter:
285
+ with self.training_timer("optimizer_step"):
286
+ self.callbacks.on_before_optimizer_step(
287
+ model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
288
+ )
289
+ grad_scaler.step(optimizer)
290
+ grad_scaler.update()
291
+ scheduler.step()
292
+ self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
293
+ if self.config.trainer.distributed_parallelism == "ddp":
294
+ model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
295
+ else:
296
+ model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
297
+ optimizer.zero_grad(set_to_none=True)
298
+ grad_accum_iter = 0
299
+ return output_batch, loss, grad_accum_iter
300
+
301
+ @torch.no_grad()
302
+ def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
303
+ """Validate on the full validation dataset.
304
+
305
+ Args:
306
+ model (ImaginaireModel): The PyTorch model.
307
+ dataloader_val (torch.utils.data.DataLoader): The validation data loader.
308
+ iteration (int): Current iteration number.
309
+ """
310
+ log.info(f"Validating at iteration {iteration}...")
311
+ self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
312
+ model.eval()
313
+ # Evaluate on the full validation set.
314
+ with model.pipe.ema_scope(context="Validation", is_cpu=False):
315
+ for val_iter, data_batch in enumerate(dataloader_val):
316
+ if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
317
+ break
318
+ data_batch = misc.to(data_batch, device="cuda")
319
+ self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
320
+ output_batch, loss = model.validation_step(data_batch, iteration)
321
+ self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
322
+ self.callbacks.on_validation_end(model, iteration=iteration)
imaginaire/utils/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/utils/callback.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ import warnings
20
+ from collections.abc import Callable
21
+ from typing import TYPE_CHECKING, Any
22
+
23
+ import omegaconf
24
+ import torch
25
+ import torch.utils.data
26
+ import tqdm
27
+
28
+ from imaginaire.lazy_config import instantiate
29
+ from imaginaire.utils import distributed, log
30
+ from imaginaire.utils.misc import get_local_tensor_if_DTensor
31
+
32
+ try:
33
+ from megatron.core import parallel_state
34
+ except ImportError:
35
+ parallel_state = None
36
+ print("Megatron-core is not installed.")
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from imaginaire.config import Config
41
+ from imaginaire.model import ImaginaireModel
42
+ from imaginaire.trainer import ImaginaireTrainer
43
+
44
+
45
+ class CallBackGroup:
46
+ """A class for hosting a collection of callback objects.
47
+
48
+ It is used to execute callback functions of multiple callback objects with the same method name.
49
+ When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
50
+ self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
51
+
52
+ Attributes:
53
+ _callbacks (list[Callback]): List of callback objects.
54
+ """
55
+
56
+ def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
57
+ """Initializes the list of callback objects.
58
+
59
+ Args:
60
+ config (Config): The config object for the Imaginaire codebase.
61
+ trainer (ImaginaireTrainer): The main trainer.
62
+ """
63
+ self._callbacks = []
64
+ callback_configs = config.trainer.callbacks
65
+ if callback_configs:
66
+ if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
67
+ warnings.warn(
68
+ "The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
69
+ "Please update your code",
70
+ DeprecationWarning,
71
+ stacklevel=2,
72
+ )
73
+ callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
74
+ for callback_name, current_callback_cfg in callback_configs.items():
75
+ if "_target_" not in current_callback_cfg:
76
+ log.critical(
77
+ f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
78
+ )
79
+ continue
80
+ log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
81
+ _callback = instantiate(current_callback_cfg)
82
+ assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
83
+ _callback.config = config
84
+ _callback.trainer = trainer
85
+ self._callbacks.append(_callback)
86
+
87
+ def __getattr__(self, method_name: str) -> Callable:
88
+ """Loops through the callback objects to call the corresponding callback function.
89
+
90
+ Args:
91
+ method_name (str): Callback method name.
92
+ """
93
+
94
+ def multi_callback_wrapper(*args, **kwargs) -> None:
95
+ for callback in self._callbacks:
96
+ assert hasattr(callback, method_name)
97
+ method = getattr(callback, method_name)
98
+ assert callable(method)
99
+ _ = method(*args, **kwargs)
100
+
101
+ return multi_callback_wrapper
102
+
103
+
104
+ class Callback:
105
+ """The base class for all callbacks.
106
+
107
+ All callbacks should inherit from this class and adhere to the established method names and signatures.
108
+ """
109
+
110
+ def __init__(self, config: Config | None = None, trainer: ImaginaireTrainer | None = None):
111
+ """Initializes a Callback object.
112
+
113
+ Args:
114
+ config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
115
+ trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
116
+
117
+ Notes:
118
+ The config and trainer parameters are optional to maintain backward compatibility.
119
+ In future releases, these parameters will be removed. Upon using these parameters, a deprecation
120
+ warning will be issued.
121
+
122
+ """
123
+ if config is not None or trainer is not None:
124
+ warnings.warn(
125
+ "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
126
+ "Please update your code to create Callback instances without these parameters.",
127
+ DeprecationWarning,
128
+ stacklevel=2,
129
+ )
130
+ del config, trainer
131
+
132
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
133
+ pass
134
+
135
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
136
+ """
137
+ Called before the training step, for each batch. This is paired with on_training_step_end() but note that
138
+ when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
139
+ this function is called for every batch.
140
+ Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
141
+ for every batch, albeit with the same iteration number.
142
+ """
143
+ pass
144
+
145
+ def on_training_step_batch_start(
146
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
147
+ ) -> None:
148
+ """
149
+ Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
150
+ on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
151
+ Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
152
+ """
153
+ pass
154
+
155
+ def on_before_forward(self, iteration: int = 0) -> None:
156
+ pass
157
+
158
+ def on_after_forward(self, iteration: int = 0) -> None:
159
+ pass
160
+
161
+ def on_before_backward(
162
+ self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
163
+ ) -> None:
164
+ pass
165
+
166
+ def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
167
+ pass
168
+
169
+ def on_before_dataloading(self, iteration: int = 0) -> None:
170
+ pass
171
+
172
+ def on_after_dataloading(self, iteration: int = 0) -> None:
173
+ pass
174
+
175
+ def on_optimizer_init_start(self) -> None:
176
+ pass
177
+
178
+ def on_optimizer_init_end(self) -> None:
179
+ pass
180
+
181
+ def on_before_optimizer_step(
182
+ self,
183
+ model_ddp: distributed.DistributedDataParallel,
184
+ optimizer: torch.optim.Optimizer,
185
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
186
+ grad_scaler: torch.amp.GradScaler,
187
+ iteration: int = 0,
188
+ ) -> None:
189
+ pass
190
+
191
+ def on_before_zero_grad(
192
+ self,
193
+ model_ddp: distributed.DistributedDataParallel,
194
+ optimizer: torch.optim.Optimizer,
195
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
196
+ iteration: int = 0,
197
+ ) -> None:
198
+ pass
199
+
200
+ def on_training_step_batch_end(
201
+ self,
202
+ model: ImaginaireModel,
203
+ data_batch: dict[str, torch.Tensor],
204
+ output_batch: dict[str, torch.Tensor],
205
+ loss: torch.Tensor,
206
+ iteration: int = 0,
207
+ ) -> None:
208
+ """
209
+ Called at the end of a training step for every batch even when using gradient accumulation.
210
+ This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
211
+ and therefore it may be the same for multiple batches.
212
+ """
213
+ pass
214
+
215
+ def on_training_step_end(
216
+ self,
217
+ model: ImaginaireModel,
218
+ data_batch: dict[str, torch.Tensor],
219
+ output_batch: dict[str, torch.Tensor],
220
+ loss: torch.Tensor,
221
+ iteration: int = 0,
222
+ ) -> None:
223
+ """
224
+ Called at the end of a training step, but note that when using gradient accumulation, this is only called
225
+ when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
226
+ Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
227
+ for every batch.
228
+ """
229
+ pass
230
+
231
+ def on_validation_start(
232
+ self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
233
+ ) -> None:
234
+ pass
235
+
236
+ def on_validation_step_start(
237
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
238
+ ) -> None:
239
+ pass
240
+
241
+ def on_validation_step_end(
242
+ self,
243
+ model: ImaginaireModel,
244
+ data_batch: dict[str, torch.Tensor],
245
+ output_batch: dict[str, torch.Tensor],
246
+ loss: torch.Tensor,
247
+ iteration: int = 0,
248
+ ) -> None:
249
+ pass
250
+
251
+ def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
252
+ pass
253
+
254
+ def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
255
+ pass
256
+
257
+ def on_load_checkpoint_end(
258
+ self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: str | None = None
259
+ ) -> None:
260
+ pass
261
+
262
+ def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
263
+ pass
264
+
265
+ def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
266
+ """
267
+ Called when checkpoint saving is about to start.
268
+ """
269
+ pass
270
+
271
+ def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
272
+ """
273
+ Called when the synchronous part of checkpointing is finished, this function can be used
274
+ along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
275
+ Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
276
+ does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
277
+ for that.
278
+ """
279
+ pass
280
+
281
+ def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
282
+ """
283
+ Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
284
+ For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
285
+ checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
286
+ checkpointing, this function is called as soon as the notification is received from the checkpointer process,
287
+ which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
288
+ the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
289
+ as this would be a significant overestimate.
290
+ """
291
+ pass
292
+
293
+ def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
294
+ pass
295
+
296
+ def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
297
+ pass
298
+
299
+ def on_app_end(self) -> None:
300
+ pass
301
+
302
+
303
+ class EMAModelCallback(Callback):
304
+ """The callback class for tracking EMA model weights."""
305
+
306
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
307
+ # Set up the EMA model weight tracker.
308
+ if model.config.ema.enabled:
309
+ assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
310
+ # EMA model must be kept in FP32 precision.
311
+ model.ema = model.ema.to(dtype=torch.float32)
312
+ else:
313
+ assert not hasattr(model, "ema"), "There should be no EMA initialized."
314
+
315
+ def on_training_step_end(
316
+ self,
317
+ model: ImaginaireModel,
318
+ data_batch: dict[str, torch.Tensor],
319
+ output_batch: dict[str, torch.Tensor],
320
+ loss: torch.Tensor,
321
+ iteration: int = 0,
322
+ ) -> None:
323
+ # Update the EMA model with the new regular weights.
324
+ if model.config.ema.enabled:
325
+ model.ema.update_average(model, iteration)
326
+
327
+
328
+ class ProgressBarCallback(Callback):
329
+ """The callback class for visualizing the training/validation progress bar in the console."""
330
+
331
+ @distributed.rank0_only
332
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
333
+ self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
334
+
335
+ @distributed.rank0_only
336
+ def on_training_step_end(
337
+ self,
338
+ model: ImaginaireModel,
339
+ data_batch: dict[str, torch.Tensor],
340
+ output_batch: dict[str, torch.Tensor],
341
+ loss: torch.Tensor,
342
+ iteration: int = 0,
343
+ ) -> None:
344
+ self.train_pbar.update()
345
+
346
+ @distributed.rank0_only
347
+ def on_validation_start(
348
+ self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
349
+ ) -> None:
350
+ if self.config.trainer.max_val_iter is not None:
351
+ num_iter = self.config.trainer.max_val_iter
352
+ else:
353
+ num_iter = len(dataloader_val)
354
+ assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
355
+ self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
356
+
357
+ @distributed.rank0_only
358
+ def on_validation_step_end(
359
+ self,
360
+ model: ImaginaireModel,
361
+ data_batch: dict[str, torch.Tensor],
362
+ output_batch: dict[str, torch.Tensor],
363
+ loss: torch.Tensor,
364
+ iteration: int = 0,
365
+ ) -> None:
366
+ self.val_pbar.update()
367
+
368
+ @distributed.rank0_only
369
+ def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
370
+ self.val_pbar.close()
371
+
372
+ @distributed.rank0_only
373
+ def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
374
+ self.trainer.checkpointer.finalize()
375
+ self.train_pbar.close()
376
+
377
+
378
+ class IterationLoggerCallback(Callback):
379
+ """The callback class for visualizing the training/validation progress bar in the console."""
380
+
381
+ @distributed.rank0_only
382
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
383
+ # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
384
+ self.start_iteration_time = time.time()
385
+ self.elapsed_iteration_time = 0
386
+
387
+ @distributed.rank0_only
388
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
389
+ self.start_iteration_time = time.time()
390
+
391
+ @distributed.rank0_only
392
+ def on_training_step_end(
393
+ self,
394
+ model: ImaginaireModel,
395
+ data_batch: dict[str, torch.Tensor],
396
+ output_batch: dict[str, torch.Tensor],
397
+ loss: torch.Tensor,
398
+ iteration: int = 0,
399
+ ) -> None:
400
+ self.elapsed_iteration_time += time.time() - self.start_iteration_time
401
+
402
+ if iteration % self.config.trainer.logging_iter == 0:
403
+ avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
404
+ log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
405
+
406
+ self.elapsed_iteration_time = 0
407
+
408
+
409
+ class LowPrecisionCallback(Callback):
410
+ """The callback class handling low precision training
411
+
412
+ Config with non-primitive type makes it difficult to override the option.
413
+ The callback gets precision from model.precision instead.
414
+ It also auto disabled when using fp32.
415
+ """
416
+
417
+ def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
418
+ self.update_iter = update_iter
419
+
420
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
421
+ assert model.precision in [
422
+ torch.bfloat16,
423
+ torch.float16,
424
+ torch.half,
425
+ ], "LowPrecisionCallback must use a low precision dtype."
426
+ self.precision_type = model.precision
427
+
428
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
429
+ for k, v in data.items():
430
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
431
+ data[k] = v.to(dtype=self.precision_type)
432
+
433
+ def on_validation_step_start(
434
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
435
+ ) -> None:
436
+ for k, v in data.items():
437
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
438
+ data[k] = v.to(dtype=self.precision_type)
439
+
440
+ def on_before_zero_grad(
441
+ self,
442
+ model_ddp: distributed.DistributedDataParallel,
443
+ optimizer: torch.optim.Optimizer,
444
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
445
+ iteration: int = 0,
446
+ ) -> None:
447
+ if iteration % self.update_iter == 0:
448
+ if getattr(optimizer, "master_weights", False):
449
+ params, master_params = [], []
450
+ for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master, strict=False):
451
+ for p, p_master in zip(group["params"], group_master["params"], strict=False):
452
+ params.append(get_local_tensor_if_DTensor(p.data))
453
+ master_params.append(p_master.data)
454
+ torch._foreach_copy_(params, master_params)
455
+
456
+
457
+ class NVTXCallback(Callback):
458
+ """The callback for creating NVTX ranges"""
459
+
460
+ def __init__(
461
+ self,
462
+ synchronize: bool = False,
463
+ config: Config | None = None,
464
+ trainer: ImaginaireTrainer | None = None,
465
+ ):
466
+ super().__init__(config, trainer)
467
+ self.synchronize = synchronize
468
+
469
+ def on_before_forward(self, iteration: int = 0) -> None:
470
+ if self.synchronize:
471
+ torch.cuda.synchronize()
472
+ torch.cuda.nvtx.range_push("forward")
473
+
474
+ def on_after_forward(self, iteration: int = 0) -> None:
475
+ if self.synchronize:
476
+ torch.cuda.synchronize()
477
+ torch.cuda.nvtx.range_pop()
478
+
479
+ def on_before_backward(
480
+ self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
481
+ ) -> None:
482
+ if self.synchronize:
483
+ torch.cuda.synchronize()
484
+ torch.cuda.nvtx.range_push("backward")
485
+
486
+ def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
487
+ if self.synchronize:
488
+ torch.cuda.synchronize()
489
+ torch.cuda.nvtx.range_pop()
490
+
491
+ def on_before_optimizer_step(
492
+ self,
493
+ model_ddp: distributed.DistributedDataParallel,
494
+ optimizer: torch.optim.Optimizer,
495
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
496
+ grad_scaler: torch.amp.GradScaler,
497
+ iteration: int = 0,
498
+ ) -> None:
499
+ if self.synchronize:
500
+ torch.cuda.synchronize()
501
+ torch.cuda.nvtx.range_push("optimizer_step")
502
+
503
+ def on_before_zero_grad(
504
+ self,
505
+ model_ddp: distributed.DistributedDataParallel,
506
+ optimizer: torch.optim.Optimizer,
507
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
508
+ iteration: int = 0,
509
+ ) -> None:
510
+ if self.synchronize:
511
+ torch.cuda.synchronize()
512
+ torch.cuda.nvtx.range_pop()
513
+
514
+ def on_before_dataloading(self, iteration: int = 0) -> None:
515
+ torch.cuda.nvtx.range_push("dataloading")
516
+
517
+ def on_after_dataloading(self, iteration: int = 0) -> None:
518
+ torch.cuda.nvtx.range_pop()
imaginaire/utils/checkpointer.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import threading
20
+ from typing import TYPE_CHECKING, NamedTuple
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+ from torch import nn
25
+
26
+ from imaginaire.model import ImaginaireModel
27
+ from imaginaire.utils import callback, distributed, log, misc
28
+ from imaginaire.utils.parallelism import ModelWrapper
29
+
30
+ if TYPE_CHECKING:
31
+ from imaginaire.config import CheckpointConfig, JobConfig
32
+
33
+
34
+ class Checkpointer:
35
+ """The checkpointer class. Supports checkpoint saving/loading to local disk."""
36
+
37
+ def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
38
+ """Constructor of the checkpointer.
39
+
40
+ Args:
41
+ config_checkpoint (CheckpointConfig): The config object for the checkpointer.
42
+ """
43
+ # Set the callback functions.
44
+ self.callbacks = callbacks
45
+ self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
46
+ self.strict_resume = config_checkpoint.strict_resume
47
+ self.load_path = config_checkpoint.load_path or None
48
+ self.load_training_state = config_checkpoint.load_training_state
49
+ self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
50
+ self.save_thread = None
51
+
52
+ def save(
53
+ self,
54
+ model: ImaginaireModel,
55
+ optimizer: torch.optim.Optimizer,
56
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
57
+ grad_scaler: torch.amp.GradScaler,
58
+ iteration: int,
59
+ ) -> None:
60
+ """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
61
+
62
+ Args:
63
+ model (ImaginaireModel): The PyTorch model.
64
+ optimizer (torch.optim.Optimizer): The model optimizer.
65
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
66
+ grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
67
+ iteration (int): Current iteration number.
68
+ """
69
+ self.callbacks.on_save_checkpoint_start(model, iteration)
70
+
71
+ checkpoint_file = f"iter_{iteration:09}.pt"
72
+
73
+ if distributed.get_rank() == 0:
74
+ state_dict = dict(
75
+ model=model.state_dict(),
76
+ optimizer=optimizer.state_dict(),
77
+ scheduler=scheduler.state_dict(),
78
+ grad_scaler=grad_scaler.state_dict(),
79
+ iteration=iteration,
80
+ )
81
+ state_dict = misc.to(state_dict, device="cpu")
82
+ self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
83
+ # Wait for previous saver thread to end.
84
+ if self.save_thread:
85
+ self.save_thread.join()
86
+ # Run the checkpoint saver in a separate thread.
87
+ self.save_thread = threading.Thread(
88
+ target=self._save_worker_local,
89
+ daemon=False,
90
+ args=(state_dict, checkpoint_file, distributed.get_rank()),
91
+ )
92
+ self.save_thread.start()
93
+
94
+ # Note: Checkpoints are saved on a separate thread and this callback is not accurate.
95
+ # Please check logs from on_save_checkpoint_success() for better accuracy
96
+ self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
97
+
98
+ @misc.timer("checkpoint saving (local)")
99
+ def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
100
+ """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
101
+
102
+ Args:
103
+ state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
104
+ checkpoint_file (str): The file name of the model checkpoint.
105
+ rank (int): GPU device (default: 0).
106
+ """
107
+ checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
108
+ os.makedirs(self.checkpoint_dir_local, exist_ok=True)
109
+ try:
110
+ torch.save(state_dict, checkpoint_path)
111
+ if rank == 0:
112
+ self._write_latest_checkpoint_file(checkpoint_file)
113
+ log.success(f"Saved checkpoint (local): {checkpoint_path}")
114
+ iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
115
+ self.callbacks.on_save_checkpoint_success(iteration=iteration)
116
+ except Exception as e:
117
+ log.exception(f"Checkpoint failed to save (local): {e}")
118
+
119
+ @misc.timer("checkpoint loading")
120
+ def load(
121
+ self,
122
+ model: ImaginaireModel,
123
+ optimizer: torch.optim.Optimizer | None = None,
124
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
125
+ grad_scaler: torch.amp.GradScaler | None = None,
126
+ ) -> int:
127
+ """Load network weights and optimizer states from a checkpoint in a single process.
128
+
129
+ The priority of the checkpoint loading logic is:
130
+ 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
131
+ 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
132
+ - This is typically used for inference mode.
133
+ - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
134
+ 3. If none of the above, randomly initialize the model parameters and train from scratch.
135
+
136
+ Args:
137
+ model (ImaginaireModel): The PyTorch model.
138
+ optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
139
+ scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
140
+ grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
141
+
142
+ Returns:
143
+ iteration (int): the iteration number to start/resume from.
144
+ """
145
+ self.callbacks.on_load_checkpoint_start(model)
146
+
147
+ latest_checkpoint_file = self._read_latest_checkpoint_file()
148
+ if latest_checkpoint_file is not None:
149
+ # 1. Resume training from latest_checkpoint.txt under the same name.
150
+ checkpoint_dir = self.checkpoint_dir_local
151
+ checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
152
+ resume = True
153
+ only_resume_scheduler = True
154
+ else:
155
+ if self.load_path:
156
+ # 2. Load the module weights specified by config_checkpoint.path.
157
+ checkpoint_path = self.load_path
158
+ resume = self.load_training_state
159
+ only_resume_scheduler = self.only_load_scheduler_state
160
+ else:
161
+ # 3. Randomly initialize the model parameters and train from scratch.
162
+ checkpoint_path = None
163
+ resume = False
164
+ only_resume_scheduler = False
165
+ # Load checkpoint.
166
+ if checkpoint_path is not None:
167
+ self._check_checkpoint_exists(checkpoint_path)
168
+ log.info(f"Loading checkpoint (local): {checkpoint_path}")
169
+ state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
170
+ log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
171
+ self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
172
+ # Load the state dicts.
173
+ log.info("- Loading the model...")
174
+ model.load_state_dict(state_dict["model"], strict=self.strict_resume)
175
+ if resume or only_resume_scheduler:
176
+ iteration = state_dict["iteration"]
177
+ assert scheduler
178
+ log.info("- Loading the scheduler...")
179
+ scheduler.load_state_dict(state_dict["scheduler"])
180
+ scheduler.last_epoch = iteration
181
+ else:
182
+ iteration = 0
183
+ if resume:
184
+ assert optimizer
185
+ log.info("- Loading the optimizer...")
186
+ optimizer.load_state_dict(state_dict["optimizer"])
187
+ log.info("- Loading the gradient scaler...")
188
+ grad_scaler.load_state_dict(state_dict["grad_scaler"])
189
+ log.success(f"Done with loading the checkpoint (iteration {iteration}).")
190
+ else:
191
+ log.success("Done with loading the checkpoint.")
192
+ else:
193
+ # Checkpoint not found and not specified. We will train everything from scratch.
194
+ iteration = 0
195
+ log.info("Training from scratch.")
196
+ torch.cuda.empty_cache()
197
+
198
+ self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
199
+
200
+ return iteration
201
+
202
+ def _read_latest_checkpoint_file(self) -> str | None:
203
+ """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
204
+
205
+ Returns:
206
+ checkpoint_file (str | None): file name of the latest saved checkpoint.
207
+ """
208
+ checkpoint_file = None
209
+ latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
210
+ if os.path.isfile(latest_path):
211
+ checkpoint_file = open(latest_path).read().strip()
212
+ return checkpoint_file
213
+
214
+ def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
215
+ """Track the file name of the latest saved checkpoint.
216
+
217
+ Args:
218
+ checkpoint_file (str): file name of the latest saved checkpoint.
219
+ """
220
+ content = f"{checkpoint_file}\n"
221
+ latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
222
+ with open(latest_path, "w") as file:
223
+ file.write(content)
224
+
225
+ def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
226
+ """If the file checkpoint_path does not exist, raise an error.
227
+
228
+ Args:
229
+ checkpoint_path (str): full path to the checkpoint.
230
+ """
231
+ if not os.path.exists(checkpoint_path):
232
+ raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
233
+
234
+ def finalize(self) -> None:
235
+ """Finalize the checkpointer."""
236
+ if self.save_thread:
237
+ self.save_thread.join()
238
+
239
+
240
+ class _IncompatibleKeys(
241
+ NamedTuple(
242
+ "IncompatibleKeys",
243
+ [
244
+ ("missing_keys", list[str]),
245
+ ("unexpected_keys", list[str]),
246
+ ("incorrect_shapes", list[tuple[str, tuple[int], tuple[int]]]),
247
+ ],
248
+ )
249
+ ):
250
+ pass
251
+
252
+
253
+ def load_checkpoint(
254
+ model_parts: list[nn.Module],
255
+ ckpt_dir,
256
+ model_ckpt_key_map: dict[str, str] = {}, # noqa: B006
257
+ ):
258
+ log.info(f"Loading checkpoint from {ckpt_dir}.")
259
+
260
+ _model_wrapper = ModelWrapper(model_parts)
261
+ state_dict = _model_wrapper.state_dict()
262
+ # remove _extra_state
263
+ state_dict = {k: v for k, v in state_dict.items() if not k.endswith("._extra_state")}
264
+
265
+ # remap keys if needed
266
+ if model_ckpt_key_map:
267
+ for model_key, checkpoint_key in model_ckpt_key_map.items():
268
+ state_dict[checkpoint_key] = state_dict.pop(model_key)
269
+ log.info(f"Re-mapping {model_key} to {checkpoint_key}")
270
+
271
+ fs_storage_reader = dist.checkpoint.FileSystemReader(ckpt_dir)
272
+ dist.checkpoint.load(state_dict=state_dict, storage_reader=fs_storage_reader)
273
+
274
+ # inverse the remapping if needed
275
+ if model_ckpt_key_map:
276
+ for model_key, checkpoint_key in model_ckpt_key_map.items():
277
+ state_dict[model_key] = state_dict.pop(checkpoint_key)
278
+ log.info(f"Inverse re-mapping {checkpoint_key} to {model_key}")
279
+
280
+ _model_wrapper.load_state_dict(state_dict)
281
+
282
+ log.info(f"Finished loading checkpoint from {ckpt_dir}.")
imaginaire/utils/config_helper.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+ import pkgutil
19
+ import sys
20
+ from dataclasses import fields as dataclass_fields
21
+ from dataclasses import is_dataclass
22
+ from typing import Any
23
+
24
+ import attr
25
+ import attrs
26
+ from hydra import compose, initialize
27
+ from hydra.core.config_store import ConfigStore
28
+ from hydra.core.global_hydra import GlobalHydra
29
+ from omegaconf import DictConfig, OmegaConf
30
+
31
+ from imaginaire.config import Config
32
+ from imaginaire.utils import log
33
+
34
+
35
+ def is_attrs_or_dataclass(obj) -> bool:
36
+ """
37
+ Check if the object is an instance of an attrs class or a dataclass.
38
+
39
+ Args:
40
+ obj: The object to check.
41
+
42
+ Returns:
43
+ bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
44
+ """
45
+ return is_dataclass(obj) or attr.has(type(obj))
46
+
47
+
48
+ def get_fields(obj):
49
+ """
50
+ Get the fields of an attrs class or a dataclass.
51
+
52
+ Args:
53
+ obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
54
+
55
+ Returns:
56
+ list: A list of field names.
57
+
58
+ Raises:
59
+ ValueError: If the object is neither an attrs class nor a dataclass.
60
+ """
61
+ if is_dataclass(obj):
62
+ return [field.name for field in dataclass_fields(obj)]
63
+ elif attr.has(type(obj)):
64
+ return [field.name for field in attr.fields(type(obj))]
65
+ else:
66
+ raise ValueError("The object is neither an attrs class nor a dataclass.")
67
+
68
+
69
+ def override(config: Config, overrides: list[str] | None = None) -> Config:
70
+ """
71
+ :param config: the instance of class `Config` (usually from `make_config`)
72
+ :param overrides: list of overrides for config
73
+ :return: the composed instance of class `Config`
74
+ """
75
+ # Store the class of the config for reconstruction after overriding.
76
+ # config_class = type(config)
77
+
78
+ # Convert Config object to a DictConfig object
79
+ config_dict = attrs.asdict(config)
80
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
81
+ # Enforce "--" separator between the script arguments and overriding configs.
82
+ if overrides:
83
+ if overrides[0] != "--":
84
+ raise ValueError('Hydra config overrides must be separated with a "--" token.')
85
+ overrides = overrides[1:]
86
+ # Use Hydra to handle overrides
87
+ cs = ConfigStore.instance()
88
+ cs.store(name="config", node=config_omegaconf)
89
+ if not GlobalHydra().is_initialized():
90
+ with initialize(version_base=None):
91
+ config_omegaconf = compose(config_name="config", overrides=overrides)
92
+ OmegaConf.resolve(config_omegaconf)
93
+ else:
94
+ config_omegaconf = compose(config_name="config", overrides=overrides)
95
+ OmegaConf.resolve(config_omegaconf)
96
+
97
+ def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
98
+ """
99
+ Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
100
+
101
+ Args:
102
+ ref_instance: The reference instance to determine the type and fields when needed
103
+ kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
104
+
105
+ Returns:
106
+ Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
107
+
108
+ Raises:
109
+ AssertionError: If the fields do not match or if extra keys are found.
110
+ Exception: If there is an error constructing the new instance.
111
+ """
112
+ is_type = is_attrs_or_dataclass(ref_instance)
113
+ if not is_type:
114
+ return kwargs
115
+ else:
116
+ ref_fields = set(get_fields(ref_instance))
117
+ assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
118
+ "kwargs must be a dictionary or a DictConfig"
119
+ )
120
+ keys = set(kwargs.keys())
121
+
122
+ # ref_fields must equal to or include all keys
123
+ extra_keys = keys - ref_fields
124
+ assert ref_fields == keys or keys.issubset(ref_fields), (
125
+ f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
126
+ )
127
+
128
+ resolved_kwargs: dict[str, Any] = {}
129
+ for f in keys:
130
+ resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
131
+ try:
132
+ new_instance = type(ref_instance)(**resolved_kwargs)
133
+ except Exception as e:
134
+ log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
135
+ log.error(e)
136
+ raise e
137
+ return new_instance
138
+
139
+ config = config_from_dict(config, config_omegaconf)
140
+
141
+ return config
142
+
143
+
144
+ def get_config_module(config_file: str) -> str:
145
+ if not config_file.endswith(".py"):
146
+ log.error("Config file cannot be specified as module.")
147
+ log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
148
+ assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
149
+ # Convert to importable module format.
150
+ config_module = config_file.replace("/", ".").replace(".py", "")
151
+ return config_module
152
+
153
+
154
+ def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
155
+ """
156
+ Import all modules from the specified package path recursively.
157
+
158
+ This function is typically used in conjunction with Hydra to ensure that all modules
159
+ within a specified package are imported, which is necessary for registering configurations.
160
+
161
+ Example usage:
162
+ ```python
163
+ import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
164
+ ```
165
+
166
+ Args:
167
+ package_path (str): The dotted path to the package from which to import all modules.
168
+ reload (bool): Flag to determine whether to reload modules if they're already imported.
169
+ skip_underscore (bool): If True, skips importing modules that start with an underscore.
170
+ """
171
+ log.critical(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
172
+ package = importlib.import_module(package_path)
173
+ package_directory = package.__path__
174
+
175
+ def import_modules_recursively(directory: str, prefix: str) -> None:
176
+ """
177
+ Recursively imports or reloads all modules in the given directory.
178
+
179
+ Args:
180
+ directory (str): The file system path to the current package directory.
181
+ prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
182
+ """
183
+ for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
184
+ if skip_underscore and module_name.startswith("_"):
185
+ log.debug(f"Skipping module {module_name} as it starts with an underscore")
186
+ continue
187
+
188
+ full_module_name = f"{prefix}.{module_name}"
189
+ log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
190
+
191
+ if full_module_name in sys.modules and reload:
192
+ importlib.reload(sys.modules[full_module_name])
193
+ else:
194
+ importlib.import_module(full_module_name)
195
+
196
+ if is_pkg:
197
+ sub_package_directory = os.path.join(directory, module_name)
198
+ import_modules_recursively(sub_package_directory, full_module_name)
199
+
200
+ for directory in package_directory:
201
+ import_modules_recursively(directory, package_path)
imaginaire/utils/device.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import os
18
+
19
+ import pynvml
20
+
21
+
22
+ class Device:
23
+ _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
24
+
25
+ def __init__(self, device_idx: int):
26
+ super().__init__()
27
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
28
+
29
+ def get_name(self) -> str:
30
+ return pynvml.nvmlDeviceGetName(self.handle)
31
+
32
+ def get_cpu_affinity(self) -> list[int]:
33
+ affinity_string = ""
34
+ for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
35
+ # assume nvml returns list of 64 bit ints
36
+ affinity_string = f"{j:064b}" + affinity_string
37
+ affinity_list = [int(x) for x in affinity_string]
38
+ affinity_list.reverse() # so core 0 is in 0th element of list
39
+ return [i for i, e in enumerate(affinity_list) if e != 0]
imaginaire/utils/distributed.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import collections
19
+ import collections.abc
20
+ import ctypes
21
+ import functools
22
+ import os
23
+ from collections.abc import Callable, Container
24
+ from contextlib import contextmanager
25
+ from datetime import timedelta
26
+ from typing import TYPE_CHECKING, Any
27
+
28
+ import pynvml
29
+ import torch
30
+ import torch.distributed as dist
31
+ from torch.distributed import get_process_group_ranks
32
+
33
+ from imaginaire.utils.device import Device
34
+
35
+ if dist.is_available():
36
+ from torch.distributed.distributed_c10d import _get_default_group
37
+ from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
38
+
39
+ from imaginaire.utils import log
40
+
41
+ if TYPE_CHECKING:
42
+ from imaginaire.config import DDPConfig
43
+
44
+ try:
45
+ from megatron.core import parallel_state
46
+ except ImportError:
47
+ print("Megatron-core is not installed.")
48
+
49
+
50
+ def init() -> int | None:
51
+ """Initialize distributed training."""
52
+ if dist.is_initialized():
53
+ return torch.cuda.current_device()
54
+
55
+ # Set GPU affinity.
56
+ pynvml.nvmlInit()
57
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
58
+ try:
59
+ device = Device(local_rank)
60
+ os.sched_setaffinity(0, device.get_cpu_affinity())
61
+ except (OSError, pynvml.NVMLError) as e:
62
+ log.warning(f"Failed to set device affinity: {e}")
63
+ # Set up NCCL communication.
64
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
65
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
66
+ if dist.is_available():
67
+ torch.cuda.set_device(local_rank)
68
+ # Get the timeout value from environment variable
69
+ timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
70
+ # Convert the timeout to an integer (if it isn't already) and then to a timedelta
71
+ timeout_timedelta = timedelta(seconds=int(timeout_seconds))
72
+ dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
73
+ log.info(
74
+ f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
75
+ rank0_only=False,
76
+ )
77
+ # Increase the L2 fetch granularity for faster speed.
78
+ _libcudart = ctypes.CDLL("libcudart.so")
79
+ # Set device limit on the current device.
80
+ p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
81
+ _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
82
+ _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
83
+ log.info(f"Training with {get_world_size()} GPUs.")
84
+
85
+
86
+ def get_rank(group: dist.ProcessGroup | None = None) -> int:
87
+ """Get the rank (GPU device) of the worker.
88
+
89
+ Returns:
90
+ rank (int): The rank of the worker.
91
+ """
92
+ rank = 0
93
+ if dist.is_available() and dist.is_initialized():
94
+ rank = dist.get_rank(group)
95
+ return rank
96
+
97
+
98
+ def get_world_size(group: dist.ProcessGroup | None = None) -> int:
99
+ """Get world size. How many GPUs are available in this job.
100
+
101
+ Returns:
102
+ world_size (int): The total number of GPUs available in this job.
103
+ """
104
+ world_size = 1
105
+ if dist.is_available() and dist.is_initialized():
106
+ world_size = dist.get_world_size(group)
107
+ return world_size
108
+
109
+
110
+ def is_rank0() -> bool:
111
+ """Check if current process is the master GPU.
112
+
113
+ Returns:
114
+ (bool): True if this function is called from the master GPU, else False.
115
+ """
116
+ return get_rank() == 0
117
+
118
+
119
+ def is_local_rank0() -> bool:
120
+ """Check if current process is the local master GPU in the current node.
121
+
122
+ Returns:
123
+ (bool): True if this function is called from the local master GPU, else False.
124
+ """
125
+ return torch.cuda.current_device() == 0
126
+
127
+
128
+ def rank0_only(func: Callable) -> Callable:
129
+ """Apply this function only to the master GPU.
130
+
131
+ Example usage:
132
+ @rank0_only
133
+ def func(x):
134
+ return x + 3
135
+
136
+ Args:
137
+ func (Callable): a function.
138
+
139
+ Returns:
140
+ (Callable): A function wrapper executing the function only on the master GPU.
141
+ """
142
+
143
+ @functools.wraps(func)
144
+ def wrapper(*args, **kwargs):
145
+ if is_rank0():
146
+ return func(*args, **kwargs)
147
+ else:
148
+ return None
149
+
150
+ return wrapper
151
+
152
+
153
+ def barrier() -> None:
154
+ """Barrier for all GPUs."""
155
+ if dist.is_available() and dist.is_initialized():
156
+ dist.barrier()
157
+
158
+
159
+ def rank0_first(func: Callable) -> Callable:
160
+ """run the function on rank 0 first, then on other ranks."""
161
+
162
+ @functools.wraps(func)
163
+ def wrapper(*args, **kwargs):
164
+ if is_rank0():
165
+ result = func(*args, **kwargs)
166
+ barrier()
167
+ if not is_rank0():
168
+ result = func(*args, **kwargs)
169
+ return result
170
+
171
+ return wrapper
172
+
173
+
174
+ def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
175
+ """Wraps the model to enable data parallalism for training across multiple GPU devices.
176
+
177
+ Args:
178
+ config_ddp (DDPConfig): The data parallel config.
179
+ model (torch.nn.Module): The PyTorch module.
180
+
181
+ Returns:
182
+ model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
183
+ if distributed environment is available, otherwise return the original model.
184
+ """
185
+ if dist.is_available() and dist.is_initialized():
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ try:
188
+ ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
189
+ except Exception as e:
190
+ log.info(e)
191
+ log.info("parallel_state not initialized, treating all GPUs equally for DDP")
192
+ ddp_group = None
193
+
194
+ model = DistributedDataParallel(
195
+ model,
196
+ device_ids=[local_rank],
197
+ output_device=local_rank,
198
+ find_unused_parameters=config_ddp.find_unused_parameters,
199
+ static_graph=config_ddp.static_graph,
200
+ broadcast_buffers=config_ddp.broadcast_buffers,
201
+ process_group=ddp_group,
202
+ )
203
+ return model
204
+
205
+
206
+ class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
207
+ """This extends torch.nn.parallel.DistributedDataParallel with .training_step().
208
+
209
+ This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
210
+ model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
211
+ model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
212
+ training_step), allowing us to preserve the function names and signatures.
213
+ """
214
+
215
+ def __init__(self, model: torch.nn.Module, *args, **kwargs):
216
+ super().__init__(model, *args, **kwargs)
217
+ self.show_sync_grad_static_graph_warning = True
218
+
219
+ def training_step(self, *args, **kwargs) -> Any:
220
+ # Cache the original model.forward() method.
221
+ original_forward = self.module.forward
222
+
223
+ def wrapped_training_step(*_args, **_kwargs):
224
+ # Unpatch immediately before calling training_step() because itself may want to call the real forward.
225
+ self.module.forward = original_forward
226
+ # The actual .training_step().
227
+ return self.module.training_step(*_args, **_kwargs)
228
+
229
+ # Patch the original_module's forward so we can redirect the arguments back to the real method.
230
+ self.module.forward = wrapped_training_step
231
+ # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
232
+ # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
233
+ return self(*args, **kwargs)
234
+
235
+
236
+ @contextmanager
237
+ def ddp_sync_grad(model, enabled):
238
+ r"""
239
+ Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
240
+ Modified from:
241
+ https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
242
+ Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
243
+
244
+ Within this context, gradients will be accumulated on module
245
+ variables, which will later be synchronized in the first
246
+ forward-backward pass exiting the context.
247
+
248
+ .. warning::
249
+ The forward pass should be included inside the context manager, or
250
+ else gradients will still be synchronized.
251
+ """
252
+ assert isinstance(model, torch.nn.Module)
253
+ if isinstance(model, DistributedDataParallel):
254
+ old_require_backward_grad_sync = model.require_backward_grad_sync
255
+ if model.static_graph and model.require_backward_grad_sync != enabled:
256
+ if model.show_sync_grad_static_graph_warning:
257
+ log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
258
+ model.show_sync_grad_static_graph_warning = False
259
+ else:
260
+ model.require_backward_grad_sync = enabled
261
+ try:
262
+ yield
263
+ finally:
264
+ if isinstance(model, DistributedDataParallel):
265
+ model.require_backward_grad_sync = old_require_backward_grad_sync
266
+
267
+
268
+ def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
269
+ """Aggregate the list of data batches from all devices and process the results.
270
+
271
+ This is used for gathering validation data batches with imaginaire.utils.dataloader.DistributedEvalSampler.
272
+ It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
273
+ in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
274
+ created before calling dis.all_gather().
275
+
276
+ Args:
277
+ data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
278
+ leaf entries are tensors.
279
+
280
+ Returns:
281
+ data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
282
+ leaf entries are concatenated tensors.
283
+ """
284
+ if isinstance(data_batches[0], torch.Tensor):
285
+ # Concatenate the local data batches.
286
+ data_concat = torch.cat(data_batches, dim=0) # type: ignore
287
+ # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
288
+ max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
289
+ dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
290
+ if len(data_concat) < max_num_local_samples:
291
+ assert len(data_concat) + 1 == max_num_local_samples
292
+ dummy = torch.empty_like(data_concat[:1])
293
+ data_concat = torch.cat([data_concat, dummy], dim=0)
294
+ dummy_count = torch.tensor(1, device="cuda")
295
+ else:
296
+ dummy_count = torch.tensor(0, device="cuda")
297
+ # Get all concatenated batches from all ranks and concatenate again.
298
+ dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
299
+ data_concat = all_gather_tensor(data_concat.contiguous())
300
+ data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
301
+ # Remove the dummy samples.
302
+ if dummy_count > 0:
303
+ data_collate = data_collate[:-dummy_count]
304
+ elif isinstance(data_batches[0], collections.abc.Mapping):
305
+ data_collate = dict()
306
+ for key in data_batches[0].keys():
307
+ data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
308
+ else:
309
+ raise TypeError
310
+ return data_collate
311
+
312
+
313
+ @torch.no_grad()
314
+ def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
315
+ """Gather the corresponding tensor from all GPU devices to a list.
316
+
317
+ Args:
318
+ tensor (torch.Tensor): Pytorch tensor.
319
+
320
+ Returns:
321
+ tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
322
+ """
323
+ tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
324
+ dist.all_gather(tensor_list, tensor)
325
+ return tensor_list
326
+
327
+
328
+ def broadcast(tensor, src, group=None, async_op=False):
329
+ world_size = get_world_size()
330
+ if world_size < 2:
331
+ return tensor
332
+ dist.broadcast(tensor, src=src, group=group, async_op=async_op)
333
+
334
+
335
+ def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
336
+ r"""Reduce to rank 0"""
337
+ world_size = get_world_size()
338
+ if world_size < 2:
339
+ return tensor
340
+ with torch.no_grad():
341
+ dist.reduce(tensor, dst=rank)
342
+ if get_rank() == rank:
343
+ if reduce == "mean":
344
+ tensor /= world_size
345
+ elif reduce == "sum":
346
+ pass
347
+ else:
348
+ raise NotImplementedError
349
+ return tensor
350
+
351
+
352
+ def sync_model_states(
353
+ model: torch.nn.Module,
354
+ process_group: dist.ProcessGroup | None = None,
355
+ src: int = 0,
356
+ params_and_buffers_to_ignore: Container[str] | None = None,
357
+ broadcast_buffers: bool = True,
358
+ ):
359
+ """
360
+ Modify based on DDP source code
361
+ Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
362
+
363
+ This function ensures that all processes in the specified process group have the same initial parameters and
364
+ buffers from the source rank, typically rank 0. It is useful when different processes start with different model
365
+ states and a synchronization is required to ensure consistency across all ranks.
366
+
367
+ Args:
368
+ model (nn.Module): The model whose parameters and buffers are to be synchronized.
369
+ process_group (dist.ProcessGroup, optional): The process group for communication. If None,
370
+ the default group is used. Defaults to None.
371
+ src (int, optional): The source rank from which parameters and buffers will be broadcasted.
372
+ Defaults to 0.
373
+ params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
374
+ names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
375
+ included.
376
+ broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
377
+
378
+ Side Effects:
379
+ This function modifies the state of the model in-place to synchronize it with the source rank's model state.
380
+
381
+ Raises:
382
+ RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
383
+
384
+ Examples:
385
+ >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
386
+ >>> # useful and save our time when model weights are huge
387
+ >>> if dist.get_rank == 0:
388
+ >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
389
+ >>> dist.barrir()
390
+ >>> sync_model_states(model) # sync rank0 weights to other ranks
391
+ """
392
+ if not dist.is_available() or not dist.is_initialized():
393
+ return
394
+ if process_group is None:
395
+ process_group = _get_default_group()
396
+ if not params_and_buffers_to_ignore:
397
+ params_and_buffers_to_ignore = set()
398
+
399
+ log.info(
400
+ f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
401
+ )
402
+
403
+ # Build tuple of (module, parameter) for all parameters that require grads.
404
+ modules_and_parameters = [
405
+ (module, parameter)
406
+ for module_name, module in model.named_modules()
407
+ for parameter in [
408
+ param
409
+ # Note that we access module.named_parameters instead of
410
+ # parameters(module). parameters(module) is only needed in the
411
+ # single-process multi device case, where it accesses replicated
412
+ # parameters through _former_parameters.
413
+ for param_name, param in module.named_parameters(recurse=False)
414
+ if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
415
+ # if param.requires_grad
416
+ # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
417
+ ]
418
+ ]
419
+
420
+ # Deduplicate any parameters that might be shared across child modules.
421
+ memo = set()
422
+ modules_and_parameters = [
423
+ # "p not in memo" is the deduplication check.
424
+ # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
425
+ (m, p)
426
+ for m, p in modules_and_parameters
427
+ if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
428
+ ]
429
+
430
+ # Build list of parameters.
431
+ parameters = [parameter for _, parameter in modules_and_parameters]
432
+ if len(parameters) == 0:
433
+ return
434
+
435
+ _verify_param_shape_across_processes(process_group, parameters)
436
+
437
+ _sync_module_states(
438
+ module=model,
439
+ process_group=process_group,
440
+ broadcast_bucket_size=(250 * 1024 * 1024),
441
+ src=src,
442
+ params_and_buffers_to_ignore=params_and_buffers_to_ignore,
443
+ broadcast_buffers=broadcast_buffers,
444
+ )
imaginaire/utils/easy_io/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/utils/easy_io/backends/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
17
+ from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
18
+ from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
19
+ from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend
20
+
21
+ __all__ = [
22
+ "BaseStorageBackend",
23
+ "HTTPBackend",
24
+ "LocalBackend",
25
+ "backends",
26
+ "prefix_to_backends",
27
+ "register_backend",
28
+ ]
imaginaire/utils/easy_io/backends/base_backend.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import os.path as osp
18
+ from abc import ABCMeta, abstractmethod
19
+
20
+
21
+ def mkdir_or_exist(dir_name, mode=0o777):
22
+ if dir_name == "":
23
+ return
24
+ dir_name = osp.expanduser(dir_name)
25
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
26
+
27
+
28
+ def has_method(obj, method):
29
+ return hasattr(obj, method) and callable(getattr(obj, method))
30
+
31
+
32
+ class BaseStorageBackend(metaclass=ABCMeta):
33
+ """Abstract class of storage backends.
34
+
35
+ All backends need to implement two apis: :meth:`get()` and
36
+ :meth:`get_text()`.
37
+
38
+ - :meth:`get()` reads the file as a byte stream.
39
+ - :meth:`get_text()` reads the file as texts.
40
+ """
41
+
42
+ # a flag to indicate whether the backend can create a symlink for a file
43
+ # This attribute will be deprecated in future.
44
+ _allow_symlink = False
45
+
46
+ @property
47
+ def allow_symlink(self):
48
+ return self._allow_symlink
49
+
50
+ @property
51
+ def name(self):
52
+ return self.__class__.__name__
53
+
54
+ @abstractmethod
55
+ def get(self, filepath):
56
+ pass
57
+
58
+ @abstractmethod
59
+ def get_text(self, filepath):
60
+ pass
imaginaire/utils/easy_io/backends/http_backend.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import tempfile
18
+ from collections.abc import Generator
19
+ from contextlib import contextmanager
20
+ from pathlib import Path
21
+ from urllib.request import urlopen
22
+
23
+ from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
24
+
25
+
26
+ class HTTPBackend(BaseStorageBackend):
27
+ """HTTP and HTTPS storage bachend."""
28
+
29
+ def get(self, filepath: str) -> bytes:
30
+ """Read bytes from a given ``filepath``.
31
+
32
+ Args:
33
+ filepath (str): Path to read data.
34
+
35
+ Returns:
36
+ bytes: Expected bytes object.
37
+
38
+ Examples:
39
+ >>> backend = HTTPBackend()
40
+ >>> backend.get('http://path/of/file')
41
+ b'hello world'
42
+ """
43
+ return urlopen(filepath).read()
44
+
45
+ def get_text(self, filepath, encoding="utf-8") -> str:
46
+ """Read text from a given ``filepath``.
47
+
48
+ Args:
49
+ filepath (str): Path to read data.
50
+ encoding (str): The encoding format used to open the ``filepath``.
51
+ Defaults to 'utf-8'.
52
+
53
+ Returns:
54
+ str: Expected text reading from ``filepath``.
55
+
56
+ Examples:
57
+ >>> backend = HTTPBackend()
58
+ >>> backend.get_text('http://path/of/file')
59
+ 'hello world'
60
+ """
61
+ return urlopen(filepath).read().decode(encoding)
62
+
63
+ @contextmanager
64
+ def get_local_path(self, filepath: str) -> Generator[str | Path, None, None]:
65
+ """Download a file from ``filepath`` to a local temporary directory,
66
+ and return the temporary path.
67
+
68
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
69
+ can be called with ``with`` statement, and when exists from the
70
+ ``with`` statement, the temporary path will be released.
71
+
72
+ Args:
73
+ filepath (str): Download a file from ``filepath``.
74
+
75
+ Yields:
76
+ Iterable[str]: Only yield one temporary path.
77
+
78
+ Examples:
79
+ >>> backend = HTTPBackend()
80
+ >>> # After existing from the ``with`` clause,
81
+ >>> # the path will be removed
82
+ >>> with backend.get_local_path('http://path/of/file') as path:
83
+ ... # do something here
84
+ """
85
+ try:
86
+ f = tempfile.NamedTemporaryFile(delete=False)
87
+ f.write(self.get(filepath))
88
+ f.close()
89
+ yield f.name
90
+ finally:
91
+ os.remove(f.name)
imaginaire/utils/easy_io/backends/local_backend.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import io
17
+ import os
18
+ import os.path as osp
19
+ import shutil
20
+ from collections.abc import Generator, Iterator
21
+ from contextlib import contextmanager
22
+ from pathlib import Path
23
+
24
+ from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist
25
+
26
+
27
+ class LocalBackend(BaseStorageBackend):
28
+ """Raw local storage backend."""
29
+
30
+ _allow_symlink = True
31
+
32
+ def get(self, filepath: str | Path) -> bytes:
33
+ """Read bytes from a given ``filepath`` with 'rb' mode.
34
+
35
+ Args:
36
+ filepath (str or Path): Path to read data.
37
+
38
+ Returns:
39
+ bytes: Expected bytes object.
40
+
41
+ Examples:
42
+ >>> backend = LocalBackend()
43
+ >>> filepath = '/path/of/file'
44
+ >>> backend.get(filepath)
45
+ b'hello world'
46
+ """
47
+ with open(filepath, "rb") as f:
48
+ value = f.read()
49
+ return value
50
+
51
+ def get_text(self, filepath: str | Path, encoding: str = "utf-8") -> str:
52
+ """Read text from a given ``filepath`` with 'r' mode.
53
+
54
+ Args:
55
+ filepath (str or Path): Path to read data.
56
+ encoding (str): The encoding format used to open the ``filepath``.
57
+ Defaults to 'utf-8'.
58
+
59
+ Returns:
60
+ str: Expected text reading from ``filepath``.
61
+
62
+ Examples:
63
+ >>> backend = LocalBackend()
64
+ >>> filepath = '/path/of/file'
65
+ >>> backend.get_text(filepath)
66
+ 'hello world'
67
+ """
68
+ with open(filepath, encoding=encoding) as f:
69
+ text = f.read()
70
+ return text
71
+
72
+ def put(self, obj: bytes | io.BytesIO, filepath: str | Path) -> None:
73
+ """Write bytes to a given ``filepath`` with 'wb' mode.
74
+
75
+ Note:
76
+ ``put`` will create a directory if the directory of
77
+ ``filepath`` does not exist.
78
+
79
+ Args:
80
+ obj (bytes): Data to be written.
81
+ filepath (str or Path): Path to write data.
82
+
83
+ Examples:
84
+ >>> backend = LocalBackend()
85
+ >>> filepath = '/path/of/file'
86
+ >>> backend.put(b'hello world', filepath)
87
+ """
88
+ mkdir_or_exist(osp.dirname(filepath))
89
+ if isinstance(obj, io.BytesIO):
90
+ obj.seek(0)
91
+ obj = obj.getvalue()
92
+ with open(filepath, "wb") as f:
93
+ f.write(obj)
94
+
95
+ def put_text(self, obj: str, filepath: str | Path, encoding: str = "utf-8") -> None:
96
+ """Write text to a given ``filepath`` with 'w' mode.
97
+
98
+ Note:
99
+ ``put_text`` will create a directory if the directory of
100
+ ``filepath`` does not exist.
101
+
102
+ Args:
103
+ obj (str): Data to be written.
104
+ filepath (str or Path): Path to write data.
105
+ encoding (str): The encoding format used to open the ``filepath``.
106
+ Defaults to 'utf-8'.
107
+
108
+ Examples:
109
+ >>> backend = LocalBackend()
110
+ >>> filepath = '/path/of/file'
111
+ >>> backend.put_text('hello world', filepath)
112
+ """
113
+ mkdir_or_exist(osp.dirname(filepath))
114
+ with open(filepath, "w", encoding=encoding) as f:
115
+ f.write(obj)
116
+
117
+ def exists(self, filepath: str | Path) -> bool:
118
+ """Check whether a file path exists.
119
+
120
+ Args:
121
+ filepath (str or Path): Path to be checked whether exists.
122
+
123
+ Returns:
124
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
125
+
126
+ Examples:
127
+ >>> backend = LocalBackend()
128
+ >>> filepath = '/path/of/file'
129
+ >>> backend.exists(filepath)
130
+ True
131
+ """
132
+ return osp.exists(filepath)
133
+
134
+ def isdir(self, filepath: str | Path) -> bool:
135
+ """Check whether a file path is a directory.
136
+
137
+ Args:
138
+ filepath (str or Path): Path to be checked whether it is a
139
+ directory.
140
+
141
+ Returns:
142
+ bool: Return ``True`` if ``filepath`` points to a directory,
143
+ ``False`` otherwise.
144
+
145
+ Examples:
146
+ >>> backend = LocalBackend()
147
+ >>> filepath = '/path/of/dir'
148
+ >>> backend.isdir(filepath)
149
+ True
150
+ """
151
+ return osp.isdir(filepath)
152
+
153
+ def isfile(self, filepath: str | Path) -> bool:
154
+ """Check whether a file path is a file.
155
+
156
+ Args:
157
+ filepath (str or Path): Path to be checked whether it is a file.
158
+
159
+ Returns:
160
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
161
+ otherwise.
162
+
163
+ Examples:
164
+ >>> backend = LocalBackend()
165
+ >>> filepath = '/path/of/file'
166
+ >>> backend.isfile(filepath)
167
+ True
168
+ """
169
+ return osp.isfile(filepath)
170
+
171
+ def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str:
172
+ r"""Concatenate all file paths.
173
+
174
+ Join one or more filepath components intelligently. The return value
175
+ is the concatenation of filepath and any members of \*filepaths.
176
+
177
+ Args:
178
+ filepath (str or Path): Path to be concatenated.
179
+
180
+ Returns:
181
+ str: The result of concatenation.
182
+
183
+ Examples:
184
+ >>> backend = LocalBackend()
185
+ >>> filepath1 = '/path/of/dir1'
186
+ >>> filepath2 = 'dir2'
187
+ >>> filepath3 = 'path/of/file'
188
+ >>> backend.join_path(filepath1, filepath2, filepath3)
189
+ '/path/of/dir/dir2/path/of/file'
190
+ """
191
+ return osp.join(filepath, *filepaths)
192
+
193
+ @contextmanager
194
+ def get_local_path(self, filepath: str) -> Generator[str, None, None]:
195
+ """Download data from filepath to local path with a context manager.
196
+
197
+ If filepath exists in localhost, it just return filepath.
198
+ If filepath doesn't exist in localhost, it will download the data
199
+ to local path, and return the path, then the path will be removed
200
+ after existing from the with statement.
201
+
202
+ Args:
203
+ filepath (str): Path to be read data.
204
+
205
+ Yields:
206
+ str: Local path.
207
+
208
+ Examples:
209
+ >>> with backend.get_local_path('http://example.com/abc.jpg') as path:
210
+ ... # do something here
211
+ """
212
+ yield filepath
213
+
214
+ def copyfile(
215
+ self,
216
+ src: str | Path,
217
+ dst: str | Path,
218
+ ) -> str:
219
+ """Copy a file src to dst and return the destination file.
220
+
221
+ src and dst should have the same prefix. If dst specifies a directory,
222
+ the file will be copied into dst using the base filename from src. If
223
+ dst specifies a file that already exists, it will be replaced.
224
+
225
+ Args:
226
+ src (str or Path): A file to be copied.
227
+ dst (str or Path): Copy file to dst.
228
+
229
+ Returns:
230
+ str: The destination file.
231
+
232
+ Raises:
233
+ SameFileError: If src and dst are the same file, a SameFileError
234
+ will be raised.
235
+
236
+ Examples:
237
+ >>> backend = LocalBackend()
238
+ >>> # dst is a file
239
+ >>> src = '/path/of/file'
240
+ >>> dst = '/path1/of/file1'
241
+ >>> # src will be copied to '/path1/of/file1'
242
+ >>> backend.copyfile(src, dst)
243
+ '/path1/of/file1'
244
+
245
+ >>> # dst is a directory
246
+ >>> dst = '/path1/of/dir'
247
+ >>> # src will be copied to '/path1/of/dir/file'
248
+ >>> backend.copyfile(src, dst)
249
+ '/path1/of/dir/file'
250
+ """
251
+ return shutil.copy(src, dst)
252
+
253
+ def copytree(
254
+ self,
255
+ src: str | Path,
256
+ dst: str | Path,
257
+ ) -> str:
258
+ """Recursively copy an entire directory tree rooted at src to a
259
+ directory named dst and return the destination directory.
260
+
261
+ src and dst should have the same prefix and dst must not already exist.
262
+
263
+ Args:
264
+ src (str or Path): A directory to be copied.
265
+ dst (str or Path): Copy directory to dst.
266
+
267
+ Returns:
268
+ str: The destination directory.
269
+
270
+ Raises:
271
+ FileExistsError: If dst had already existed, a FileExistsError will
272
+ be raised.
273
+
274
+ Examples:
275
+ >>> backend = LocalBackend()
276
+ >>> src = '/path/of/dir1'
277
+ >>> dst = '/path/of/dir2'
278
+ >>> backend.copytree(src, dst)
279
+ '/path/of/dir2'
280
+ """
281
+ return shutil.copytree(src, dst)
282
+
283
+ def copyfile_from_local(
284
+ self,
285
+ src: str | Path,
286
+ dst: str | Path,
287
+ ) -> str:
288
+ """Copy a local file src to dst and return the destination file. Same
289
+ as :meth:`copyfile`.
290
+
291
+ Args:
292
+ src (str or Path): A local file to be copied.
293
+ dst (str or Path): Copy file to dst.
294
+
295
+ Returns:
296
+ str: If dst specifies a directory, the file will be copied into dst
297
+ using the base filename from src.
298
+
299
+ Raises:
300
+ SameFileError: If src and dst are the same file, a SameFileError
301
+ will be raised.
302
+
303
+ Examples:
304
+ >>> backend = LocalBackend()
305
+ >>> # dst is a file
306
+ >>> src = '/path/of/file'
307
+ >>> dst = '/path1/of/file1'
308
+ >>> # src will be copied to '/path1/of/file1'
309
+ >>> backend.copyfile_from_local(src, dst)
310
+ '/path1/of/file1'
311
+
312
+ >>> # dst is a directory
313
+ >>> dst = '/path1/of/dir'
314
+ >>> # src will be copied to
315
+ >>> backend.copyfile_from_local(src, dst)
316
+ '/path1/of/dir/file'
317
+ """
318
+ return self.copyfile(src, dst)
319
+
320
+ def copytree_from_local(
321
+ self,
322
+ src: str | Path,
323
+ dst: str | Path,
324
+ ) -> str:
325
+ """Recursively copy an entire directory tree rooted at src to a
326
+ directory named dst and return the destination directory. Same as
327
+ :meth:`copytree`.
328
+
329
+ Args:
330
+ src (str or Path): A local directory to be copied.
331
+ dst (str or Path): Copy directory to dst.
332
+
333
+ Returns:
334
+ str: The destination directory.
335
+
336
+ Examples:
337
+ >>> backend = LocalBackend()
338
+ >>> src = '/path/of/dir1'
339
+ >>> dst = '/path/of/dir2'
340
+ >>> backend.copytree_from_local(src, dst)
341
+ '/path/of/dir2'
342
+ """
343
+ return self.copytree(src, dst)
344
+
345
+ def copyfile_to_local(
346
+ self,
347
+ src: str | Path,
348
+ dst: str | Path,
349
+ dst_type: str | None = None,
350
+ ) -> str:
351
+ """Copy the file src to local dst and return the destination file. Same
352
+ as :meth:`copyfile`.
353
+
354
+ If dst specifies a directory, the file will be copied into dst using
355
+ the base filename from src. If dst specifies a file that already
356
+ exists, it will be replaced.
357
+
358
+ Args:
359
+ src (str or Path): A file to be copied.
360
+ dst (str or Path): Copy file to to local dst.
361
+
362
+ Returns:
363
+ str: If dst specifies a directory, the file will be copied into dst
364
+ using the base filename from src.
365
+
366
+ Examples:
367
+ >>> backend = LocalBackend()
368
+ >>> # dst is a file
369
+ >>> src = '/path/of/file'
370
+ >>> dst = '/path1/of/file1'
371
+ >>> # src will be copied to '/path1/of/file1'
372
+ >>> backend.copyfile_to_local(src, dst)
373
+ '/path1/of/file1'
374
+
375
+ >>> # dst is a directory
376
+ >>> dst = '/path1/of/dir'
377
+ >>> # src will be copied to
378
+ >>> backend.copyfile_to_local(src, dst)
379
+ '/path1/of/dir/file'
380
+ """
381
+ return self.copyfile(src, dst)
382
+
383
+ def copytree_to_local(
384
+ self,
385
+ src: str | Path,
386
+ dst: str | Path,
387
+ ) -> str:
388
+ """Recursively copy an entire directory tree rooted at src to a local
389
+ directory named dst and return the destination directory.
390
+
391
+ Args:
392
+ src (str or Path): A directory to be copied.
393
+ dst (str or Path): Copy directory to local dst.
394
+ backend_args (dict, optional): Arguments to instantiate the
395
+ prefix of uri corresponding backend. Defaults to None.
396
+
397
+ Returns:
398
+ str: The destination directory.
399
+
400
+ Examples:
401
+ >>> backend = LocalBackend()
402
+ >>> src = '/path/of/dir1'
403
+ >>> dst = '/path/of/dir2'
404
+ >>> backend.copytree_from_local(src, dst)
405
+ '/path/of/dir2'
406
+ """
407
+ return self.copytree(src, dst)
408
+
409
+ def remove(self, filepath: str | Path) -> None:
410
+ """Remove a file.
411
+
412
+ Args:
413
+ filepath (str or Path): Path to be removed.
414
+
415
+ Raises:
416
+ IsADirectoryError: If filepath is a directory, an IsADirectoryError
417
+ will be raised.
418
+ FileNotFoundError: If filepath does not exist, an FileNotFoundError
419
+ will be raised.
420
+
421
+ Examples:
422
+ >>> backend = LocalBackend()
423
+ >>> filepath = '/path/of/file'
424
+ >>> backend.remove(filepath)
425
+ """
426
+ if not self.exists(filepath):
427
+ raise FileNotFoundError(f"filepath {filepath} does not exist")
428
+
429
+ if self.isdir(filepath):
430
+ raise IsADirectoryError("filepath should be a file")
431
+
432
+ os.remove(filepath)
433
+
434
+ def rmtree(self, dir_path: str | Path) -> None:
435
+ """Recursively delete a directory tree.
436
+
437
+ Args:
438
+ dir_path (str or Path): A directory to be removed.
439
+
440
+ Examples:
441
+ >>> dir_path = '/path/of/dir'
442
+ >>> backend.rmtree(dir_path)
443
+ """
444
+ shutil.rmtree(dir_path)
445
+
446
+ def copy_if_symlink_fails(
447
+ self,
448
+ src: str | Path,
449
+ dst: str | Path,
450
+ ) -> bool:
451
+ """Create a symbolic link pointing to src named dst.
452
+
453
+ If failed to create a symbolic link pointing to src, directly copy src
454
+ to dst instead.
455
+
456
+ Args:
457
+ src (str or Path): Create a symbolic link pointing to src.
458
+ dst (str or Path): Create a symbolic link named dst.
459
+
460
+ Returns:
461
+ bool: Return True if successfully create a symbolic link pointing
462
+ to src. Otherwise, return False.
463
+
464
+ Examples:
465
+ >>> backend = LocalBackend()
466
+ >>> src = '/path/of/file'
467
+ >>> dst = '/path1/of/file1'
468
+ >>> backend.copy_if_symlink_fails(src, dst)
469
+ True
470
+ >>> src = '/path/of/dir'
471
+ >>> dst = '/path1/of/dir1'
472
+ >>> backend.copy_if_symlink_fails(src, dst)
473
+ True
474
+ """
475
+ try:
476
+ os.symlink(src, dst)
477
+ return True
478
+ except Exception:
479
+ if self.isfile(src):
480
+ self.copyfile(src, dst)
481
+ else:
482
+ self.copytree(src, dst)
483
+ return False
484
+
485
+ def list_dir_or_file(
486
+ self,
487
+ dir_path: str | Path,
488
+ list_dir: bool = True,
489
+ list_file: bool = True,
490
+ suffix: str | tuple[str] | None = None,
491
+ recursive: bool = False,
492
+ ) -> Iterator[str]:
493
+ """Scan a directory to find the interested directories or files in
494
+ arbitrary order.
495
+
496
+ Note:
497
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
498
+
499
+ Args:
500
+ dir_path (str or Path): Path of the directory.
501
+ list_dir (bool): List the directories. Defaults to True.
502
+ list_file (bool): List the path of files. Defaults to True.
503
+ suffix (str or tuple[str], optional): File suffix that we are
504
+ interested in. Defaults to None.
505
+ recursive (bool): If set to True, recursively scan the directory.
506
+ Defaults to False.
507
+
508
+ Yields:
509
+ Iterable[str]: A relative path to ``dir_path``.
510
+
511
+ Examples:
512
+ >>> backend = LocalBackend()
513
+ >>> dir_path = '/path/of/dir'
514
+ >>> # list those files and directories in current directory
515
+ >>> for file_path in backend.list_dir_or_file(dir_path):
516
+ ... print(file_path)
517
+ >>> # only list files
518
+ >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
519
+ ... print(file_path)
520
+ >>> # only list directories
521
+ >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
522
+ ... print(file_path)
523
+ >>> # only list files ending with specified suffixes
524
+ >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
525
+ ... print(file_path)
526
+ >>> # list all files and directory recursively
527
+ >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
528
+ ... print(file_path)
529
+ """
530
+ if list_dir and suffix is not None:
531
+ raise TypeError("`suffix` should be None when `list_dir` is True")
532
+
533
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
534
+ raise TypeError("`suffix` must be a string or tuple of strings")
535
+
536
+ root = dir_path
537
+
538
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
539
+ for entry in os.scandir(dir_path):
540
+ if not entry.name.startswith(".") and entry.is_file():
541
+ rel_path = osp.relpath(entry.path, root)
542
+ if (suffix is None or rel_path.endswith(suffix)) and list_file:
543
+ yield rel_path
544
+ elif osp.isdir(entry.path):
545
+ if list_dir:
546
+ rel_dir = osp.relpath(entry.path, root)
547
+ yield rel_dir
548
+ if recursive:
549
+ yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive)
550
+
551
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
imaginaire/utils/easy_io/backends/registry_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+
18
+ from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
19
+ from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
20
+ from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
21
+
22
+ backends: dict = {}
23
+ prefix_to_backends: dict = {}
24
+
25
+
26
+ def _register_backend(
27
+ name: str,
28
+ backend: type[BaseStorageBackend],
29
+ force: bool = False,
30
+ prefixes: str | list | tuple | None = None,
31
+ ):
32
+ """Register a backend.
33
+
34
+ Args:
35
+ name (str): The name of the registered backend.
36
+ backend (BaseStorageBackend): The backend class to be registered,
37
+ which must be a subclass of :class:`BaseStorageBackend`.
38
+ force (bool): Whether to override the backend if the name has already
39
+ been registered. Defaults to False.
40
+ prefixes (str or list[str] or tuple[str], optional): The prefix
41
+ of the registered storage backend. Defaults to None.
42
+ """
43
+ global backends, prefix_to_backends
44
+
45
+ if not isinstance(name, str):
46
+ raise TypeError(f"the backend name should be a string, but got {type(name)}")
47
+
48
+ if not inspect.isclass(backend):
49
+ raise TypeError(f"backend should be a class, but got {type(backend)}")
50
+ if not issubclass(backend, BaseStorageBackend):
51
+ raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
52
+
53
+ if name in backends and not force:
54
+ raise ValueError(
55
+ f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
56
+ )
57
+ backends[name] = backend
58
+
59
+ if prefixes is not None:
60
+ if isinstance(prefixes, str):
61
+ prefixes = [prefixes]
62
+ else:
63
+ assert isinstance(prefixes, (list, tuple))
64
+
65
+ for prefix in prefixes:
66
+ if prefix in prefix_to_backends and not force:
67
+ raise ValueError(
68
+ f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it'
69
+ )
70
+
71
+ prefix_to_backends[prefix] = backend
72
+
73
+
74
+ def register_backend(
75
+ name: str,
76
+ backend: type[BaseStorageBackend] | None = None,
77
+ force: bool = False,
78
+ prefixes: str | list | tuple | None = None,
79
+ ):
80
+ """Register a backend.
81
+
82
+ Args:
83
+ name (str): The name of the registered backend.
84
+ backend (class, optional): The backend class to be registered,
85
+ which must be a subclass of :class:`BaseStorageBackend`.
86
+ When this method is used as a decorator, backend is None.
87
+ Defaults to None.
88
+ force (bool): Whether to override the backend if the name has already
89
+ been registered. Defaults to False.
90
+ prefixes (str or list[str] or tuple[str], optional): The prefix
91
+ of the registered storage backend. Defaults to None.
92
+
93
+ This method can be used as a normal method or a decorator.
94
+
95
+ Examples:
96
+
97
+ >>> class NewBackend(BaseStorageBackend):
98
+ ... def get(self, filepath):
99
+ ... return filepath
100
+ ...
101
+ ... def get_text(self, filepath):
102
+ ... return filepath
103
+ >>> register_backend('new', NewBackend)
104
+
105
+ >>> @register_backend('new')
106
+ ... class NewBackend(BaseStorageBackend):
107
+ ... def get(self, filepath):
108
+ ... return filepath
109
+ ...
110
+ ... def get_text(self, filepath):
111
+ ... return filepath
112
+ """
113
+ if backend is not None:
114
+ _register_backend(name, backend, force=force, prefixes=prefixes)
115
+ return
116
+
117
+ def _register(backend_cls):
118
+ _register_backend(name, backend_cls, force=force, prefixes=prefixes)
119
+ return backend_cls
120
+
121
+ return _register
122
+
123
+
124
+ register_backend("local", LocalBackend, prefixes="")
125
+ register_backend("http", HTTPBackend, prefixes=["http", "https"])
imaginaire/utils/easy_io/easy_io.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import warnings
18
+ from collections.abc import Generator, Iterator
19
+ from contextlib import contextmanager
20
+ from io import BytesIO, StringIO
21
+ from pathlib import Path
22
+ from typing import IO, Any
23
+
24
+ from imaginaire.utils.easy_io.backends import backends, prefix_to_backends
25
+ from imaginaire.utils.easy_io.file_client import FileClient
26
+ from imaginaire.utils.easy_io.handlers import file_handlers
27
+
28
+ backend_instances: dict = {}
29
+
30
+
31
+ def is_filepath(filepath):
32
+ return isinstance(filepath, (str, Path))
33
+
34
+
35
+ def _parse_uri_prefix(uri: str | Path) -> str:
36
+ """Parse the prefix of uri.
37
+
38
+ Args:
39
+ uri (str or Path): Uri to be parsed that contains the file prefix.
40
+
41
+ Examples:
42
+ >>> _parse_uri_prefix('/home/path/of/your/file')
43
+ ''
44
+ >>> _parse_uri_prefix('http://path/of/your/file')
45
+ 'http'
46
+
47
+ Returns:
48
+ str: Return the prefix of uri if the uri contains '://'. Otherwise,
49
+ return ''.
50
+ """
51
+ assert is_filepath(uri)
52
+ uri = str(uri)
53
+ # if uri does not contains '://', the uri will be handled by
54
+ # LocalBackend by default
55
+ if "://" not in uri:
56
+ return ""
57
+ else:
58
+ prefix, _ = uri.split("://")
59
+ if ":" in prefix:
60
+ _, prefix = prefix.split(":")
61
+ return prefix
62
+
63
+
64
+ def _get_file_backend(prefix: str, backend_args: dict):
65
+ """Return a file backend based on the prefix or backend_args.
66
+
67
+ Args:
68
+ prefix (str): Prefix of uri.
69
+ backend_args (dict): Arguments to instantiate the corresponding
70
+ backend.
71
+ """
72
+ # backend name has a higher priority
73
+ if "backend" in backend_args:
74
+ # backend_args should not be modified
75
+ backend_args_bak = backend_args.copy()
76
+ backend_name = backend_args_bak.pop("backend")
77
+ backend = backends[backend_name](**backend_args_bak)
78
+ else:
79
+ backend = prefix_to_backends[prefix](**backend_args)
80
+ return backend
81
+
82
+
83
+ def get_file_backend(
84
+ uri: str | Path | None = None,
85
+ *,
86
+ backend_args: dict | None = None,
87
+ enable_singleton: bool = False,
88
+ backend_key: str | None = None,
89
+ ):
90
+ """Return a file backend based on the prefix of uri or backend_args.
91
+
92
+ Args:
93
+ uri (str or Path): Uri to be parsed that contains the file prefix.
94
+ backend_args (dict, optional): Arguments to instantiate the
95
+ corresponding backend. Defaults to None.
96
+ enable_singleton (bool): Whether to enable the singleton pattern.
97
+ If it is True, the backend created will be reused if the
98
+ signature is same with the previous one. Defaults to False.
99
+ backend_key: str: The key to register the backend. Defaults to None.
100
+
101
+ Returns:
102
+ BaseStorageBackend: Instantiated Backend object.
103
+
104
+ Examples:
105
+ >>> # get file backend based on the prefix of uri
106
+ >>> uri = 'http://path/of/your/file'
107
+ >>> backend = get_file_backend(uri)
108
+ >>> # get file backend based on the backend_args
109
+ >>> backend = get_file_backend(backend_args={'backend': 'http'})
110
+ >>> # backend name has a higher priority if 'backend' in backend_args
111
+ >>> backend = get_file_backend(uri, backend_args={'backend': 'http'})
112
+ """
113
+ global backend_instances
114
+ if backend_key is not None:
115
+ if backend_key in backend_instances:
116
+ return backend_instances[backend_key]
117
+
118
+ if backend_args is None:
119
+ backend_args = {}
120
+
121
+ if uri is None and "backend" not in backend_args and backend_key is None:
122
+ raise ValueError('uri should not be None when "backend" does not exist in backend_args and backend_key is None')
123
+
124
+ if uri is not None:
125
+ prefix = _parse_uri_prefix(uri)
126
+ else:
127
+ prefix = ""
128
+
129
+ if enable_singleton:
130
+ unique_key = f"{prefix}:{json.dumps(backend_args)}"
131
+ if unique_key in backend_instances:
132
+ return backend_instances[unique_key]
133
+
134
+ backend = _get_file_backend(prefix, backend_args)
135
+ backend_instances[unique_key] = backend
136
+ if backend_key is not None:
137
+ backend_instances[backend_key] = backend
138
+ return backend
139
+ else:
140
+ backend = _get_file_backend(prefix, backend_args)
141
+ return backend
142
+
143
+
144
+ def get(
145
+ filepath: str | Path,
146
+ backend_args: dict | None = None,
147
+ backend_key: str | None = None,
148
+ ) -> bytes:
149
+ """Read bytes from a given ``filepath`` with 'rb' mode.
150
+
151
+ Args:
152
+ filepath (str or Path): Path to read data.
153
+ backend_args (dict, optional): Arguments to instantiate the
154
+ corresponding backend. Defaults to None.
155
+ backend_key (str, optional): The key to get the backend from register.
156
+
157
+ Returns:
158
+ bytes: Expected bytes object.
159
+
160
+ Examples:
161
+ >>> filepath = '/path/of/file'
162
+ >>> get(filepath)
163
+ b'hello world'
164
+ """
165
+ backend = get_file_backend(
166
+ filepath,
167
+ backend_args=backend_args,
168
+ enable_singleton=True,
169
+ backend_key=backend_key,
170
+ )
171
+ return backend.get(filepath)
172
+
173
+
174
+ def get_text(
175
+ filepath: str | Path,
176
+ encoding="utf-8",
177
+ backend_args: dict | None = None,
178
+ backend_key: str | None = None,
179
+ ) -> str:
180
+ """Read text from a given ``filepath`` with 'r' mode.
181
+
182
+ Args:
183
+ filepath (str or Path): Path to read data.
184
+ encoding (str): The encoding format used to open the ``filepath``.
185
+ Defaults to 'utf-8'.
186
+ backend_args (dict, optional): Arguments to instantiate the
187
+ corresponding backend. Defaults to None.
188
+ backend_key (str, optional): The key to get the backend from register.
189
+
190
+ Returns:
191
+ str: Expected text reading from ``filepath``.
192
+
193
+ Examples:
194
+ >>> filepath = '/path/of/file'
195
+ >>> get_text(filepath)
196
+ 'hello world'
197
+ """
198
+ backend = get_file_backend(
199
+ filepath,
200
+ backend_args=backend_args,
201
+ enable_singleton=True,
202
+ backend_key=backend_key,
203
+ )
204
+ return backend.get_text(filepath, encoding)
205
+
206
+
207
+ def put(
208
+ obj: bytes,
209
+ filepath: str | Path,
210
+ backend_args: dict | None = None,
211
+ backend_key: str | None = None,
212
+ ) -> None:
213
+ """Write bytes to a given ``filepath`` with 'wb' mode.
214
+
215
+ Note:
216
+ ``put`` should create a directory if the directory of
217
+ ``filepath`` does not exist.
218
+
219
+ Args:
220
+ obj (bytes): Data to be written.
221
+ filepath (str or Path): Path to write data.
222
+ backend_args (dict, optional): Arguments to instantiate the
223
+ corresponding backend. Defaults to None.
224
+ backend_key (str, optional): The key to get the backend from register.
225
+
226
+ Examples:
227
+ >>> filepath = '/path/of/file'
228
+ >>> put(b'hello world', filepath)
229
+ """
230
+ backend = get_file_backend(
231
+ filepath,
232
+ backend_args=backend_args,
233
+ enable_singleton=True,
234
+ backend_key=backend_key,
235
+ )
236
+ backend.put(obj, filepath)
237
+
238
+
239
+ def put_text(
240
+ obj: str,
241
+ filepath: str | Path,
242
+ backend_args: dict | None = None,
243
+ backend_key: str | None = None,
244
+ ) -> None:
245
+ """Write text to a given ``filepath`` with 'w' mode.
246
+
247
+ Note:
248
+ ``put_text`` should create a directory if the directory of
249
+ ``filepath`` does not exist.
250
+
251
+ Args:
252
+ obj (str): Data to be written.
253
+ filepath (str or Path): Path to write data.
254
+ encoding (str, optional): The encoding format used to open the
255
+ ``filepath``. Defaults to 'utf-8'.
256
+ backend_args (dict, optional): Arguments to instantiate the
257
+ corresponding backend. Defaults to None.
258
+ backend_key (str, optional): The key to get the backend from register.
259
+
260
+ Examples:
261
+ >>> filepath = '/path/of/file'
262
+ >>> put_text('hello world', filepath)
263
+ """
264
+ backend = get_file_backend(
265
+ filepath,
266
+ backend_args=backend_args,
267
+ enable_singleton=True,
268
+ backend_key=backend_key,
269
+ )
270
+ backend.put_text(obj, filepath)
271
+
272
+
273
+ def exists(
274
+ filepath: str | Path,
275
+ backend_args: dict | None = None,
276
+ backend_key: str | None = None,
277
+ ) -> bool:
278
+ """Check whether a file path exists.
279
+
280
+ Args:
281
+ filepath (str or Path): Path to be checked whether exists.
282
+ backend_args (dict, optional): Arguments to instantiate the
283
+ corresponding backend. Defaults to None.
284
+ backend_key (str, optional): The key to get the backend from register.
285
+
286
+ Returns:
287
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
288
+
289
+ Examples:
290
+ >>> filepath = '/path/of/file'
291
+ >>> exists(filepath)
292
+ True
293
+ """
294
+ backend = get_file_backend(
295
+ filepath,
296
+ backend_args=backend_args,
297
+ enable_singleton=True,
298
+ backend_key=backend_key,
299
+ )
300
+ return backend.exists(filepath)
301
+
302
+
303
+ def isdir(
304
+ filepath: str | Path,
305
+ backend_args: dict | None = None,
306
+ backend_key: str | None = None,
307
+ ) -> bool:
308
+ """Check whether a file path is a directory.
309
+
310
+ Args:
311
+ filepath (str or Path): Path to be checked whether it is a
312
+ directory.
313
+ backend_args (dict, optional): Arguments to instantiate the
314
+ corresponding backend. Defaults to None.
315
+ backend_key (str, optional): The key to get the backend from register.
316
+
317
+ Returns:
318
+ bool: Return ``True`` if ``filepath`` points to a directory,
319
+ ``False`` otherwise.
320
+
321
+ Examples:
322
+ >>> filepath = '/path/of/dir'
323
+ >>> isdir(filepath)
324
+ True
325
+ """
326
+ backend = get_file_backend(
327
+ filepath,
328
+ backend_args=backend_args,
329
+ enable_singleton=True,
330
+ backend_key=backend_key,
331
+ )
332
+ return backend.isdir(filepath)
333
+
334
+
335
+ def isfile(
336
+ filepath: str | Path,
337
+ backend_args: dict | None = None,
338
+ backend_key: str | None = None,
339
+ ) -> bool:
340
+ """Check whether a file path is a file.
341
+
342
+ Args:
343
+ filepath (str or Path): Path to be checked whether it is a file.
344
+ backend_args (dict, optional): Arguments to instantiate the
345
+ corresponding backend. Defaults to None.
346
+ backend_key (str, optional): The key to get the backend from register.
347
+
348
+ Returns:
349
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
350
+ otherwise.
351
+
352
+ Examples:
353
+ >>> filepath = '/path/of/file'
354
+ >>> isfile(filepath)
355
+ True
356
+ """
357
+ backend = get_file_backend(
358
+ filepath,
359
+ backend_args=backend_args,
360
+ enable_singleton=True,
361
+ backend_key=backend_key,
362
+ )
363
+ return backend.isfile(filepath)
364
+
365
+
366
+ def join_path(
367
+ filepath: str | Path,
368
+ *filepaths: str | Path,
369
+ backend_args: dict | None = None,
370
+ backend_key: str | None = None,
371
+ ) -> str | Path:
372
+ r"""Concatenate all file paths.
373
+
374
+ Join one or more filepath components intelligently. The return value
375
+ is the concatenation of filepath and any members of \*filepaths.
376
+
377
+ Args:
378
+ filepath (str or Path): Path to be concatenated.
379
+ *filepaths (str or Path): Other paths to be concatenated.
380
+ backend_args (dict, optional): Arguments to instantiate the
381
+ corresponding backend. Defaults to None.
382
+ backend_key (str, optional): The key to get the backend from register.
383
+
384
+ Returns:
385
+ str: The result of concatenation.
386
+
387
+ Examples:
388
+ >>> filepath1 = '/path/of/dir1'
389
+ >>> filepath2 = 'dir2'
390
+ >>> filepath3 = 'path/of/file'
391
+ >>> join_path(filepath1, filepath2, filepath3)
392
+ '/path/of/dir/dir2/path/of/file'
393
+ """
394
+ backend = get_file_backend(
395
+ filepath,
396
+ backend_args=backend_args,
397
+ enable_singleton=True,
398
+ backend_key=backend_key,
399
+ )
400
+ return backend.join_path(filepath, *filepaths)
401
+
402
+
403
+ @contextmanager
404
+ def get_local_path(
405
+ filepath: str | Path,
406
+ backend_args: dict | None = None,
407
+ backend_key: str | None = None,
408
+ ) -> Generator[str | Path, None, None]:
409
+ """Download data from ``filepath`` and write the data to local path.
410
+
411
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
412
+ can be called with ``with`` statement, and when exists from the
413
+ ``with`` statement, the temporary path will be released.
414
+
415
+ Note:
416
+ If the ``filepath`` is a local path, just return itself and it will
417
+ not be released (removed).
418
+
419
+ Args:
420
+ filepath (str or Path): Path to be read data.
421
+ backend_args (dict, optional): Arguments to instantiate the
422
+ corresponding backend. Defaults to None.
423
+
424
+ Yields:
425
+ Iterable[str]: Only yield one path.
426
+
427
+ Examples:
428
+ >>> with get_local_path('http://example.com/file.jpg') as path:
429
+ ... # do something here
430
+ """
431
+ backend = get_file_backend(
432
+ filepath,
433
+ backend_args=backend_args,
434
+ enable_singleton=True,
435
+ backend_key=backend_key,
436
+ )
437
+ with backend.get_local_path(str(filepath)) as local_path:
438
+ yield local_path
439
+
440
+
441
+ def copyfile(
442
+ src: str | Path,
443
+ dst: str | Path,
444
+ backend_args: dict | None = None,
445
+ backend_key: str | None = None,
446
+ ) -> str | Path:
447
+ """Copy a file src to dst and return the destination file.
448
+
449
+ src and dst should have the same prefix. If dst specifies a directory,
450
+ the file will be copied into dst using the base filename from src. If
451
+ dst specifies a file that already exists, it will be replaced.
452
+
453
+ Args:
454
+ src (str or Path): A file to be copied.
455
+ dst (str or Path): Copy file to dst.
456
+ backend_args (dict, optional): Arguments to instantiate the
457
+ corresponding backend. Defaults to None.
458
+
459
+ Returns:
460
+ str: The destination file.
461
+
462
+ Raises:
463
+ SameFileError: If src and dst are the same file, a SameFileError will
464
+ be raised.
465
+
466
+ Examples:
467
+ >>> # dst is a file
468
+ >>> src = '/path/of/file'
469
+ >>> dst = '/path1/of/file1'
470
+ >>> # src will be copied to '/path1/of/file1'
471
+ >>> copyfile(src, dst)
472
+ '/path1/of/file1'
473
+
474
+ >>> # dst is a directory
475
+ >>> dst = '/path1/of/dir'
476
+ >>> # src will be copied to '/path1/of/dir/file'
477
+ >>> copyfile(src, dst)
478
+ '/path1/of/dir/file'
479
+ """
480
+ backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
481
+ return backend.copyfile(src, dst)
482
+
483
+
484
+ def copytree(
485
+ src: str | Path,
486
+ dst: str | Path,
487
+ backend_args: dict | None = None,
488
+ backend_key: str | None = None,
489
+ ) -> str | Path:
490
+ """Recursively copy an entire directory tree rooted at src to a directory
491
+ named dst and return the destination directory.
492
+
493
+ src and dst should have the same prefix and dst must not already exist.
494
+
495
+ Args:
496
+ src (str or Path): A directory to be copied.
497
+ dst (str or Path): Copy directory to dst.
498
+ backend_args (dict, optional): Arguments to instantiate the
499
+ corresponding backend. Defaults to None.
500
+ backend_key (str, optional): The key to get the backend from register.
501
+
502
+ Returns:
503
+ str: The destination directory.
504
+
505
+ Raises:
506
+ FileExistsError: If dst had already existed, a FileExistsError will be
507
+ raised.
508
+
509
+ Examples:
510
+ >>> src = '/path/of/dir1'
511
+ >>> dst = '/path/of/dir2'
512
+ >>> copytree(src, dst)
513
+ '/path/of/dir2'
514
+ """
515
+ backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
516
+ return backend.copytree(src, dst)
517
+
518
+
519
+ def copyfile_from_local(
520
+ src: str | Path,
521
+ dst: str | Path,
522
+ backend_args: dict | None = None,
523
+ backend_key: str | None = None,
524
+ ) -> str | Path:
525
+ """Copy a local file src to dst and return the destination file.
526
+
527
+ Note:
528
+ If the backend is the instance of LocalBackend, it does the same
529
+ thing with :func:`copyfile`.
530
+
531
+ Args:
532
+ src (str or Path): A local file to be copied.
533
+ dst (str or Path): Copy file to dst.
534
+ backend_args (dict, optional): Arguments to instantiate the
535
+ corresponding backend. Defaults to None.
536
+
537
+ Returns:
538
+ str: If dst specifies a directory, the file will be copied into dst
539
+ using the base filename from src.
540
+
541
+ Examples:
542
+ >>> # dst is a file
543
+ >>> src = '/path/of/file'
544
+ >>> dst = 'http://example.com/file1'
545
+ >>> # src will be copied to 'http://example.com/file1'
546
+ >>> copyfile_from_local(src, dst)
547
+ http://example.com/file1
548
+
549
+ >>> # dst is a directory
550
+ >>> dst = 'http://example.com/dir'
551
+ >>> # src will be copied to 'http://example.com/dir/file''
552
+ >>> copyfile_from_local(src, dst)
553
+ 'http://example.com/dir/file'
554
+ """
555
+ backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
556
+ return backend.copyfile_from_local(src, dst)
557
+
558
+
559
+ def copytree_from_local(
560
+ src: str | Path,
561
+ dst: str | Path,
562
+ backend_args: dict | None = None,
563
+ backend_key: str | None = None,
564
+ ) -> str | Path:
565
+ """Recursively copy an entire directory tree rooted at src to a directory
566
+ named dst and return the destination directory.
567
+
568
+ Note:
569
+ If the backend is the instance of LocalBackend, it does the same
570
+ thing with :func:`copytree`.
571
+
572
+ Args:
573
+ src (str or Path): A local directory to be copied.
574
+ dst (str or Path): Copy directory to dst.
575
+ backend_args (dict, optional): Arguments to instantiate the
576
+ corresponding backend. Defaults to None.
577
+
578
+ Returns:
579
+ str: The destination directory.
580
+
581
+ Examples:
582
+ >>> src = '/path/of/dir'
583
+ >>> dst = 'http://example.com/dir'
584
+ >>> copyfile_from_local(src, dst)
585
+ 'http://example.com/dir'
586
+ """
587
+ backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
588
+ return backend.copytree_from_local(src, dst)
589
+
590
+
591
+ def copyfile_to_local(
592
+ src: str | Path,
593
+ dst: str | Path,
594
+ dst_type: str, # Choose from ["file", "dir"]
595
+ backend_args: dict | None = None,
596
+ backend_key: str | None = None,
597
+ ) -> str | Path:
598
+ """Copy the file src to local dst and return the destination file.
599
+
600
+ If dst specifies a directory, the file will be copied into dst using
601
+ the base filename from src. If dst specifies a file that already
602
+ exists, it will be replaced.
603
+
604
+ Note:
605
+ If the backend is the instance of LocalBackend, it does the same
606
+ thing with :func:`copyfile`.
607
+
608
+ Args:
609
+ src (str or Path): A file to be copied.
610
+ dst (str or Path): Copy file to to local dst.
611
+ backend_args (dict, optional): Arguments to instantiate the
612
+ corresponding backend. Defaults to None.
613
+
614
+ Returns:
615
+ str: If dst specifies a directory, the file will be copied into dst
616
+ using the base filename from src.
617
+
618
+ Examples:
619
+ >>> # dst is a file
620
+ >>> src = 'http://example.com/file'
621
+ >>> dst = '/path/of/file'
622
+ >>> # src will be copied to '/path/of/file'
623
+ >>> copyfile_to_local(src, dst)
624
+ '/path/of/file'
625
+
626
+ >>> # dst is a directory
627
+ >>> dst = '/path/of/dir'
628
+ >>> # src will be copied to '/path/of/dir/file'
629
+ >>> copyfile_to_local(src, dst)
630
+ '/path/of/dir/file'
631
+ """
632
+ assert dst_type in ["file", "dir"]
633
+ Path(dst).parent.mkdir(parents=True, exist_ok=True)
634
+ backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
635
+ return backend.copyfile_to_local(src, dst, dst_type=dst_type)
636
+
637
+
638
+ def copytree_to_local(
639
+ src: str | Path,
640
+ dst: str | Path,
641
+ backend_args: dict | None = None,
642
+ backend_key: str | None = None,
643
+ ) -> str | Path:
644
+ """Recursively copy an entire directory tree rooted at src to a local
645
+ directory named dst and return the destination directory.
646
+
647
+ Note:
648
+ If the backend is the instance of LocalBackend, it does the same
649
+ thing with :func:`copytree`.
650
+
651
+ Args:
652
+ src (str or Path): A directory to be copied.
653
+ dst (str or Path): Copy directory to local dst.
654
+ backend_args (dict, optional): Arguments to instantiate the
655
+ corresponding backend. Defaults to None.
656
+
657
+ Returns:
658
+ str: The destination directory.
659
+
660
+ Examples:
661
+ >>> src = 'http://example.com/dir'
662
+ >>> dst = '/path/of/dir'
663
+ >>> copytree_to_local(src, dst)
664
+ '/path/of/dir'
665
+ """
666
+ Path(dst).parent.mkdir(parents=True, exist_ok=True)
667
+ backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
668
+ return backend.copytree_to_local(src, dst)
669
+
670
+
671
+ def remove(
672
+ filepath: str | Path,
673
+ backend_args: dict | None = None,
674
+ backend_key: str | None = None,
675
+ ) -> None:
676
+ """Remove a file.
677
+
678
+ Args:
679
+ filepath (str, Path): Path to be removed.
680
+ backend_args (dict, optional): Arguments to instantiate the
681
+ corresponding backend. Defaults to None.
682
+
683
+ Raises:
684
+ FileNotFoundError: If filepath does not exist, an FileNotFoundError
685
+ will be raised.
686
+ IsADirectoryError: If filepath is a directory, an IsADirectoryError
687
+ will be raised.
688
+
689
+ Examples:
690
+ >>> filepath = '/path/of/file'
691
+ >>> remove(filepath)
692
+ """
693
+ backend = get_file_backend(
694
+ filepath,
695
+ backend_args=backend_args,
696
+ enable_singleton=True,
697
+ backend_key=backend_key,
698
+ )
699
+ backend.remove(filepath)
700
+
701
+
702
+ def rmtree(
703
+ dir_path: str | Path,
704
+ backend_args: dict | None = None,
705
+ backend_key: str | None = None,
706
+ ) -> None:
707
+ """Recursively delete a directory tree.
708
+
709
+ Args:
710
+ dir_path (str or Path): A directory to be removed.
711
+ backend_args (dict, optional): Arguments to instantiate the
712
+ corresponding backend. Defaults to None.
713
+
714
+ Examples:
715
+ >>> dir_path = '/path/of/dir'
716
+ >>> rmtree(dir_path)
717
+ """
718
+ backend = get_file_backend(
719
+ dir_path,
720
+ backend_args=backend_args,
721
+ enable_singleton=True,
722
+ backend_key=backend_key,
723
+ )
724
+ backend.rmtree(dir_path)
725
+
726
+
727
+ def copy_if_symlink_fails(
728
+ src: str | Path,
729
+ dst: str | Path,
730
+ backend_args: dict | None = None,
731
+ backend_key: str | None = None,
732
+ ) -> bool:
733
+ """Create a symbolic link pointing to src named dst.
734
+
735
+ If failed to create a symbolic link pointing to src, directory copy src to
736
+ dst instead.
737
+
738
+ Args:
739
+ src (str or Path): Create a symbolic link pointing to src.
740
+ dst (str or Path): Create a symbolic link named dst.
741
+ backend_args (dict, optional): Arguments to instantiate the
742
+ corresponding backend. Defaults to None.
743
+
744
+ Returns:
745
+ bool: Return True if successfully create a symbolic link pointing to
746
+ src. Otherwise, return False.
747
+
748
+ Examples:
749
+ >>> src = '/path/of/file'
750
+ >>> dst = '/path1/of/file1'
751
+ >>> copy_if_symlink_fails(src, dst)
752
+ True
753
+ >>> src = '/path/of/dir'
754
+ >>> dst = '/path1/of/dir1'
755
+ >>> copy_if_symlink_fails(src, dst)
756
+ True
757
+ """
758
+ backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
759
+ return backend.copy_if_symlink_fails(src, dst)
760
+
761
+
762
+ def list_dir(
763
+ dir_path: str | Path,
764
+ backend_args: dict | None = None,
765
+ backend_key: str | None = None,
766
+ ):
767
+ """List all folders in a directory with a given path.
768
+
769
+ Args:
770
+ dir_path (str | Path): Path of the directory.
771
+
772
+ Examples:
773
+ >>> dir_path = '/path/of/dir'
774
+ >>> for file_path in list_dir(dir_path):
775
+ ... print(file_path)
776
+ """
777
+ if not dir_path.endswith("/"):
778
+ dir_path += "/"
779
+ backend = get_file_backend(
780
+ dir_path,
781
+ backend_args=backend_args,
782
+ enable_singleton=True,
783
+ backend_key=backend_key,
784
+ )
785
+
786
+ return backend.list_dir(dir_path)
787
+
788
+
789
+ def list_dir_or_file(
790
+ dir_path: str | Path,
791
+ list_dir: bool = True,
792
+ list_file: bool = True,
793
+ suffix: str | tuple[str] | None = None,
794
+ recursive: bool = False,
795
+ backend_args: dict | None = None,
796
+ backend_key: str | None = None,
797
+ ) -> Iterator[str]:
798
+ """Scan a directory to find the interested directories or files in
799
+ arbitrary order.
800
+
801
+ Note:
802
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
803
+
804
+ Args:
805
+ dir_path (str or Path): Path of the directory.
806
+ list_dir (bool): List the directories. Defaults to True.
807
+ list_file (bool): List the path of files. Defaults to True.
808
+ suffix (str or tuple[str], optional): File suffix that we are
809
+ interested in. Defaults to None.
810
+ recursive (bool): If set to True, recursively scan the directory.
811
+ Defaults to False.
812
+ backend_args (dict, optional): Arguments to instantiate the
813
+ corresponding backend. Defaults to None.
814
+
815
+ Yields:
816
+ Iterable[str]: A relative path to ``dir_path``.
817
+
818
+ Examples:
819
+ >>> dir_path = '/path/of/dir'
820
+ >>> for file_path in list_dir_or_file(dir_path):
821
+ ... print(file_path)
822
+ >>> # list those files and directories in current directory
823
+ >>> for file_path in list_dir_or_file(dir_path):
824
+ ... print(file_path)
825
+ >>> # only list files
826
+ >>> for file_path in list_dir_or_file(dir_path, list_dir=False):
827
+ ... print(file_path)
828
+ >>> # only list directories
829
+ >>> for file_path in list_dir_or_file(dir_path, list_file=False):
830
+ ... print(file_path)
831
+ >>> # only list files ending with specified suffixes
832
+ >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'):
833
+ ... print(file_path)
834
+ >>> # list all files and directory recursively
835
+ >>> for file_path in list_dir_or_file(dir_path, recursive=True):
836
+ ... print(file_path)
837
+ """
838
+ backend = get_file_backend(
839
+ dir_path,
840
+ backend_args=backend_args,
841
+ enable_singleton=True,
842
+ backend_key=backend_key,
843
+ )
844
+ yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
845
+
846
+
847
+ def load(
848
+ file: str | Path | IO[Any],
849
+ file_format: str | None = None,
850
+ file_client_args: dict | None = None,
851
+ fast_backend: bool = False,
852
+ backend_args: dict | None = None,
853
+ backend_key: str | None = None,
854
+ **kwargs,
855
+ ):
856
+ """Load data from json/yaml/pickle files.
857
+
858
+ This method provides a unified api for loading data from serialized files.
859
+
860
+ ``load`` supports loading data from serialized files those can be storaged
861
+ in different backends.
862
+
863
+ Args:
864
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
865
+ object.
866
+ file_format (str, optional): If not specified, the file format will be
867
+ inferred from the file extension, otherwise use the specified one.
868
+ Currently supported formats include "json", "yaml/yml" and
869
+ "pickle/pkl".
870
+ file_client_args (dict, optional): Arguments to instantiate a
871
+ FileClient. See :class:`mmengine.fileio.FileClient` for details.
872
+ Defaults to None. It will be deprecated in future. Please use
873
+ ``backend_args`` instead.
874
+ fast_backend: bool: Whether to use multiprocess. Defaults to False.
875
+ backend_args (dict, optional): Arguments to instantiate the
876
+ prefix of uri corresponding backend. Defaults to None.
877
+ New in v0.2.0.
878
+
879
+ Examples:
880
+ >>> load('/path/of/your/file') # file is storaged in disk
881
+ >>> load('https://path/of/your/file') # file is storaged in Internet
882
+
883
+ Returns:
884
+ The content from the file.
885
+ """
886
+ if isinstance(file, Path):
887
+ file = str(file)
888
+ if file_format is None and isinstance(file, str):
889
+ file_format = file.split(".")[-1]
890
+ # convert file_format to lower case
891
+ file_format = file_format.lower()
892
+ if file_format not in file_handlers:
893
+ raise TypeError(f"Unsupported format: {file_format}")
894
+
895
+ if file_client_args is not None:
896
+ warnings.warn( # noqa: B028
897
+ '"file_client_args" will be deprecated in future. Please use "backend_args" instead',
898
+ DeprecationWarning,
899
+ )
900
+ if backend_args is not None:
901
+ raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.')
902
+
903
+ handler = file_handlers[file_format]
904
+ if isinstance(file, str):
905
+ if file_client_args is not None:
906
+ file_client = FileClient.infer_client(file_client_args, file)
907
+ file_backend = file_client
908
+ else:
909
+ file_backend = get_file_backend(
910
+ file,
911
+ backend_args=backend_args,
912
+ backend_key=backend_key,
913
+ enable_singleton=True,
914
+ )
915
+
916
+ if handler.str_like:
917
+ with StringIO(file_backend.get_text(file)) as f:
918
+ obj = handler.load_from_fileobj(f, **kwargs)
919
+ else:
920
+ if fast_backend:
921
+ if hasattr(file_backend, "fast_get"):
922
+ with BytesIO(file_backend.fast_get(file)) as f:
923
+ obj = handler.load_from_fileobj(f, **kwargs)
924
+ else:
925
+ warnings.warn( # noqa: B028
926
+ f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get"
927
+ )
928
+ with BytesIO(file_backend.get(file)) as f:
929
+ obj = handler.load_from_fileobj(f, **kwargs)
930
+ else:
931
+ with BytesIO(file_backend.get(file)) as f:
932
+ obj = handler.load_from_fileobj(f, **kwargs)
933
+ elif hasattr(file, "read"):
934
+ obj = handler.load_from_fileobj(file, **kwargs)
935
+ else:
936
+ raise TypeError('"file" must be a filepath str or a file-object')
937
+ return obj
938
+
939
+
940
+ def dump(
941
+ obj: Any,
942
+ file: str | Path | IO[Any] | None = None,
943
+ file_format: str | None = None,
944
+ file_client_args: dict | None = None,
945
+ fast_backend: bool = False,
946
+ backend_args: dict | None = None,
947
+ backend_key: str | None = None,
948
+ **kwargs,
949
+ ):
950
+ """Dump data to json/yaml/pickle strings or files.
951
+
952
+ This method provides a unified api for dumping data as strings or to files,
953
+ and also supports custom arguments for each file format.
954
+
955
+ ``dump`` supports dumping data as strings or to files which is saved to
956
+ different backends.
957
+
958
+ Args:
959
+ obj (any): The python object to be dumped.
960
+ file (str or :obj:`Path` or file-like object, optional): If not
961
+ specified, then the object is dumped to a str, otherwise to a file
962
+ specified by the filename or file-like object.
963
+ file_format (str, optional): Same as :func:`load`.
964
+ file_client_args (dict, optional): Arguments to instantiate a
965
+ FileClient. See :class:`mmengine.fileio.FileClient` for details.
966
+ Defaults to None. It will be deprecated in future. Please use
967
+ ``backend_args`` instead.
968
+ fast_backend: bool: Whether to use multiprocess. Defaults to False.
969
+ backend_args (dict, optional): Arguments to instantiate the
970
+ prefix of uri corresponding backend. Defaults to None.
971
+ New in v0.2.0.
972
+ backend_key: str: The key to register the backend. Defaults to None.
973
+
974
+ Examples:
975
+ >>> dump('hello world', '/path/of/your/file') # disk
976
+ >>> dump('hello world', 'http://path/of/your/file') # http
977
+
978
+ Returns:
979
+ bool: True for success, False otherwise.
980
+ """
981
+ if isinstance(file, Path):
982
+ file = str(file)
983
+ if file_format is None:
984
+ if isinstance(file, str):
985
+ file_format = file.split(".")[-1]
986
+ elif file is None:
987
+ raise ValueError("file_format must be specified since file is None")
988
+ # convert file_format to lower case
989
+ file_format = file_format.lower()
990
+ if file_format not in file_handlers:
991
+ raise TypeError(f"Unsupported format: {file_format}")
992
+
993
+ if file_client_args is not None:
994
+ warnings.warn( # noqa: B028
995
+ '"file_client_args" will be deprecated in future. Please use "backend_args" instead',
996
+ DeprecationWarning,
997
+ )
998
+ if backend_args is not None:
999
+ raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.')
1000
+
1001
+ handler = file_handlers[file_format]
1002
+ if file is None:
1003
+ return handler.dump_to_str(obj, **kwargs)
1004
+ elif isinstance(file, str):
1005
+ if file_client_args is not None:
1006
+ file_client = FileClient.infer_client(file_client_args, file)
1007
+ file_backend = file_client
1008
+ else:
1009
+ file_backend = get_file_backend(
1010
+ file,
1011
+ backend_args=backend_args,
1012
+ backend_key=backend_key,
1013
+ enable_singleton=True,
1014
+ )
1015
+
1016
+ if handler.str_like:
1017
+ with StringIO() as f:
1018
+ handler.dump_to_fileobj(obj, f, **kwargs)
1019
+ file_backend.put_text(f.getvalue(), file)
1020
+ else:
1021
+ with BytesIO() as f:
1022
+ handler.dump_to_fileobj(obj, f, **kwargs)
1023
+ if fast_backend:
1024
+ if hasattr(file_backend, "fast_put"):
1025
+ file_backend.fast_put(f, file)
1026
+ else:
1027
+ warnings.warn("fast_backend is not supported by the backend, fallback to normal put") # noqa: B028
1028
+ file_backend.put(f, file)
1029
+ else:
1030
+ file_backend.put(f, file)
1031
+ elif hasattr(file, "write"):
1032
+ handler.dump_to_fileobj(obj, file, **kwargs)
1033
+ else:
1034
+ raise TypeError('"file" must be a filename str or a file-object')
imaginaire/utils/easy_io/file_client.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ from collections.abc import Generator, Iterator
18
+ from contextlib import contextmanager
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ from imaginaire.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend
23
+
24
+
25
+ def is_filepath(filepath):
26
+ return isinstance(filepath, (str, Path))
27
+
28
+
29
+ class HardDiskBackend(LocalBackend):
30
+ """Raw hard disks storage backend."""
31
+
32
+ @property
33
+ def name(self):
34
+ return self.__class__.__name__
35
+
36
+
37
+ class FileClient:
38
+ """A general file client to access files in different backends.
39
+
40
+ The client loads a file or text in a specified backend from its path
41
+ and returns it as a binary or text file. There are two ways to choose a
42
+ backend, the name of backend and the prefix of path. Although both of them
43
+ can be used to choose a storage backend, ``backend`` has a higher priority
44
+ that is if they are all set, the storage backend will be chosen by the
45
+ backend argument. If they are all `None`, the disk backend will be chosen.
46
+ Note that It can also register other backend accessor with a given name,
47
+ prefixes, and backend class. In addition, We use the singleton pattern to
48
+ avoid repeated object creation. If the arguments are the same, the same
49
+ object will be returned.
50
+
51
+ Warning:
52
+ `FileClient` will be deprecated in future. Please use io functions
53
+ in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
54
+
55
+ Args:
56
+ backend (str, optional): The storage backend type. Options are "disk",
57
+ "memcached", "lmdb" and "http". Defaults to None.
58
+ prefix (str, optional): The prefix of the registered storage backend.
59
+ Options are "http", "https". Defaults to None.
60
+
61
+ Examples:
62
+ >>> # only set backend
63
+ >>> file_client = FileClient(backend='disk')
64
+ >>> # only set prefix
65
+ >>> file_client = FileClient(prefix='http')
66
+ >>> # set both backend and prefix but use backend to choose client
67
+ >>> file_client = FileClient(backend='http', prefix='http')
68
+ >>> # if the arguments are the same, the same object is returned
69
+ >>> file_client1 = FileClient(backend='disk')
70
+ >>> file_client1 is file_client
71
+ True
72
+
73
+ Attributes:
74
+ client (:obj:`BaseStorageBackend`): The backend object.
75
+ """
76
+
77
+ _backends = { # noqa: RUF012
78
+ "disk": HardDiskBackend,
79
+ "http": HTTPBackend,
80
+ }
81
+
82
+ _prefix_to_backends: dict = { # noqa: RUF012
83
+ "http": HTTPBackend,
84
+ "https": HTTPBackend,
85
+ }
86
+
87
+ _instances: dict = {} # noqa: RUF012
88
+
89
+ client: Any
90
+
91
+ def __new__(cls, backend=None, prefix=None, **kwargs):
92
+ if backend is None and prefix is None:
93
+ backend = "disk"
94
+ if backend is not None and backend not in cls._backends:
95
+ raise ValueError(
96
+ f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}"
97
+ )
98
+ if prefix is not None and prefix not in cls._prefix_to_backends:
99
+ raise ValueError(
100
+ f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}"
101
+ )
102
+
103
+ # concatenate the arguments to a unique key for determining whether
104
+ # objects with the same arguments were created
105
+ arg_key = f"{backend}:{prefix}"
106
+ for key, value in kwargs.items():
107
+ arg_key += f":{key}:{value}"
108
+
109
+ # if a backend was overridden, it will create a new object
110
+ if arg_key in cls._instances:
111
+ _instance = cls._instances[arg_key]
112
+ else:
113
+ # create a new object and put it to _instance
114
+ _instance = super().__new__(cls)
115
+ if backend is not None:
116
+ _instance.client = cls._backends[backend](**kwargs)
117
+ else:
118
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
119
+
120
+ cls._instances[arg_key] = _instance
121
+
122
+ return _instance
123
+
124
+ @property
125
+ def name(self):
126
+ return self.client.name
127
+
128
+ @property
129
+ def allow_symlink(self):
130
+ return self.client.allow_symlink
131
+
132
+ @staticmethod
133
+ def parse_uri_prefix(uri: str | Path) -> str | None:
134
+ """Parse the prefix of a uri.
135
+
136
+ Args:
137
+ uri (str | Path): Uri to be parsed that contains the file prefix.
138
+
139
+ Examples:
140
+ >>> FileClient.parse_uri_prefix('http://path/of/your/file')
141
+ 'http'
142
+
143
+ Returns:
144
+ str | None: Return the prefix of uri if the uri contains '://' else
145
+ ``None``.
146
+ """
147
+ assert is_filepath(uri)
148
+ uri = str(uri)
149
+ if "://" not in uri:
150
+ return None
151
+ else:
152
+ prefix, _ = uri.split("://")
153
+ return prefix
154
+
155
+ @classmethod
156
+ def infer_client(
157
+ cls,
158
+ file_client_args: dict | None = None,
159
+ uri: str | Path | None = None,
160
+ ) -> "FileClient":
161
+ """Infer a suitable file client based on the URI and arguments.
162
+
163
+ Args:
164
+ file_client_args (dict, optional): Arguments to instantiate a
165
+ FileClient. Defaults to None.
166
+ uri (str | Path, optional): Uri to be parsed that contains the file
167
+ prefix. Defaults to None.
168
+
169
+ Examples:
170
+ >>> uri = 'http://path/of/your/file'
171
+ >>> file_client = FileClient.infer_client(uri=uri)
172
+ >>> file_client_args = {'backend': 'disk'}
173
+ >>> file_client = FileClient.infer_client(file_client_args)
174
+
175
+ Returns:
176
+ FileClient: Instantiated FileClient object.
177
+ """
178
+ assert file_client_args is not None or uri is not None
179
+ if file_client_args is None:
180
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
181
+ return cls(prefix=file_prefix)
182
+ else:
183
+ return cls(**file_client_args)
184
+
185
+ @classmethod
186
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
187
+ if not isinstance(name, str):
188
+ raise TypeError(f"the backend name should be a string, but got {type(name)}")
189
+ if not inspect.isclass(backend):
190
+ raise TypeError(f"backend should be a class but got {type(backend)}")
191
+ if not issubclass(backend, BaseStorageBackend):
192
+ raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
193
+ if not force and name in cls._backends:
194
+ raise KeyError(
195
+ f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
196
+ )
197
+
198
+ if name in cls._backends and force:
199
+ for arg_key, instance in list(cls._instances.items()):
200
+ if isinstance(instance.client, cls._backends[name]):
201
+ cls._instances.pop(arg_key)
202
+ cls._backends[name] = backend
203
+
204
+ if prefixes is not None:
205
+ if isinstance(prefixes, str):
206
+ prefixes = [prefixes]
207
+ else:
208
+ assert isinstance(prefixes, (list, tuple))
209
+ for prefix in prefixes:
210
+ if prefix not in cls._prefix_to_backends:
211
+ cls._prefix_to_backends[prefix] = backend
212
+ elif (prefix in cls._prefix_to_backends) and force:
213
+ overridden_backend = cls._prefix_to_backends[prefix]
214
+ for arg_key, instance in list(cls._instances.items()):
215
+ if isinstance(instance.client, overridden_backend):
216
+ cls._instances.pop(arg_key)
217
+ else:
218
+ raise KeyError(
219
+ f"{prefix} is already registered as a storage backend,"
220
+ ' add "force=True" if you want to override it'
221
+ )
222
+
223
+ @classmethod
224
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
225
+ """Register a backend to FileClient.
226
+
227
+ This method can be used as a normal class method or a decorator.
228
+
229
+ .. code-block:: python
230
+
231
+ class NewBackend(BaseStorageBackend):
232
+
233
+ def get(self, filepath):
234
+ return filepath
235
+
236
+ def get_text(self, filepath):
237
+ return filepath
238
+
239
+ FileClient.register_backend('new', NewBackend)
240
+
241
+ or
242
+
243
+ .. code-block:: python
244
+
245
+ @FileClient.register_backend('new')
246
+ class NewBackend(BaseStorageBackend):
247
+
248
+ def get(self, filepath):
249
+ return filepath
250
+
251
+ def get_text(self, filepath):
252
+ return filepath
253
+
254
+ Args:
255
+ name (str): The name of the registered backend.
256
+ backend (class, optional): The backend class to be registered,
257
+ which must be a subclass of :class:`BaseStorageBackend`.
258
+ When this method is used as a decorator, backend is None.
259
+ Defaults to None.
260
+ force (bool, optional): Whether to override the backend if the name
261
+ has already been registered. Defaults to False.
262
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
263
+ of the registered storage backend. Defaults to None.
264
+ `New in version 1.3.15.`
265
+ """
266
+ if backend is not None:
267
+ cls._register_backend(name, backend, force=force, prefixes=prefixes)
268
+ return
269
+
270
+ def _register(backend_cls):
271
+ cls._register_backend(name, backend_cls, force=force, prefixes=prefixes)
272
+ return backend_cls
273
+
274
+ return _register
275
+
276
+ def get(self, filepath: str | Path) -> bytes | memoryview:
277
+ """Read data from a given ``filepath`` with 'rb' mode.
278
+
279
+ Note:
280
+ There are two types of return values for ``get``, one is ``bytes``
281
+ and the other is ``memoryview``. The advantage of using memoryview
282
+ is that you can avoid copying, and if you want to convert it to
283
+ ``bytes``, you can use ``.tobytes()``.
284
+
285
+ Args:
286
+ filepath (str or Path): Path to read data.
287
+
288
+ Returns:
289
+ bytes | memoryview: Expected bytes object or a memory view of the
290
+ bytes object.
291
+ """
292
+ return self.client.get(filepath)
293
+
294
+ def get_text(self, filepath: str | Path, encoding="utf-8") -> str:
295
+ """Read data from a given ``filepath`` with 'r' mode.
296
+
297
+ Args:
298
+ filepath (str or Path): Path to read data.
299
+ encoding (str): The encoding format used to open the ``filepath``.
300
+ Defaults to 'utf-8'.
301
+
302
+ Returns:
303
+ str: Expected text reading from ``filepath``.
304
+ """
305
+ return self.client.get_text(filepath, encoding)
306
+
307
+ def put(self, obj: bytes, filepath: str | Path) -> None:
308
+ """Write data to a given ``filepath`` with 'wb' mode.
309
+
310
+ Note:
311
+ ``put`` should create a directory if the directory of ``filepath``
312
+ does not exist.
313
+
314
+ Args:
315
+ obj (bytes): Data to be written.
316
+ filepath (str or Path): Path to write data.
317
+ """
318
+ self.client.put(obj, filepath)
319
+
320
+ def put_text(self, obj: str, filepath: str | Path) -> None:
321
+ """Write data to a given ``filepath`` with 'w' mode.
322
+
323
+ Note:
324
+ ``put_text`` should create a directory if the directory of
325
+ ``filepath`` does not exist.
326
+
327
+ Args:
328
+ obj (str): Data to be written.
329
+ filepath (str or Path): Path to write data.
330
+ encoding (str, optional): The encoding format used to open the
331
+ `filepath`. Defaults to 'utf-8'.
332
+ """
333
+ self.client.put_text(obj, filepath)
334
+
335
+ def remove(self, filepath: str | Path) -> None:
336
+ """Remove a file.
337
+
338
+ Args:
339
+ filepath (str, Path): Path to be removed.
340
+ """
341
+ self.client.remove(filepath)
342
+
343
+ def exists(self, filepath: str | Path) -> bool:
344
+ """Check whether a file path exists.
345
+
346
+ Args:
347
+ filepath (str or Path): Path to be checked whether exists.
348
+
349
+ Returns:
350
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
351
+ """
352
+ return self.client.exists(filepath)
353
+
354
+ def isdir(self, filepath: str | Path) -> bool:
355
+ """Check whether a file path is a directory.
356
+
357
+ Args:
358
+ filepath (str or Path): Path to be checked whether it is a
359
+ directory.
360
+
361
+ Returns:
362
+ bool: Return ``True`` if ``filepath`` points to a directory,
363
+ ``False`` otherwise.
364
+ """
365
+ return self.client.isdir(filepath)
366
+
367
+ def isfile(self, filepath: str | Path) -> bool:
368
+ """Check whether a file path is a file.
369
+
370
+ Args:
371
+ filepath (str or Path): Path to be checked whether it is a file.
372
+
373
+ Returns:
374
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
375
+ otherwise.
376
+ """
377
+ return self.client.isfile(filepath)
378
+
379
+ def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str:
380
+ r"""Concatenate all file paths.
381
+
382
+ Join one or more filepath components intelligently. The return value
383
+ is the concatenation of filepath and any members of \*filepaths.
384
+
385
+ Args:
386
+ filepath (str or Path): Path to be concatenated.
387
+
388
+ Returns:
389
+ str: The result of concatenation.
390
+ """
391
+ return self.client.join_path(filepath, *filepaths)
392
+
393
+ @contextmanager
394
+ def get_local_path(self, filepath: str | Path) -> Generator[str | Path, None, None]:
395
+ """Download data from ``filepath`` and write the data to local path.
396
+
397
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
398
+ can be called with ``with`` statement, and when exists from the
399
+ ``with`` statement, the temporary path will be released.
400
+
401
+ Note:
402
+ If the ``filepath`` is a local path, just return itself.
403
+
404
+ .. warning::
405
+ ``get_local_path`` is an experimental interface that may change in
406
+ the future.
407
+
408
+ Args:
409
+ filepath (str or Path): Path to be read data.
410
+
411
+ Examples:
412
+ >>> file_client = FileClient(prefix='http')
413
+ >>> with file_client.get_local_path('http://example.com/abc.jpg') as path:
414
+ ... # do something here
415
+
416
+ Yields:
417
+ Iterable[str]: Only yield one path.
418
+ """
419
+ with self.client.get_local_path(str(filepath)) as local_path:
420
+ yield local_path
421
+
422
+ def list_dir_or_file( # pylint: disable=too-many-arguments
423
+ self,
424
+ dir_path: str | Path,
425
+ list_dir: bool = True,
426
+ list_file: bool = True,
427
+ suffix: str | tuple[str] | None = None,
428
+ recursive: bool = False,
429
+ ) -> Iterator[str]:
430
+ """Scan a directory to find the interested directories or files in
431
+ arbitrary order.
432
+
433
+ Note:
434
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
435
+
436
+ Args:
437
+ dir_path (str | Path): Path of the directory.
438
+ list_dir (bool): List the directories. Defaults to True.
439
+ list_file (bool): List the path of files. Defaults to True.
440
+ suffix (str or tuple[str], optional): File suffix
441
+ that we are interested in. Defaults to None.
442
+ recursive (bool): If set to True, recursively scan the
443
+ directory. Defaults to False.
444
+
445
+ Yields:
446
+ Iterable[str]: A relative path to ``dir_path``.
447
+ """
448
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
imaginaire/utils/easy_io/handlers/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17
+ from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
18
+ from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
19
+ from imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler
20
+ from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
21
+
22
+ __all__ = [
23
+ "BaseFileHandler",
24
+ "JsonHandler",
25
+ "PickleHandler",
26
+ "YamlHandler",
27
+ "file_handlers",
28
+ "register_handler",
29
+ ]
imaginaire/utils/easy_io/handlers/base.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABCMeta, abstractmethod
17
+
18
+
19
+ class BaseFileHandler(metaclass=ABCMeta):
20
+ # `str_like` is a flag to indicate whether the type of file object is
21
+ # str-like object or bytes-like object. Pickle only processes bytes-like
22
+ # objects but json only processes str-like object. If it is str-like
23
+ # object, `StringIO` will be used to process the buffer.
24
+ str_like = True
25
+
26
+ @abstractmethod
27
+ def load_from_fileobj(self, file, **kwargs):
28
+ pass
29
+
30
+ @abstractmethod
31
+ def dump_to_fileobj(self, obj, file, **kwargs):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def dump_to_str(self, obj, **kwargs):
36
+ pass
37
+
38
+ def load_from_path(self, filepath, mode="r", **kwargs):
39
+ with open(filepath, mode) as f:
40
+ return self.load_from_fileobj(f, **kwargs)
41
+
42
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
43
+ with open(filepath, mode) as f:
44
+ self.dump_to_fileobj(obj, f, **kwargs)
imaginaire/utils/easy_io/handlers/byte_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import IO
17
+
18
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
19
+
20
+
21
+ class ByteHandler(BaseFileHandler):
22
+ str_like = False
23
+
24
+ def load_from_fileobj(self, file: IO[bytes], **kwargs):
25
+ file.seek(0)
26
+ # extra all bytes and return
27
+ return file.read()
28
+
29
+ def dump_to_fileobj(
30
+ self,
31
+ obj: bytes,
32
+ file: IO[bytes],
33
+ **kwargs,
34
+ ):
35
+ # write all bytes to file
36
+ file.write(obj)
37
+
38
+ def dump_to_str(self, obj, **kwargs):
39
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/csv_handler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import csv
17
+ from io import StringIO
18
+
19
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
20
+
21
+
22
+ class CsvHandler(BaseFileHandler):
23
+ def load_from_fileobj(self, file, **kwargs):
24
+ del kwargs
25
+ reader = csv.reader(file)
26
+ return list(reader)
27
+
28
+ def dump_to_fileobj(self, obj, file, **kwargs):
29
+ del kwargs
30
+ writer = csv.writer(file)
31
+ if not all(isinstance(row, list) for row in obj):
32
+ raise ValueError("Each row must be a list")
33
+ writer.writerows(obj)
34
+
35
+ def dump_to_str(self, obj, **kwargs):
36
+ del kwargs
37
+ output = StringIO()
38
+ writer = csv.writer(output)
39
+ if not all(isinstance(row, list) for row in obj):
40
+ raise ValueError("Each row must be a list")
41
+ writer.writerows(obj)
42
+ return output.getvalue()
imaginaire/utils/easy_io/handlers/gzip_handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gzip
17
+ import pickle
18
+ from io import BytesIO
19
+ from typing import Any
20
+
21
+ from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
22
+
23
+
24
+ class GzipHandler(PickleHandler):
25
+ str_like = False
26
+
27
+ def load_from_fileobj(self, file: BytesIO, **kwargs):
28
+ with gzip.GzipFile(fileobj=file, mode="rb") as f:
29
+ return pickle.load(f)
30
+
31
+ def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
32
+ with gzip.GzipFile(fileobj=file, mode="wb") as f:
33
+ pickle.dump(obj, f)
imaginaire/utils/easy_io/handlers/imageio_video_handler.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import IO, Any
17
+
18
+ import imageio
19
+ import imageio.v3 as iio_v3
20
+ import numpy as np
21
+ import torch
22
+
23
+ from imaginaire.utils import log
24
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
25
+
26
+
27
+ class ImageioVideoHandler(BaseFileHandler):
28
+ str_like = False
29
+
30
+ def load_from_fileobj(
31
+ self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs
32
+ ) -> tuple[np.ndarray, dict[str, Any]]:
33
+ """
34
+ Load video from a file-like object using imageio.v3 with specified format and color mode.
35
+
36
+ Parameters:
37
+ file (IO[bytes]): A file-like object containing video data.
38
+ format (str): Format of the video file (default 'mp4').
39
+ mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb').
40
+
41
+ Returns:
42
+ tuple: A tuple containing an array of video frames and metadata about the video.
43
+ """
44
+ file.seek(0)
45
+
46
+ # The plugin argument in v3 replaces the format argument in v2
47
+ plugin = kwargs.pop("plugin", "pyav")
48
+
49
+ # Load all frames at once using v3 API
50
+ video_frames = iio_v3.imread(file, plugin=plugin, **kwargs)
51
+
52
+ # Handle grayscale conversion if needed
53
+ if mode == "gray":
54
+ import cv2
55
+
56
+ if len(video_frames.shape) == 4: # (frames, height, width, channels)
57
+ gray_frames = []
58
+ for frame in video_frames:
59
+ gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
60
+ gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent
61
+ gray_frames.append(gray_frame)
62
+ video_frames = np.array(gray_frames)
63
+
64
+ # Extract metadata
65
+ # Note: iio_v3.imread doesn't return metadata directly like v2 did
66
+ # We need to extract it separately
67
+ file.seek(0)
68
+ metadata = self._extract_metadata(file, plugin=plugin)
69
+
70
+ return video_frames, metadata
71
+
72
+ def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> dict[str, Any]:
73
+ """
74
+ Extract metadata from a video file.
75
+
76
+ Parameters:
77
+ file (IO[bytes]): File-like object containing video data.
78
+ plugin (str): Plugin to use for reading.
79
+
80
+ Returns:
81
+ dict: Video metadata.
82
+ """
83
+ try:
84
+ # Create a generator to read frames and metadata
85
+ metadata = iio_v3.immeta(file, plugin=plugin)
86
+
87
+ # Add some standard fields similar to v2 metadata format
88
+ if "fps" not in metadata and "duration" in metadata:
89
+ # Read the first frame to get shape information
90
+ file.seek(0)
91
+ first_frame = iio_v3.imread(file, plugin=plugin, index=0)
92
+ metadata["size"] = first_frame.shape[1::-1] # (width, height)
93
+ metadata["source_size"] = metadata["size"]
94
+
95
+ # Create a consistent metadata structure with v2
96
+ metadata["plugin"] = plugin
97
+ if "codec" not in metadata:
98
+ metadata["codec"] = "unknown"
99
+ if "pix_fmt" not in metadata:
100
+ metadata["pix_fmt"] = "unknown"
101
+
102
+ # Calculate nframes if possible
103
+ if "fps" in metadata and "duration" in metadata:
104
+ metadata["nframes"] = int(metadata["fps"] * metadata["duration"])
105
+ else:
106
+ metadata["nframes"] = float("inf")
107
+
108
+ return metadata
109
+
110
+ except Exception as e:
111
+ # Fallback to basic metadata
112
+ return {
113
+ "plugin": plugin,
114
+ "nframes": float("inf"),
115
+ "codec": "unknown",
116
+ "fps": 30.0, # Default values
117
+ "duration": 0,
118
+ "size": (0, 0),
119
+ }
120
+
121
+ def dump_to_fileobj(
122
+ self,
123
+ obj: np.ndarray | torch.Tensor,
124
+ file: IO[bytes],
125
+ format: str = "mp4", # pylint: disable=redefined-builtin
126
+ fps: int = 17,
127
+ quality: int = 7,
128
+ ffmpeg_params=None,
129
+ **kwargs,
130
+ ):
131
+ """
132
+ Save an array of video frames to a file-like object using imageio.
133
+
134
+ Parameters:
135
+ obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video.
136
+ file (IO[bytes]): A file-like object to which the video data will be written.
137
+ format (str): Format of the video file (default 'mp4').
138
+ fps (int): Frames per second of the output video (default 17).
139
+ quality (int): Quality of the video (0-10, default 5).
140
+ ffmpeg_params (list): Additional parameters to pass to ffmpeg.
141
+
142
+ """
143
+ if isinstance(obj, torch.Tensor):
144
+ assert obj.dtype == torch.uint8, "Tensor must be of type uint8"
145
+ obj = obj.cpu().numpy()
146
+ h, w = obj.shape[1:-1]
147
+
148
+ # Default ffmpeg params that ensure width and height are set
149
+ default_ffmpeg_params = ["-s", f"{w}x{h}"]
150
+
151
+ # Use provided ffmpeg_params if any, otherwise use defaults
152
+ final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params
153
+
154
+ mimsave_kwargs = {
155
+ "fps": fps,
156
+ "quality": quality,
157
+ "macro_block_size": 1,
158
+ "ffmpeg_params": final_ffmpeg_params,
159
+ "output_params": ["-f", "mp4"],
160
+ }
161
+ # Update with any other kwargs
162
+ mimsave_kwargs.update(kwargs)
163
+ log.debug(f"mimsave_kwargs: {mimsave_kwargs}")
164
+
165
+ imageio.mimsave(file, obj, format, **mimsave_kwargs)
166
+
167
+ def dump_to_str(self, obj, **kwargs):
168
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/json_handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+
18
+ import numpy as np
19
+
20
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21
+
22
+
23
+ def set_default(obj):
24
+ """Set default json values for non-serializable values.
25
+
26
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
27
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
28
+ etc.) into plain numbers of plain python built-in types.
29
+ """
30
+ if isinstance(obj, (set, range)):
31
+ return list(obj)
32
+ elif isinstance(obj, np.ndarray):
33
+ return obj.tolist()
34
+ elif isinstance(obj, np.generic):
35
+ return obj.item()
36
+ raise TypeError(f"{type(obj)} is unsupported for json dump")
37
+
38
+
39
+ class JsonHandler(BaseFileHandler):
40
+ def load_from_fileobj(self, file):
41
+ return json.load(file)
42
+
43
+ def dump_to_fileobj(self, obj, file, **kwargs):
44
+ kwargs.setdefault("default", set_default)
45
+ json.dump(obj, file, **kwargs)
46
+
47
+ def dump_to_str(self, obj, **kwargs):
48
+ kwargs.setdefault("default", set_default)
49
+ return json.dumps(obj, **kwargs)
imaginaire/utils/easy_io/handlers/jsonl_handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ from typing import IO
18
+
19
+ import numpy as np
20
+
21
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
+
23
+
24
+ def set_default(obj):
25
+ """Set default json values for non-serializable values.
26
+
27
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
28
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
29
+ etc.) into plain numbers of plain python built-in types.
30
+ """
31
+ if isinstance(obj, (set, range)):
32
+ return list(obj)
33
+ elif isinstance(obj, np.ndarray):
34
+ return obj.tolist()
35
+ elif isinstance(obj, np.generic):
36
+ return obj.item()
37
+ raise TypeError(f"{type(obj)} is unsupported for json dump")
38
+
39
+
40
+ class JsonlHandler(BaseFileHandler):
41
+ """Handler for JSON lines (JSONL) files."""
42
+
43
+ def load_from_fileobj(self, file: IO[bytes]):
44
+ """Load JSON objects from a newline-delimited JSON (JSONL) file object.
45
+
46
+ Returns:
47
+ A list of Python objects loaded from each JSON line.
48
+ """
49
+ data = []
50
+ for line in file:
51
+ line = line.strip()
52
+ if not line:
53
+ continue # skip empty lines if any
54
+ data.append(json.loads(line))
55
+ return data
56
+
57
+ def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs):
58
+ """Dump a list of objects to a newline-delimited JSON (JSONL) file object.
59
+
60
+ Args:
61
+ obj: A list (or iterable) of objects to dump line by line.
62
+ """
63
+ kwargs.setdefault("default", set_default)
64
+ for item in obj:
65
+ file.write(json.dumps(item, **kwargs) + "\n")
66
+
67
+ def dump_to_str(self, obj, **kwargs):
68
+ """Dump a list of objects to a newline-delimited JSON (JSONL) string."""
69
+ kwargs.setdefault("default", set_default)
70
+ lines = [json.dumps(item, **kwargs) for item in obj]
71
+ return "\n".join(lines)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ from imaginaire.utils.easy_io import easy_io
76
+
77
+ easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl")
78
+ print(easy_io.load("test.jsonl"))
79
+ easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl")
80
+ print(easy_io.load("test.jsonl"))
imaginaire/utils/easy_io/handlers/np_handler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from io import BytesIO
17
+ from typing import IO, Any
18
+
19
+ import numpy as np
20
+
21
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
+
23
+
24
+ class NumpyHandler(BaseFileHandler):
25
+ str_like = False
26
+
27
+ def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any:
28
+ """
29
+ Load a NumPy array from a file-like object.
30
+
31
+ Parameters:
32
+ file (IO[bytes]): The file-like object containing the NumPy array data.
33
+ **kwargs: Additional keyword arguments passed to `np.load`.
34
+
35
+ Returns:
36
+ numpy.ndarray: The loaded NumPy array.
37
+ """
38
+ return np.load(file, **kwargs)
39
+
40
+ def load_from_path(self, filepath: str, **kwargs) -> Any:
41
+ """
42
+ Load a NumPy array from a file path.
43
+
44
+ Parameters:
45
+ filepath (str): The path to the file to load.
46
+ **kwargs: Additional keyword arguments passed to `np.load`.
47
+
48
+ Returns:
49
+ numpy.ndarray: The loaded NumPy array.
50
+ """
51
+ return super().load_from_path(filepath, mode="rb", **kwargs)
52
+
53
+ def dump_to_str(self, obj: np.ndarray, **kwargs) -> str:
54
+ """
55
+ Serialize a NumPy array to a string in binary format.
56
+
57
+ Parameters:
58
+ obj (np.ndarray): The NumPy array to serialize.
59
+ **kwargs: Additional keyword arguments passed to `np.save`.
60
+
61
+ Returns:
62
+ str: The serialized NumPy array as a string.
63
+ """
64
+ with BytesIO() as f:
65
+ np.save(f, obj, **kwargs)
66
+ return f.getvalue()
67
+
68
+ def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs):
69
+ """
70
+ Dump a NumPy array to a file-like object.
71
+
72
+ Parameters:
73
+ obj (np.ndarray): The NumPy array to dump.
74
+ file (IO[bytes]): The file-like object to which the array is dumped.
75
+ **kwargs: Additional keyword arguments passed to `np.save`.
76
+ """
77
+ np.save(file, obj, **kwargs)
78
+
79
+ def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs):
80
+ """
81
+ Dump a NumPy array to a file path.
82
+
83
+ Parameters:
84
+ obj (np.ndarray): The NumPy array to dump.
85
+ filepath (str): The file path where the array should be saved.
86
+ **kwargs: Additional keyword arguments passed to `np.save`.
87
+ """
88
+ with open(filepath, "wb") as f:
89
+ np.save(f, obj, **kwargs)
imaginaire/utils/easy_io/handlers/pandas_handler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pandas as pd
17
+
18
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
19
+
20
+
21
+ class PandasHandler(BaseFileHandler):
22
+ str_like = False
23
+
24
+ def load_from_fileobj(self, file, **kwargs):
25
+ return pd.read_csv(file, **kwargs)
26
+
27
+ def dump_to_fileobj(self, obj, file, **kwargs):
28
+ obj.to_csv(file, **kwargs)
29
+
30
+ def dump_to_str(self, obj, **kwargs):
31
+ raise NotImplementedError("PandasHandler does not support dumping to str")
imaginaire/utils/easy_io/handlers/pickle_handler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pickle
17
+ from io import BytesIO
18
+ from typing import Any
19
+
20
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21
+
22
+
23
+ class PickleHandler(BaseFileHandler):
24
+ str_like = False
25
+
26
+ def load_from_fileobj(self, file: BytesIO, **kwargs):
27
+ return pickle.load(file, **kwargs)
28
+
29
+ def load_from_path(self, filepath, **kwargs):
30
+ return super().load_from_path(filepath, mode="rb", **kwargs)
31
+
32
+ def dump_to_str(self, obj, **kwargs):
33
+ kwargs.setdefault("protocol", 2)
34
+ return pickle.dumps(obj, **kwargs)
35
+
36
+ def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
37
+ kwargs.setdefault("protocol", 2)
38
+ pickle.dump(obj, file, **kwargs)
39
+
40
+ def dump_to_path(self, obj, filepath, **kwargs):
41
+ with open(filepath, "wb") as f:
42
+ pickle.dump(obj, f, **kwargs)
imaginaire/utils/easy_io/handlers/pil_handler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import IO
17
+
18
+ import numpy as np
19
+
20
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21
+
22
+ try:
23
+ from PIL import Image
24
+ except ImportError:
25
+ Image = None
26
+
27
+
28
+ class PILHandler(BaseFileHandler):
29
+ format: str
30
+ str_like = False
31
+
32
+ def load_from_fileobj(
33
+ self,
34
+ file: IO[bytes],
35
+ fmt: str = "pil",
36
+ size: int | tuple[int, int] | None = None,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ Load an image from a file-like object and return it in a specified format.
41
+
42
+ Args:
43
+ file (IO[bytes]): A file-like object containing the image data.
44
+ fmt (str): The format to convert the image into. Options are \
45
+ 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \
46
+ 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor).
47
+ size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \
48
+ or a tuple of (width, height). If specified, the image is resized accordingly.
49
+ **kwargs: Additional keyword arguments that can be passed to conversion functions.
50
+
51
+ Returns:
52
+ Image data in the format specified by `fmt`.
53
+
54
+ Raises:
55
+ IOError: If the image cannot be loaded or processed.
56
+ ValueError: If the specified format is unsupported.
57
+ """
58
+ try:
59
+ img = Image.open(file)
60
+ img.load() # Explicitly load the image data
61
+ if size is not None:
62
+ if isinstance(size, int):
63
+ size = (
64
+ size,
65
+ size,
66
+ ) # create a tuple if only one integer is provided
67
+ img = img.resize(size, Image.ANTIALIAS)
68
+
69
+ # Return the image in the requested format
70
+ if fmt in ["numpy", "np", "npy"]:
71
+ return np.array(img, **kwargs)
72
+ if fmt == "pil":
73
+ return img
74
+ if fmt in ["th", "torch"]:
75
+ import torch
76
+
77
+ # Convert to tensor
78
+ img_tensor = torch.from_numpy(np.array(img, **kwargs))
79
+ # Convert image from HxWxC to CxHxW
80
+ if img_tensor.ndim == 3:
81
+ img_tensor = img_tensor.permute(2, 0, 1)
82
+ return img_tensor
83
+ raise ValueError(
84
+ "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'."
85
+ )
86
+ except Exception as e:
87
+ raise OSError(f"Unable to load image: {e}") from e
88
+
89
+ def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs):
90
+ if "format" not in kwargs:
91
+ kwargs["format"] = self.format
92
+ kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper()
93
+ obj.save(file, **kwargs)
94
+
95
+ def dump_to_str(self, obj, **kwargs):
96
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/registry_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17
+ from imaginaire.utils.easy_io.handlers.byte_handler import ByteHandler
18
+ from imaginaire.utils.easy_io.handlers.csv_handler import CsvHandler
19
+ from imaginaire.utils.easy_io.handlers.gzip_handler import GzipHandler
20
+ from imaginaire.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler
21
+ from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
22
+ from imaginaire.utils.easy_io.handlers.jsonl_handler import JsonlHandler
23
+ from imaginaire.utils.easy_io.handlers.np_handler import NumpyHandler
24
+ from imaginaire.utils.easy_io.handlers.pandas_handler import PandasHandler
25
+ from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
26
+ from imaginaire.utils.easy_io.handlers.pil_handler import PILHandler
27
+ from imaginaire.utils.easy_io.handlers.tarfile_handler import TarHandler
28
+ from imaginaire.utils.easy_io.handlers.torch_handler import TorchHandler
29
+ from imaginaire.utils.easy_io.handlers.torchjit_handler import TorchJitHandler
30
+ from imaginaire.utils.easy_io.handlers.txt_handler import TxtHandler
31
+ from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
32
+
33
+ file_handlers = {
34
+ "json": JsonHandler(),
35
+ "yaml": YamlHandler(),
36
+ "yml": YamlHandler(),
37
+ "pickle": PickleHandler(),
38
+ "pkl": PickleHandler(),
39
+ "tar": TarHandler(),
40
+ "jit": TorchJitHandler(),
41
+ "npy": NumpyHandler(),
42
+ "txt": TxtHandler(),
43
+ "csv": CsvHandler(),
44
+ "pandas": PandasHandler(),
45
+ "gz": GzipHandler(),
46
+ "jsonl": JsonlHandler(),
47
+ "byte": ByteHandler(),
48
+ }
49
+
50
+ for torch_type in ["pt", "pth", "ckpt"]:
51
+ file_handlers[torch_type] = TorchHandler()
52
+ for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]:
53
+ file_handlers[img_type] = PILHandler()
54
+ file_handlers[img_type].format = img_type
55
+ for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]:
56
+ file_handlers[video_type] = ImageioVideoHandler()
57
+
58
+
59
+ def _register_handler(handler, file_formats):
60
+ """Register a handler for some file extensions.
61
+
62
+ Args:
63
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
64
+ file_formats (str or list[str]): File formats to be handled by this
65
+ handler.
66
+ """
67
+ if not isinstance(handler, BaseFileHandler):
68
+ raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}")
69
+ if isinstance(file_formats, str):
70
+ file_formats = [file_formats]
71
+ if not all([isinstance(item, str) for item in file_formats]):
72
+ raise TypeError("file_formats must be a str or a list of str")
73
+ for ext in file_formats:
74
+ file_handlers[ext] = handler
75
+
76
+
77
+ def register_handler(file_formats, **kwargs):
78
+ def wrap(cls):
79
+ _register_handler(cls(**kwargs), file_formats)
80
+ return cls
81
+
82
+ return wrap
imaginaire/utils/easy_io/handlers/tarfile_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import tarfile
17
+
18
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
19
+
20
+
21
+ class TarHandler(BaseFileHandler):
22
+ str_like = False
23
+
24
+ def load_from_fileobj(self, file, mode="r|*", **kwargs):
25
+ return tarfile.open(fileobj=file, mode=mode, **kwargs)
26
+
27
+ def load_from_path(self, filepath, mode="r|*", **kwargs):
28
+ return tarfile.open(filepath, mode=mode, **kwargs)
29
+
30
+ def dump_to_fileobj(self, obj, file, mode="w", **kwargs):
31
+ with tarfile.open(fileobj=file, mode=mode) as tar:
32
+ tar.add(obj, **kwargs)
33
+
34
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
35
+ with tarfile.open(filepath, mode=mode) as tar:
36
+ tar.add(obj, **kwargs)
37
+
38
+ def dump_to_str(self, obj, **kwargs):
39
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/torch_handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ try:
17
+ import torch
18
+ except ImportError:
19
+ torch = None
20
+
21
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
+
23
+
24
+ class TorchHandler(BaseFileHandler):
25
+ str_like = False
26
+
27
+ def load_from_fileobj(self, file, **kwargs):
28
+ return torch.load(file, **kwargs)
29
+
30
+ def dump_to_fileobj(self, obj, file, **kwargs):
31
+ torch.save(obj, file, **kwargs)
32
+
33
+ def dump_to_str(self, obj, **kwargs):
34
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/torchjit_handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ try:
17
+ import torch
18
+ except ImportError:
19
+ torch = None
20
+
21
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
+
23
+
24
+ class TorchJitHandler(BaseFileHandler):
25
+ str_like = False
26
+
27
+ def load_from_fileobj(self, file, **kwargs):
28
+ return torch.jit.load(file, **kwargs)
29
+
30
+ def dump_to_fileobj(self, obj, file, **kwargs):
31
+ torch.jit.save(obj, file, **kwargs)
32
+
33
+ def dump_to_str(self, obj, **kwargs):
34
+ raise NotImplementedError
imaginaire/utils/easy_io/handlers/txt_handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17
+
18
+
19
+ class TxtHandler(BaseFileHandler):
20
+ def load_from_fileobj(self, file, **kwargs):
21
+ del kwargs
22
+ return file.read()
23
+
24
+ def dump_to_fileobj(self, obj, file, **kwargs):
25
+ del kwargs
26
+ if not isinstance(obj, str):
27
+ obj = str(obj)
28
+ file.write(obj)
29
+
30
+ def dump_to_str(self, obj, **kwargs):
31
+ del kwargs
32
+ if not isinstance(obj, str):
33
+ obj = str(obj)
34
+ return obj
imaginaire/utils/easy_io/handlers/yaml_handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import yaml
17
+
18
+ try:
19
+ from yaml import CDumper as Dumper # type: ignore
20
+ from yaml import CLoader as Loader # type: ignore
21
+ except ImportError:
22
+ from yaml import Dumper, Loader # type: ignore
23
+
24
+ from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
25
+
26
+
27
+ class YamlHandler(BaseFileHandler):
28
+ def load_from_fileobj(self, file, **kwargs):
29
+ kwargs.setdefault("Loader", Loader)
30
+ return yaml.load(file, **kwargs)
31
+
32
+ def dump_to_fileobj(self, obj, file, **kwargs):
33
+ kwargs.setdefault("Dumper", Dumper)
34
+ yaml.dump(obj, file, **kwargs)
35
+
36
+ def dump_to_str(self, obj, **kwargs):
37
+ kwargs.setdefault("Dumper", Dumper)
38
+ return yaml.dump(obj, **kwargs)
imaginaire/utils/ema.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ try:
24
+ from megatron.core import parallel_state
25
+
26
+ USE_MEGATRON = True
27
+ except ImportError:
28
+ USE_MEGATRON = False
29
+
30
+ from imaginaire.utils import distributed, log
31
+
32
+ if TYPE_CHECKING:
33
+ from imaginaire.model import ImaginaireModel
34
+
35
+
36
+ class FastEmaModelUpdater:
37
+ """
38
+ This class is used to update target model~(EMA) given source model~(regular model) and beta.
39
+ The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`.
40
+ Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape.
41
+ The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes.
42
+ """
43
+
44
+ def __init__(self):
45
+ # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite
46
+ self.is_cached = False
47
+
48
+ @torch.no_grad()
49
+ def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None:
50
+ for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters(), strict=False):
51
+ tgt_params.data.copy_(src_params.data)
52
+
53
+ @torch.no_grad()
54
+ def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None:
55
+ target_list = []
56
+ source_list = []
57
+ for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters(), strict=False):
58
+ assert tgt_params.dtype == torch.float32, (
59
+ f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead."
60
+ )
61
+ target_list.append(tgt_params)
62
+ source_list.append(src_params.data)
63
+ torch._foreach_mul_(target_list, beta)
64
+ torch._foreach_add_(target_list, source_list, alpha=1.0 - beta)
65
+
66
+ @torch.no_grad()
67
+ def cache(self, parameters: Any, is_cpu: bool = False) -> None:
68
+ """Save the current parameters for restoring later.
69
+
70
+ Args:
71
+ parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored.
72
+ """
73
+ assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?"
74
+ device = "cpu" if is_cpu else "cuda"
75
+ self.collected_params = [param.clone().to(device) for param in parameters]
76
+ self.is_cached = True
77
+
78
+ @torch.no_grad()
79
+ def restore(self, parameters: Any) -> None:
80
+ """Restore the parameters in self.collected_params.
81
+
82
+ Useful to validate the model with EMA parameters without affecting the
83
+ original optimization process. Store the parameters before copy_to().
84
+ After validation (or model saving), use this to restore the former parameters.
85
+
86
+ Args:
87
+ parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters.
88
+ """
89
+ assert self.is_cached, "EMA cache is not taken yet."
90
+ for c_param, param in zip(self.collected_params, parameters, strict=False):
91
+ param.data.copy_(c_param.data.type_as(param.data))
92
+ self.collected_params = []
93
+ # Release the cache after we call restore
94
+ self.is_cached = False
95
+
96
+
97
+ def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str:
98
+ """
99
+ This function creates buffer name used by EMA from parameter's name
100
+
101
+ Args:
102
+ param_name (str): Model's parameter name
103
+ Returns:
104
+ buffer_name (str): buffer name to be used for given parameter name
105
+ """
106
+
107
+ buffer_name = param_name.replace(".", "-")
108
+
109
+ if torch_compile_buffer_renaming:
110
+ # torch.compile() adds _orig_mod to state dict names, this way we get original name
111
+ buffer_name = buffer_name.replace("_orig_mod-", "")
112
+
113
+ return buffer_name
114
+
115
+
116
+ class EMAModelTracker(torch.nn.Module):
117
+ """This is a class to track the EMA model weights.
118
+
119
+ The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the
120
+ regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's
121
+ implementation of EMA. There are no optimizable parameters.
122
+
123
+ Attributes:
124
+ collected_params (list): temporarily stores the regular weights while in EMA mode.
125
+ beta (float): EMA decay rate. (default: 0.9999).
126
+ torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used
127
+ """
128
+
129
+ def __init__(self, model: ImaginaireModel, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False):
130
+ """Constructor of the EMA model weight tracker.
131
+
132
+ Args:
133
+ model (ImaginaireModel): The PyTorch model.
134
+ beta (float): EMA decay rate. (default: 0.9999).
135
+ """
136
+ super().__init__()
137
+ self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming
138
+ if not 0.0 <= beta <= 1.0:
139
+ raise ValueError("Decay must be between 0 and 1")
140
+ self.beta = beta
141
+ for name, param in model.named_parameters():
142
+ if param.requires_grad:
143
+ buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
144
+ self.register_buffer(buffer_name, param.clone().detach().data)
145
+ self.collected_params = []
146
+ # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite
147
+ self.is_cached = False
148
+
149
+ @torch.no_grad()
150
+ def update_average(self, model: ImaginaireModel, iteration: int | None = None) -> None:
151
+ del iteration
152
+ target_list = []
153
+ source_list = []
154
+ ema_buffers = self.state_dict()
155
+ for name, param in model.named_parameters():
156
+ if param.requires_grad:
157
+ buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
158
+ buffer = ema_buffers[buffer_name]
159
+ assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead."
160
+ target_list.append(buffer)
161
+ source_list.append(param.data)
162
+ torch._foreach_mul_(target_list, self.beta)
163
+ torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta)
164
+
165
+ def copy_to(self, model: ImaginaireModel) -> None:
166
+ ema_buffers = self.state_dict()
167
+ for name, param in model.named_parameters():
168
+ if param.requires_grad:
169
+ buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
170
+ buffer = ema_buffers[buffer_name]
171
+ param.data.copy_(buffer.data)
172
+
173
+ def cache(self, parameters: Any, is_cpu: bool = False) -> None:
174
+ """Save the current parameters for restoring later.
175
+
176
+ Args:
177
+ parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored.
178
+ """
179
+ assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?"
180
+ device = "cpu" if is_cpu else "cuda"
181
+ self.collected_params = [param.clone().to(device) for param in parameters]
182
+ self.is_cached = True
183
+
184
+ def restore(self, parameters: Any) -> None:
185
+ """Restore the parameters in self.collected_params.
186
+
187
+ Useful to validate the model with EMA parameters without affecting the
188
+ original optimization process. Store the parameters before copy_to().
189
+ After validation (or model saving), use this to restore the former parameters.
190
+
191
+ Args:
192
+ parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters.
193
+ """
194
+ assert self.is_cached, "EMA cache is not taken yet."
195
+ for c_param, param in zip(self.collected_params, parameters, strict=False):
196
+ param.data.copy_(c_param.data.type_as(param.data))
197
+ self.collected_params = []
198
+ # Release the cache after we call restore
199
+ self.is_cached = False
200
+
201
+ @classmethod
202
+ def initialize_multi_rank_ema(
203
+ cls, model: torch.nn.Module, rate: float | list[float], num: int = 1, enabled: bool = True
204
+ ) -> EMAModelTracker | None:
205
+ """
206
+ Class method to initialize per rank EMA Model Tracker with different rate.
207
+ Each rank will have a different rate based on the given configuration, resulting in different EMA weights.
208
+
209
+ Args:
210
+ model (torch.nn.Module): The neural network model to be tracked.
211
+ rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided,
212
+ it corresponds to rates for different ranks.
213
+ num (int, optional): The number of leading ranks to consider for different rates.
214
+ Defaults to 1.
215
+ enabled (bool, optional): Flag to enable or disable the creation of the tracker.
216
+ If False, returns None. Defaults to True.
217
+
218
+ Returns:
219
+ Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None.
220
+
221
+ Example:
222
+ >>> model = torch.nn.Linear(10, 2)
223
+ >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2)
224
+ >>> print(tracker)
225
+
226
+ Notes:
227
+ If `rate` is a list and the current rank is less than `num`, the rate for the current rank
228
+ is used. If the current rank exceeds `num`, the first rate in the list is used by default.
229
+ """
230
+ if not enabled:
231
+ return None
232
+ if USE_MEGATRON and parallel_state.is_initialized():
233
+ cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
234
+ log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
235
+ log.warning("It should not used together with FSDP!")
236
+ else:
237
+ cur_dp_rank = distributed.get_rank()
238
+ log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
239
+ rate = rate if isinstance(rate, list) else [rate]
240
+ num = min(num, len(rate))
241
+ rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0]
242
+ if cur_dp_rank < num:
243
+ print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}")
244
+ return cls(model, rate)
245
+
246
+
247
+ class PowerEMATracker(EMAModelTracker):
248
+ def __init__(self, model: ImaginaireModel, s: float = 0.1, torch_compile_buffer_renaming: bool = False):
249
+ """Constructor of the EMA model weight tracker.
250
+
251
+ Args:
252
+ model (ImaginaireModel): The PyTorch model.
253
+ s (float): EMA decay rate. See EDM2 paper
254
+ torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used
255
+ """
256
+ super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming)
257
+ self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max()
258
+
259
+ @torch.no_grad()
260
+ def update_average(self, model: ImaginaireModel, iteration: int | None = None) -> None:
261
+ if iteration == 0:
262
+ beta = 0.0
263
+ else:
264
+ i = iteration + 1
265
+ beta = (1 - 1 / i) ** (self.exp + 1)
266
+ self.beta = beta
267
+
268
+ super().update_average(model, iteration)
269
+
270
+ @classmethod
271
+ def initialize_multi_rank_ema(
272
+ cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True
273
+ ) -> PowerEMATracker | None:
274
+ """
275
+ Class method to initialize per rank EMA Model Tracker with different rate.
276
+ Each rank will have a different rate based on the given configuration, resulting in different EMA weights.
277
+
278
+ Args:
279
+ model (torch.nn.Module): The neural network model for which the EMA tracker is being set up.
280
+ num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged.
281
+ rate (float): The base decay rate for the EMA calculation.
282
+ enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None.
283
+ Defaults to True.
284
+
285
+ Returns:
286
+ Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None.
287
+
288
+ Raises:
289
+ None
290
+
291
+ Example:
292
+ >>> model = torch.nn.Linear(10, 2)
293
+ >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99)
294
+ >>> print(tracker)
295
+
296
+ Notes:
297
+ The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`.
298
+ If the rank is greater than or equal to `num`, the base rate is used without modification. This approach
299
+ allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization
300
+ in a distributed training scenario.
301
+ """
302
+ if not enabled:
303
+ return None
304
+ if USE_MEGATRON and parallel_state.is_initialized():
305
+ cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
306
+ log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
307
+ log.warning("It should not used together with FSDP!")
308
+ else:
309
+ cur_dp_rank = distributed.get_rank()
310
+ log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
311
+
312
+ divider = 2**cur_dp_rank if cur_dp_rank < num else 1
313
+ if cur_dp_rank < num:
314
+ print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}")
315
+ return cls(model, rate / divider)
imaginaire/utils/fused_adam.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from apex.multi_tensor_apply import multi_tensor_applier
18
+
19
+ from imaginaire.utils import distributed, log
20
+
21
+
22
+ class FusedAdam(torch.optim.Optimizer):
23
+ """Implements Adam algorithm.
24
+
25
+ Currently GPU-only. Requires Apex to be installed via
26
+ ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
27
+
28
+ This version of fused Adam implements 2 fusions.
29
+
30
+ * Fusion of the Adam update's elementwise operations
31
+ * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters
32
+ into one or a few kernel launches.
33
+
34
+ :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
35
+ or ``torch.optim.Adam`` with ``adam_w_mode=False``::
36
+
37
+ opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
38
+ ...
39
+ opt.step()
40
+
41
+ :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp,
42
+ you may choose any ``opt_level``::
43
+
44
+ opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
45
+ model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
46
+ ...
47
+ opt.step()
48
+
49
+ In general, ``opt_level="O1"`` is recommended.
50
+
51
+
52
+ .. warning::
53
+ A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``.
54
+ These additional arguments are now deprecated and unnecessary.
55
+
56
+ Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
57
+
58
+ Arguments:
59
+ params (iterable): iterable of parameters to optimize or dicts defining
60
+ parameter groups.
61
+ lr (float, optional): learning rate. (default: 1e-3)
62
+ betas (Tuple[float, float], optional): coefficients used for computing
63
+ running averages of gradient and its square. (default: (0.9, 0.999))
64
+ eps (float, optional): term added to the denominator to improve
65
+ numerical stability. (default: 1e-8)
66
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
67
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
68
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
69
+ (default: False) NOT SUPPORTED in FusedAdam!
70
+ adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
71
+ True for decoupled weight decay(also known as AdamW) (default: True)
72
+ capturable (bool, optional): whether to use the version of the optimizer
73
+ that can be used with CUDA Graphs. (default: False)
74
+ master_weights (bool, optional): whether to maintain FP32 master weights
75
+ in the optimizer with FP16 mixed precision training, currently can
76
+ only be used with capturable set to True. (default: False)
77
+
78
+ .. _Adam - A Method for Stochastic Optimization:
79
+ https://arxiv.org/abs/1412.6980
80
+ .. _On the Convergence of Adam and Beyond:
81
+ https://openreview.net/forum?id=ryQu7f-RZ
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ params,
87
+ lr=1e-3,
88
+ bias_correction=True,
89
+ betas=(0.9, 0.999),
90
+ eps=1e-8,
91
+ adam_w_mode=True,
92
+ weight_decay=0.0,
93
+ amsgrad=False,
94
+ capturable=False,
95
+ master_weights=False,
96
+ ):
97
+ if amsgrad:
98
+ raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
99
+ if master_weights and not capturable:
100
+ raise RuntimeError("Master weights is currently only supported with the capturable version.")
101
+ # If the optimizer is capturable then LR should be a tensor (on GPU)
102
+ log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}")
103
+ lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
104
+ defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
105
+ super(FusedAdam, self).__init__(params, defaults) # noqa: UP008
106
+ self.adam_w_mode = 1 if adam_w_mode else 0
107
+
108
+ self.capturable = capturable
109
+ self.master_weights = master_weights
110
+
111
+ self.param_groups_master = None
112
+
113
+ if capturable:
114
+ for idx, group in enumerate(self.param_groups):
115
+ if len(group["params"]) == 0:
116
+ continue
117
+ device = group["params"][0].device
118
+ for item in ["lr"]:
119
+ if isinstance(group[item], float):
120
+ group[item] = torch.tensor(group[item], dtype=torch.float32)
121
+ self.param_groups[idx][item] = group[item].to(device=device)
122
+
123
+ self._step_supports_amp_scaling = True
124
+
125
+ if multi_tensor_applier.available:
126
+ import amp_C
127
+
128
+ # Skip buffer
129
+ self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
130
+ self.multi_tensor_adam = amp_C.multi_tensor_adam
131
+ self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable
132
+ self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master
133
+ else:
134
+ raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions")
135
+
136
+ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
137
+ """Performs a single optimization step.
138
+
139
+ Arguments:
140
+ closure (callable, optional): A closure that reevaluates the model
141
+ and returns the loss.
142
+
143
+ The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
144
+ """
145
+ if any(p is not None for p in [grads, output_params, scale, grad_norms]):
146
+ raise RuntimeError(
147
+ "FusedAdam has been updated. "
148
+ "Simply initialize it identically to torch.optim.Adam, and call step() with no arguments."
149
+ )
150
+ loss = None
151
+ if closure is not None:
152
+ loss = closure()
153
+
154
+ if self.param_groups_master is None:
155
+ # Create full precision master weights
156
+ self.param_groups_master = []
157
+ for i, pg in enumerate(self.param_groups): # noqa: B007
158
+ param_list = pg["params"]
159
+ self.param_groups_master.append(
160
+ {
161
+ "params": [p.clone().detach().float() if self.master_weights else None for p in param_list],
162
+ }
163
+ )
164
+
165
+ for group, group_master in zip(self.param_groups, self.param_groups_master, strict=False):
166
+ if len(group["params"]) == 0:
167
+ continue
168
+ device = group["params"][0].device
169
+ bias_correction = 1 if group.get("bias_correction") else 0
170
+ beta1, beta2 = group["betas"]
171
+
172
+ # assume same step across group now to simplify things
173
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
174
+ if "step" in group:
175
+ if self.capturable:
176
+ group["step"] = (
177
+ group["step"].to(device=device)
178
+ if isinstance(group["step"], torch.Tensor)
179
+ else torch.tensor(group["step"], dtype=torch.int32, device=device)
180
+ )
181
+ group["step"] += (self._dummy_overflow_buf != 1).to(torch.int)
182
+ else:
183
+ group["step"] += 1
184
+ else:
185
+ group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device)
186
+
187
+ if self.capturable:
188
+ group["lr"] = (
189
+ group["lr"].to(device=device)
190
+ if isinstance(group["lr"], torch.Tensor)
191
+ else torch.tensor(group["lr"], dtype=torch.float32, device=device)
192
+ )
193
+
194
+ # create lists for multi-tensor apply
195
+ g_16, p_16, m_16, v_16 = [], [], [], []
196
+ g_bf, p_bf, m_bf, v_bf = [], [], [], []
197
+ g_32, p_32, m_32, v_32 = [], [], [], []
198
+ p_16_master = []
199
+ p_32_master = []
200
+ bf16_master = []
201
+
202
+ for p, p_master in zip(group["params"], group_master["params"], strict=False):
203
+ if p.grad is None:
204
+ continue
205
+ if p.grad.data.is_sparse:
206
+ raise RuntimeError(
207
+ "FusedAdam does not support sparse gradients, please consider SparseAdam instead"
208
+ )
209
+
210
+ state = self.state[p]
211
+ # State initialization
212
+ if len(state) == 0:
213
+ # Exponential moving average of gradient values
214
+ state["exp_avg"] = torch.zeros_like(p.data).float()
215
+ # Exponential moving average of squared gradient values
216
+ state["exp_avg_sq"] = torch.zeros_like(p.data).float()
217
+
218
+ if p.dtype == torch.float16:
219
+ if self.master_weights:
220
+ p_16_master.append(p_master.data)
221
+ g_16.append(p.grad.data)
222
+ p_16.append(p.data)
223
+ m_16.append(state["exp_avg"])
224
+ v_16.append(state["exp_avg_sq"])
225
+ elif p.dtype == torch.bfloat16:
226
+ if self.master_weights:
227
+ bf16_master.append(p_master.data)
228
+ g_bf.append(p.grad)
229
+ p_bf.append(p)
230
+ m_bf.append(state["exp_avg"])
231
+ v_bf.append(state["exp_avg_sq"])
232
+ elif p.dtype == torch.float32:
233
+ if self.master_weights:
234
+ p_32_master.append(p_master.data)
235
+ g_32.append(p.grad.data)
236
+ p_32.append(p.data)
237
+ m_32.append(state["exp_avg"])
238
+ v_32.append(state["exp_avg_sq"])
239
+ else:
240
+ raise RuntimeError("FusedAdam only support fp16 and fp32.")
241
+
242
+ # If the optimizer is capturable, then if there's a grad scaler it works
243
+ # on the GPU + a different multi_tensor_applier should be called
244
+ if self.capturable:
245
+ # overflow check of gradients
246
+ found_inf = (
247
+ grad_scaler._check_inf_per_device(self)[device]
248
+ if grad_scaler is not None
249
+ else torch.zeros((1,), device=device)
250
+ )
251
+ self._dummy_overflow_buf.copy_(found_inf)
252
+
253
+ # get unscale scale factor
254
+ scale, inv_scale = None, None
255
+ if grad_scaler:
256
+ scale = grad_scaler._get_scale_async()
257
+ inv_scale = scale.double().reciprocal().float()
258
+ else:
259
+ scale = torch.ones((1,), device=device, dtype=torch.float32)
260
+ inv_scale = torch.ones((1,), device=device, dtype=torch.float32)
261
+
262
+ if len(g_16) > 0:
263
+ multi_tensor_applier(
264
+ (
265
+ self.multi_tensor_adam_capturable_master
266
+ if self.master_weights
267
+ else self.multi_tensor_adam_capturable
268
+ ),
269
+ self._dummy_overflow_buf,
270
+ [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16],
271
+ group["lr"],
272
+ beta1,
273
+ beta2,
274
+ group["eps"],
275
+ group["step"],
276
+ self.adam_w_mode,
277
+ bias_correction,
278
+ group["weight_decay"],
279
+ inv_scale,
280
+ )
281
+
282
+ if len(g_bf) > 0:
283
+ multi_tensor_applier(
284
+ (
285
+ self.multi_tensor_adam_capturable_master
286
+ if self.master_weights
287
+ else self.multi_tensor_adam_capturable
288
+ ),
289
+ self._dummy_overflow_buf,
290
+ [g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf],
291
+ group["lr"],
292
+ beta1,
293
+ beta2,
294
+ group["eps"],
295
+ group["step"],
296
+ self.adam_w_mode,
297
+ bias_correction,
298
+ group["weight_decay"],
299
+ inv_scale,
300
+ )
301
+
302
+ if len(g_32) > 0:
303
+ multi_tensor_applier(
304
+ (
305
+ self.multi_tensor_adam_capturable_master
306
+ if self.master_weights
307
+ else self.multi_tensor_adam_capturable
308
+ ),
309
+ self._dummy_overflow_buf,
310
+ [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32],
311
+ group["lr"],
312
+ beta1,
313
+ beta2,
314
+ group["eps"],
315
+ group["step"],
316
+ self.adam_w_mode,
317
+ bias_correction,
318
+ group["weight_decay"],
319
+ inv_scale,
320
+ )
321
+ else:
322
+ if len(g_16) > 0:
323
+ multi_tensor_applier(
324
+ self.multi_tensor_adam,
325
+ self._dummy_overflow_buf,
326
+ [g_16, p_16, m_16, v_16],
327
+ group["lr"],
328
+ beta1,
329
+ beta2,
330
+ group["eps"],
331
+ group["step"],
332
+ self.adam_w_mode,
333
+ bias_correction,
334
+ group["weight_decay"],
335
+ )
336
+
337
+ if len(g_bf) > 0:
338
+ multi_tensor_applier(
339
+ self.multi_tensor_adam,
340
+ self._dummy_overflow_buf,
341
+ [g_bf, p_bf, m_bf, v_bf],
342
+ group["lr"],
343
+ beta1,
344
+ beta2,
345
+ group["eps"],
346
+ group["step"],
347
+ self.adam_w_mode,
348
+ bias_correction,
349
+ group["weight_decay"],
350
+ )
351
+
352
+ if len(g_32) > 0:
353
+ multi_tensor_applier(
354
+ self.multi_tensor_adam,
355
+ self._dummy_overflow_buf,
356
+ [g_32, p_32, m_32, v_32],
357
+ group["lr"],
358
+ beta1,
359
+ beta2,
360
+ group["eps"],
361
+ group["step"],
362
+ self.adam_w_mode,
363
+ bias_correction,
364
+ group["weight_decay"],
365
+ )
366
+
367
+ return loss
368
+
369
+ def load_state_dict(self, state_dict):
370
+ super().load_state_dict(state_dict)
371
+ for group in self.param_groups:
372
+ if self.capturable:
373
+ group["lr"] = (
374
+ group["lr"].cuda()
375
+ if isinstance(group["lr"], torch.Tensor)
376
+ else torch.tensor(group["lr"], dtype=torch.float32).cuda()
377
+ )
378
+
379
+ if "step" in group:
380
+ if self.capturable:
381
+ if distributed.get_rank() == 0:
382
+ step = (
383
+ group["step"].cuda()
384
+ if isinstance(group["step"], torch.Tensor)
385
+ else torch.tensor([group["step"]], dtype=torch.int32).cuda()
386
+ )
387
+ else:
388
+ step = torch.zeros(1, dtype=torch.int32).cuda()
389
+ # make it compatible with FSDP optimizer
390
+ distributed.broadcast(step, 0)
391
+ group["step"] = step
392
+ elif isinstance(group["step"], torch.Tensor):
393
+ group["step"] = group["step"].item()
394
+ for p in group["params"]:
395
+ state = self.state[p]
396
+ if "exp_avg" in state:
397
+ state["exp_avg"] = state["exp_avg"].float()
398
+ state["exp_avg_sq"] = state["exp_avg_sq"].float()