Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
adecc3c
1
Parent(s):
4d42c48
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +15 -0
- README.md +5 -5
- app.py +175 -0
- imaginaire/__init__.py +14 -0
- imaginaire/callbacks/__init__.py +14 -0
- imaginaire/callbacks/every_n.py +84 -0
- imaginaire/callbacks/manual_gc.py +49 -0
- imaginaire/config.py +410 -0
- imaginaire/lazy_config/__init__.py +73 -0
- imaginaire/lazy_config/file_io.py +24 -0
- imaginaire/lazy_config/instantiate.py +119 -0
- imaginaire/lazy_config/lazy.py +442 -0
- imaginaire/lazy_config/omegaconf_patch.py +65 -0
- imaginaire/lazy_config/registry.py +74 -0
- imaginaire/model.py +137 -0
- imaginaire/trainer.py +322 -0
- imaginaire/utils/__init__.py +14 -0
- imaginaire/utils/callback.py +518 -0
- imaginaire/utils/checkpointer.py +282 -0
- imaginaire/utils/config_helper.py +201 -0
- imaginaire/utils/device.py +39 -0
- imaginaire/utils/distributed.py +444 -0
- imaginaire/utils/easy_io/__init__.py +14 -0
- imaginaire/utils/easy_io/backends/__init__.py +28 -0
- imaginaire/utils/easy_io/backends/base_backend.py +60 -0
- imaginaire/utils/easy_io/backends/http_backend.py +91 -0
- imaginaire/utils/easy_io/backends/local_backend.py +551 -0
- imaginaire/utils/easy_io/backends/registry_utils.py +125 -0
- imaginaire/utils/easy_io/easy_io.py +1034 -0
- imaginaire/utils/easy_io/file_client.py +448 -0
- imaginaire/utils/easy_io/handlers/__init__.py +29 -0
- imaginaire/utils/easy_io/handlers/base.py +44 -0
- imaginaire/utils/easy_io/handlers/byte_handler.py +39 -0
- imaginaire/utils/easy_io/handlers/csv_handler.py +42 -0
- imaginaire/utils/easy_io/handlers/gzip_handler.py +33 -0
- imaginaire/utils/easy_io/handlers/imageio_video_handler.py +168 -0
- imaginaire/utils/easy_io/handlers/json_handler.py +49 -0
- imaginaire/utils/easy_io/handlers/jsonl_handler.py +80 -0
- imaginaire/utils/easy_io/handlers/np_handler.py +89 -0
- imaginaire/utils/easy_io/handlers/pandas_handler.py +31 -0
- imaginaire/utils/easy_io/handlers/pickle_handler.py +42 -0
- imaginaire/utils/easy_io/handlers/pil_handler.py +96 -0
- imaginaire/utils/easy_io/handlers/registry_utils.py +82 -0
- imaginaire/utils/easy_io/handlers/tarfile_handler.py +39 -0
- imaginaire/utils/easy_io/handlers/torch_handler.py +34 -0
- imaginaire/utils/easy_io/handlers/torchjit_handler.py +34 -0
- imaginaire/utils/easy_io/handlers/txt_handler.py +34 -0
- imaginaire/utils/easy_io/handlers/yaml_handler.py +38 -0
- imaginaire/utils/ema.py +315 -0
- imaginaire/utils/fused_adam.py +398 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache and build files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
|
| 7 |
+
# Virtual environments
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
| 11 |
+
|
| 12 |
+
# IDE and misc
|
| 13 |
+
.idea/
|
| 14 |
+
.vscode/
|
| 15 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
@@ -11,4 +11,4 @@ license: apache-2.0
|
|
| 11 |
short_description: rCM model for Wan2.1
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: rCM-Wan 720p
|
| 3 |
+
emoji: 🐠
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
|
|
|
| 11 |
short_description: rCM model for Wan2.1
|
| 12 |
---
|
| 13 |
|
| 14 |
+
This demo uses the unofficial rCM models for Wan from worstcoder/rcm-Wan.
|
app.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import time
|
| 4 |
+
import requests
|
| 5 |
+
from wan2pt1_t2v_rcm_infer import inference, prepare_models
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
import random
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
import gc
|
| 10 |
+
import torch
|
| 11 |
+
from imaginaire.lazy_config import LazyCall as L, LazyDict, instantiate
|
| 12 |
+
from wan2pt1_t2v_rcm_infer import load_dit_model, WanModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import flash_attn
|
| 16 |
+
print("flash_attn version: ", flash_attn.__version__)
|
| 17 |
+
|
| 18 |
+
WAN2PT1_1PT3B_T2V: LazyDict = L(WanModel)(
|
| 19 |
+
dim=1536,
|
| 20 |
+
eps=1e-06,
|
| 21 |
+
ffn_dim=8960,
|
| 22 |
+
freq_dim=256,
|
| 23 |
+
in_dim=16,
|
| 24 |
+
model_type="t2v",
|
| 25 |
+
num_heads=12,
|
| 26 |
+
num_layers=30,
|
| 27 |
+
out_dim=16,
|
| 28 |
+
text_len=512,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
WAN2PT1_14B_T2V: LazyDict = L(WanModel)(
|
| 32 |
+
dim=5120,
|
| 33 |
+
eps=1e-06,
|
| 34 |
+
ffn_dim=13824,
|
| 35 |
+
freq_dim=256,
|
| 36 |
+
in_dim=16,
|
| 37 |
+
model_type="t2v",
|
| 38 |
+
num_heads=40,
|
| 39 |
+
num_layers=40,
|
| 40 |
+
out_dim=16,
|
| 41 |
+
text_len=512,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
dit_configs = {"1.3B": WAN2PT1_1PT3B_T2V, "14B": WAN2PT1_14B_T2V}
|
| 45 |
+
|
| 46 |
+
dit_path_14B_720p = hf_hub_download(
|
| 47 |
+
repo_id="worstcoder/rcm-Wan",
|
| 48 |
+
filename="rCM_Wan2.1_T2V_14B_720p.pt",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
vae_path = hf_hub_download(
|
| 52 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
| 53 |
+
filename="Wan2.1_VAE.pth"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
text_encoder_path = hf_hub_download(
|
| 57 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
| 58 |
+
filename="models_t5_umt5-xxl-enc-bf16.pth"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
net_14B_720p, tokenizer, t5_encoder = prepare_models(dit_path_14B_720p, vae_path, text_encoder_path)
|
| 62 |
+
print("Loaded models")
|
| 63 |
+
gc.collect()
|
| 64 |
+
|
| 65 |
+
def random_seed():
|
| 66 |
+
return random.randint(0, 2**32 - 1)
|
| 67 |
+
|
| 68 |
+
@spaces.GPU(duration=360)
|
| 69 |
+
def generate_videos(prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed):
|
| 70 |
+
if seed is None:
|
| 71 |
+
seed = random.randint(0, 2**32 - 1)
|
| 72 |
+
|
| 73 |
+
if "480p" in model_size:
|
| 74 |
+
resolution = "480p"
|
| 75 |
+
else:
|
| 76 |
+
resolution = "720p"
|
| 77 |
+
|
| 78 |
+
args = SimpleNamespace(
|
| 79 |
+
prompt=prompt,
|
| 80 |
+
model_size=model_size,
|
| 81 |
+
num_steps=num_steps,
|
| 82 |
+
num_samples=num_samples,
|
| 83 |
+
sigma_max=sigma_max,
|
| 84 |
+
num_frames=77,
|
| 85 |
+
resolution=resolution,
|
| 86 |
+
aspect_ratio=aspect_ratio,
|
| 87 |
+
seed=seed,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
video_list = inference(args, net_14B_720p, tokenizer, t5_encoder)
|
| 92 |
+
|
| 93 |
+
if aspect_ratio == "16:9":
|
| 94 |
+
return video_list, None
|
| 95 |
+
else:
|
| 96 |
+
return None, video_list
|
| 97 |
+
|
| 98 |
+
def update_num_samples(model_choice):
|
| 99 |
+
if model_choice == "rCM-Wan2.1-T2V-1.3B-480p":
|
| 100 |
+
options = [1, 2, 3, 4]
|
| 101 |
+
elif model_choice == "rCM-Wan2.1-T2V-14B-480p":
|
| 102 |
+
options = [1, 2, 3]
|
| 103 |
+
else:
|
| 104 |
+
options = [1, 2, 3]
|
| 105 |
+
return gr.Dropdown(choices=options, value=options[0], label="num_samples")
|
| 106 |
+
|
| 107 |
+
def update_sigma_max(model_choice):
|
| 108 |
+
if "480p" in model_choice:
|
| 109 |
+
options = [80, 120, 200, 400, 800, 1600]
|
| 110 |
+
else:
|
| 111 |
+
options = [120, 200, 400, 800, 1600]
|
| 112 |
+
return gr.Dropdown(choices=options, value=options[0], label="sigma_max")
|
| 113 |
+
|
| 114 |
+
with gr.Blocks() as demo:
|
| 115 |
+
gr.Markdown("## rCM model for Wan")
|
| 116 |
+
|
| 117 |
+
examples = [
|
| 118 |
+
["A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."],
|
| 119 |
+
["A close-up shot captures a steaming hot pot brimming with vegetables and dumplings, set on a rustic wooden table. The camera focuses on the bubbling broth as a woman, dressed in a light, patterned blouse, reaches in with chopsticks to lift a tender leaf of cabbage from the simmering mixture. Steam rises around her as she leans back slightly, her warm smile reflecting satisfaction and joy. Her movements are smooth and deliberate, showcasing her comfort and familiarity with the dining process. The background includes a small bowl of dipping sauce and a clay pot, adding to the cozy, communal dining atmosphere."],
|
| 120 |
+
["A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond."]
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
with gr.Row():
|
| 124 |
+
with gr.Column(scale=1):
|
| 125 |
+
with gr.Row():
|
| 126 |
+
prompt = gr.Textbox(label="Text prompt", placeholder="Text prompt for videos")
|
| 127 |
+
model_size = gr.Radio(
|
| 128 |
+
["rCM-Wan2.1-T2V-14B-720p"],
|
| 129 |
+
value="rCM-Wan2.1-T2V-14B-720p",
|
| 130 |
+
label="Model"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
with gr.Row():
|
| 134 |
+
num_samples = gr.Dropdown([1, 2], value=1, label="num_samples")
|
| 135 |
+
aspect_ratio = gr.Radio(["16:9", "9:16"], value="16:9", label="aspect_ratio")
|
| 136 |
+
sigma_max = gr.Dropdown([120, 200, 400, 800, 1600], value=120, label="sigma_max")
|
| 137 |
+
|
| 138 |
+
with gr.Row():
|
| 139 |
+
num_steps = gr.Slider(1, 4, value=4, step=1, label="num_steps")
|
| 140 |
+
seed = gr.Number(label="seed", value=random_seed(), interactive=True)
|
| 141 |
+
|
| 142 |
+
with gr.Row():
|
| 143 |
+
regenerate_btn = gr.Button("New Seed")
|
| 144 |
+
run_btn = gr.Button("Generate Videos")
|
| 145 |
+
|
| 146 |
+
with gr.Row():
|
| 147 |
+
gr.Examples(
|
| 148 |
+
examples,
|
| 149 |
+
inputs=[prompt],
|
| 150 |
+
label="Example prompts"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
with gr.Column(scale=1):
|
| 154 |
+
video_16_9 = gr.Video(label="Videos 16:9", width=832)
|
| 155 |
+
video_9_16 = gr.Video(label="Videos 9:16", width=480, visible=False)
|
| 156 |
+
|
| 157 |
+
def show_video(aspect):
|
| 158 |
+
if aspect == "16:9":
|
| 159 |
+
return gr.update(visible=True), gr.update(visible=False, value=None)
|
| 160 |
+
else:
|
| 161 |
+
return gr.update(visible=False, value=None), gr.update(visible=True)
|
| 162 |
+
|
| 163 |
+
model_size.change(fn=update_num_samples, inputs=model_size, outputs=num_samples)
|
| 164 |
+
model_size.change(fn=update_sigma_max, inputs=model_size, outputs=sigma_max)
|
| 165 |
+
|
| 166 |
+
aspect_ratio.change(show_video, inputs=aspect_ratio, outputs=[video_16_9, video_9_16])
|
| 167 |
+
regenerate_btn.click(fn=random_seed, outputs=seed)
|
| 168 |
+
|
| 169 |
+
run_btn.click(
|
| 170 |
+
fn=generate_videos,
|
| 171 |
+
inputs=[prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed],
|
| 172 |
+
outputs=[video_16_9, video_9_16],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
demo.launch()
|
imaginaire/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/callbacks/every_n.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from abc import abstractmethod
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from imaginaire.model import ImaginaireModel
|
| 21 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 22 |
+
from imaginaire.utils import distributed, log
|
| 23 |
+
from imaginaire.utils.callback import Callback
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EveryN(Callback):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
every_n: int | None = None,
|
| 30 |
+
step_size: int = 1,
|
| 31 |
+
barrier_after_run: bool = True,
|
| 32 |
+
run_at_start: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""Constructor for `EveryN`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
every_n (int): Frequency with which callback is run during training.
|
| 38 |
+
step_size (int): Size of iteration step count. Default 1.
|
| 39 |
+
barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts.
|
| 40 |
+
run_at_start (bool): Whether to run at the beginning of training. Default False.
|
| 41 |
+
"""
|
| 42 |
+
self.every_n = every_n
|
| 43 |
+
if self.every_n == 0:
|
| 44 |
+
log.warning(
|
| 45 |
+
f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.step_size = step_size
|
| 49 |
+
self.barrier_after_run = barrier_after_run
|
| 50 |
+
self.run_at_start = run_at_start
|
| 51 |
+
|
| 52 |
+
def on_training_step_end(
|
| 53 |
+
self,
|
| 54 |
+
model: ImaginaireModel,
|
| 55 |
+
data_batch: dict[str, torch.Tensor],
|
| 56 |
+
output_batch: dict[str, torch.Tensor],
|
| 57 |
+
loss: torch.Tensor,
|
| 58 |
+
iteration: int = 0,
|
| 59 |
+
) -> None:
|
| 60 |
+
# every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training
|
| 61 |
+
if self.every_n != 0:
|
| 62 |
+
trainer = self.trainer
|
| 63 |
+
global_step = iteration // self.step_size
|
| 64 |
+
should_run = (iteration == 1 and self.run_at_start) or (
|
| 65 |
+
global_step % self.every_n == 0
|
| 66 |
+
) # (self.every_n - 1)
|
| 67 |
+
if should_run:
|
| 68 |
+
log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}")
|
| 69 |
+
self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration)
|
| 70 |
+
log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}")
|
| 71 |
+
# add necessary barrier to avoid timeout
|
| 72 |
+
if self.barrier_after_run:
|
| 73 |
+
distributed.barrier()
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def every_n_impl(
|
| 77 |
+
self,
|
| 78 |
+
trainer: ImaginaireTrainer,
|
| 79 |
+
model: ImaginaireModel,
|
| 80 |
+
data_batch: dict[str, torch.Tensor],
|
| 81 |
+
output_batch: dict[str, torch.Tensor],
|
| 82 |
+
loss: torch.Tensor,
|
| 83 |
+
iteration: int,
|
| 84 |
+
) -> None: ...
|
imaginaire/callbacks/manual_gc.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
|
| 18 |
+
from imaginaire.callbacks.every_n import EveryN
|
| 19 |
+
from imaginaire.utils import log
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ManualGarbageCollection(EveryN):
|
| 23 |
+
"""
|
| 24 |
+
Disable auto gc and manually trigger garbage collection every N iterations
|
| 25 |
+
It is super useful for large scale training to reduce gpu sync time!
|
| 26 |
+
Can reach 50% speedup.
|
| 27 |
+
|
| 28 |
+
It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses.
|
| 29 |
+
|
| 30 |
+
We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, *args, warm_up: int = 5, **kwargs):
|
| 34 |
+
kwargs["barrier_after_run"] = False
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
self.counter = 0
|
| 38 |
+
self.warm = warm_up
|
| 39 |
+
|
| 40 |
+
def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration):
|
| 41 |
+
del trainer, model, data_batch, output_batch, loss
|
| 42 |
+
self.counter += 1
|
| 43 |
+
if self.counter < self.warm:
|
| 44 |
+
return
|
| 45 |
+
if self.counter == self.warm:
|
| 46 |
+
gc.disable()
|
| 47 |
+
log.critical("Garbage collection disabled")
|
| 48 |
+
|
| 49 |
+
gc.collect(1)
|
imaginaire/config.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Training config system for Imaginare4"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
from typing import Any, TypeVar
|
| 22 |
+
|
| 23 |
+
import attrs
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.data
|
| 26 |
+
|
| 27 |
+
from imaginaire.model import ImaginaireModel
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from megatron.core import ModelParallelConfig
|
| 31 |
+
|
| 32 |
+
USE_MEGATRON = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
USE_MEGATRON = False
|
| 35 |
+
print("Megatron-core is not installed.")
|
| 36 |
+
|
| 37 |
+
import builtins
|
| 38 |
+
|
| 39 |
+
from imaginaire.lazy_config import LazyCall as L
|
| 40 |
+
from imaginaire.lazy_config import LazyDict
|
| 41 |
+
from imaginaire.utils import callback, distributed
|
| 42 |
+
from imaginaire.utils.misc import Color
|
| 43 |
+
|
| 44 |
+
T = TypeVar("T")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _is_attrs_instance(obj: object) -> bool:
|
| 48 |
+
"""
|
| 49 |
+
Helper function to check if an object is an instance of an attrs-defined class.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
obj: The object to check.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
bool: True if the object is an instance of an attrs-defined class, False otherwise.
|
| 56 |
+
"""
|
| 57 |
+
return hasattr(obj, "__attrs_attrs__")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def make_freezable(cls: T) -> T:
|
| 61 |
+
"""
|
| 62 |
+
A decorator that adds the capability to freeze instances of an attrs-defined class.
|
| 63 |
+
|
| 64 |
+
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
|
| 65 |
+
to hack on a "_is_frozen" attribute.
|
| 66 |
+
|
| 67 |
+
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
|
| 68 |
+
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
|
| 69 |
+
any attrs-defined objects that are attributes of the class.
|
| 70 |
+
|
| 71 |
+
Usage:
|
| 72 |
+
@make_freezable
|
| 73 |
+
@attrs.define(slots=False)
|
| 74 |
+
class MyClass:
|
| 75 |
+
attribute1: int
|
| 76 |
+
attribute2: str
|
| 77 |
+
|
| 78 |
+
obj = MyClass(1, 'a')
|
| 79 |
+
obj.freeze() # Freeze the instance
|
| 80 |
+
obj.attribute1 = 2 # Raises AttributeError
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
cls: The class to be decorated.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
The decorated class with added freezing capability.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if not hasattr(cls, "__dict__"):
|
| 90 |
+
raise TypeError(
|
| 91 |
+
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
|
| 92 |
+
"class was defined with `@attrs.define(slots=False)`"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
original_setattr = cls.__setattr__
|
| 96 |
+
|
| 97 |
+
def setattr_override(self, key, value) -> None:
|
| 98 |
+
"""
|
| 99 |
+
Override __setattr__ to allow modifications during initialization
|
| 100 |
+
and prevent modifications once the instance is frozen.
|
| 101 |
+
"""
|
| 102 |
+
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
|
| 103 |
+
raise AttributeError("Cannot modify frozen instance")
|
| 104 |
+
original_setattr(self, key, value) # type: ignore
|
| 105 |
+
|
| 106 |
+
cls.__setattr__ = setattr_override # type: ignore
|
| 107 |
+
|
| 108 |
+
def freeze(self: object) -> None:
|
| 109 |
+
"""
|
| 110 |
+
Freeze the instance and all its attrs-defined attributes.
|
| 111 |
+
"""
|
| 112 |
+
for _, value in attrs.asdict(self, recurse=False).items():
|
| 113 |
+
if _is_attrs_instance(value) and hasattr(value, "freeze"):
|
| 114 |
+
value.freeze()
|
| 115 |
+
self._is_frozen = True # type: ignore
|
| 116 |
+
|
| 117 |
+
cls.freeze = freeze # type: ignore
|
| 118 |
+
|
| 119 |
+
return cls
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Recursively pretty prints attrs objects with color.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
assert attrs.has(obj.__class__)
|
| 128 |
+
|
| 129 |
+
lines: list[str] = []
|
| 130 |
+
for attribute in attrs.fields(obj.__class__):
|
| 131 |
+
value = getattr(obj, attribute.name)
|
| 132 |
+
if attrs.has(value.__class__):
|
| 133 |
+
if use_color:
|
| 134 |
+
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
|
| 135 |
+
else:
|
| 136 |
+
lines.append(" " * indent + "* " + attribute.name + ":")
|
| 137 |
+
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
|
| 138 |
+
else:
|
| 139 |
+
if use_color:
|
| 140 |
+
lines.append(
|
| 141 |
+
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
|
| 145 |
+
return "\n".join(lines)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def pretty_print_overrides(overrides: list[str] | None = None, use_color: bool = False) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Pretty prints overrides.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
lines: list[str] = []
|
| 154 |
+
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
|
| 155 |
+
for override in overrides:
|
| 156 |
+
if override == "--":
|
| 157 |
+
continue
|
| 158 |
+
if override.startswith("~"):
|
| 159 |
+
attribute_name = override[1:]
|
| 160 |
+
attribute_value = None
|
| 161 |
+
else:
|
| 162 |
+
attribute_name, attribute_value = override.split("=")
|
| 163 |
+
if use_color:
|
| 164 |
+
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
|
| 165 |
+
else:
|
| 166 |
+
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
|
| 167 |
+
|
| 168 |
+
return "\n".join(lines)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@make_freezable
|
| 172 |
+
@attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
|
| 173 |
+
class ObjectStoreConfig:
|
| 174 |
+
# Whether the file I/O is from object store instead of local disk.
|
| 175 |
+
enabled: bool = False
|
| 176 |
+
# Path to the object store credentials file.
|
| 177 |
+
credentials: str = ""
|
| 178 |
+
# Object store bucket to read from / write to the objects.
|
| 179 |
+
bucket: str = ""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@make_freezable
|
| 183 |
+
@attrs.define(slots=False)
|
| 184 |
+
class JobConfig:
|
| 185 |
+
# Project name.
|
| 186 |
+
project: str = ""
|
| 187 |
+
# Experiment name.
|
| 188 |
+
group: str = ""
|
| 189 |
+
# Run/job name.
|
| 190 |
+
name: str = ""
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def path(self) -> str:
|
| 194 |
+
return f"{self.project}/{self.group}/{self.name}"
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def path_local(self) -> str:
|
| 198 |
+
local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "checkpoints")
|
| 199 |
+
return f"{local_root}/{self.path}"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@make_freezable
|
| 203 |
+
@attrs.define(slots=False)
|
| 204 |
+
class EMAConfig:
|
| 205 |
+
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 206 |
+
enabled: bool = False
|
| 207 |
+
# EMA decay rate.
|
| 208 |
+
beta: float = 0.9999
|
| 209 |
+
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 210 |
+
torch_compile_buffer_renaming: bool = False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@make_freezable
|
| 214 |
+
@attrs.define(slots=False)
|
| 215 |
+
class PowerEMAConfig:
|
| 216 |
+
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 217 |
+
enabled: bool = False
|
| 218 |
+
# EDM2 paper EMA decay rate.
|
| 219 |
+
s: float = 0.1
|
| 220 |
+
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 221 |
+
torch_compile_buffer_renaming: bool = False
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@make_freezable
|
| 225 |
+
@attrs.define(slots=False)
|
| 226 |
+
class DDPConfig:
|
| 227 |
+
# Traverse the computation graph to find parameters that don't receive gradients.
|
| 228 |
+
find_unused_parameters: bool = False
|
| 229 |
+
# Set to True if the computation graph does not change during the whole training loop.
|
| 230 |
+
static_graph: bool = True
|
| 231 |
+
# Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
|
| 232 |
+
broadcast_buffers: bool = True
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@make_freezable
|
| 236 |
+
@attrs.define(slots=False)
|
| 237 |
+
class CuDNNConfig:
|
| 238 |
+
# Set to True for better reproducibility of the results (only using deterministic cudnn functions).
|
| 239 |
+
deterministic: bool = False
|
| 240 |
+
# If set to True, cudnn will benchmark several algorithms and pick the fastest one.
|
| 241 |
+
benchmark: bool = True
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@make_freezable
|
| 245 |
+
@attrs.define(slots=False)
|
| 246 |
+
class JITConfig:
|
| 247 |
+
# Enable exporting a JIT compiled model.
|
| 248 |
+
enabled: bool = False
|
| 249 |
+
# Input tensor shape, for example input.
|
| 250 |
+
input_shape: list[int] | None = None
|
| 251 |
+
# Device to compile onto.
|
| 252 |
+
device: str = "cuda"
|
| 253 |
+
# # Data type to compile onto.
|
| 254 |
+
dtype: str = "bfloat16"
|
| 255 |
+
# Strict mode for PyTorch JIT.
|
| 256 |
+
strict: bool = True
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@make_freezable
|
| 260 |
+
@attrs.define(slots=False)
|
| 261 |
+
class CheckpointConfig:
|
| 262 |
+
# possible checkpoint class
|
| 263 |
+
type: dict | None = None
|
| 264 |
+
# for dcp, whether to use async mode
|
| 265 |
+
dcp_async_mode_enabled: bool = False
|
| 266 |
+
# Save the checkpoint every N iterations.
|
| 267 |
+
save_iter: int = 999999999
|
| 268 |
+
# Path of model weights to resume the checkpoint from.
|
| 269 |
+
load_path: str = ""
|
| 270 |
+
# Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
|
| 271 |
+
load_training_state: bool = False
|
| 272 |
+
# Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
|
| 273 |
+
only_load_scheduler_state: bool = False
|
| 274 |
+
# Load state_dict to the models in strict mode.
|
| 275 |
+
strict_resume: bool = True
|
| 276 |
+
# Configs for JIT compiling EMA model.
|
| 277 |
+
jit: JITConfig = attrs.field(factory=JITConfig)
|
| 278 |
+
# Print detailed information during checkpoint saving/loading.
|
| 279 |
+
verbose: bool = True
|
| 280 |
+
# keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
|
| 281 |
+
keys_not_to_resume: list[str] = [] # noqa: RUF008
|
| 282 |
+
# Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
|
| 283 |
+
broadcast_via_filesystem: bool = False
|
| 284 |
+
load_ema_to_reg: bool = False
|
| 285 |
+
# In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
|
| 286 |
+
dcp_allow_mismatched_size: bool = False
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@make_freezable
|
| 290 |
+
@attrs.define(slots=False)
|
| 291 |
+
class NVTXConfig:
|
| 292 |
+
"""Config for NVTX ranges used in the main training loop.
|
| 293 |
+
|
| 294 |
+
See tutorials/nanogpt for more details on how to integrate profiling into your model."""
|
| 295 |
+
|
| 296 |
+
# Enable the NVTX ranges.
|
| 297 |
+
enabled: bool = False
|
| 298 |
+
# Synchronize everything in each NVTX range.
|
| 299 |
+
cuda_synchronize: bool = False
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@make_freezable
|
| 303 |
+
@attrs.define(slots=False)
|
| 304 |
+
class Profiling:
|
| 305 |
+
enable_profiling: bool = False
|
| 306 |
+
enable_memory_snapshot: bool = False
|
| 307 |
+
profile_freq: int = 1
|
| 308 |
+
first_n_rank: int = 8 # -1 means all ranks, n means first n ranks dumpy profiling info
|
| 309 |
+
record_shape: bool = True
|
| 310 |
+
profile_memory: bool = True
|
| 311 |
+
with_stack: bool = True
|
| 312 |
+
with_modules: bool = True
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@make_freezable
|
| 316 |
+
@attrs.define(slots=False)
|
| 317 |
+
class TrainerConfig:
|
| 318 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 319 |
+
|
| 320 |
+
type: builtins.type[ImaginaireTrainer] = ImaginaireTrainer
|
| 321 |
+
# Set the callback class.
|
| 322 |
+
# Defaults to the callbacks below.
|
| 323 |
+
callbacks: LazyDict[dict[str, callback.Callback]] = LazyDict( # noqa: RUF009
|
| 324 |
+
dict(
|
| 325 |
+
ema=L(callback.EMAModelCallback)(),
|
| 326 |
+
progress_bar=L(callback.ProgressBarCallback)(),
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
# distributed parallelism strategy
|
| 330 |
+
distributed_parallelism: str = "ddp"
|
| 331 |
+
# Distributed data parallel configs.
|
| 332 |
+
ddp: DDPConfig = attrs.field(factory=DDPConfig)
|
| 333 |
+
# cuDNN configs.
|
| 334 |
+
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
|
| 335 |
+
# Set the random seed.
|
| 336 |
+
seed: int = 0
|
| 337 |
+
# Gradient scaler arguments (for torch.amp.GradScaler).
|
| 338 |
+
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
|
| 339 |
+
# Maximum number of iterations to train the model.
|
| 340 |
+
max_iter: int = 999999999
|
| 341 |
+
# Maximum number of iterations to validate the model. If None, validate on the entire dataset.
|
| 342 |
+
max_val_iter: int | None = None
|
| 343 |
+
# How often we log the training stats.
|
| 344 |
+
logging_iter: int = 100
|
| 345 |
+
# Whether we want to run the validation routines.
|
| 346 |
+
run_validation: bool = True
|
| 347 |
+
# How often we evaluate on the validation set.
|
| 348 |
+
validation_iter: int = 999999999
|
| 349 |
+
# Kill the process after N seconds since the last iteration (usually means dead job).
|
| 350 |
+
timeout_period: int = 999999999
|
| 351 |
+
# Tensor memory organization format.
|
| 352 |
+
memory_format: torch.memory_format = torch.preserve_format
|
| 353 |
+
# Gradient accumulation (update step every N iteration).
|
| 354 |
+
grad_accum_iter: int = 1
|
| 355 |
+
# Profiling config
|
| 356 |
+
profiling: Profiling = attrs.field(factory=Profiling)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@make_freezable
|
| 360 |
+
@attrs.define(slots=False)
|
| 361 |
+
class Config:
|
| 362 |
+
"""Config for an imaginaire4 job.
|
| 363 |
+
|
| 364 |
+
See /README.md/Configuration System for more info.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
# Model configs.
|
| 368 |
+
model: LazyDict[ImaginaireModel]
|
| 369 |
+
# Optimizer configs.
|
| 370 |
+
optimizer: LazyDict[torch.optim.Optimizer]
|
| 371 |
+
# Scheduler configs.
|
| 372 |
+
scheduler: LazyDict[torch.optim.lr_scheduler.LRScheduler]
|
| 373 |
+
# Training data configs.
|
| 374 |
+
dataloader_train: LazyDict[torch.utils.data.DataLoader]
|
| 375 |
+
# Validation data configs.
|
| 376 |
+
dataloader_val: LazyDict[torch.utils.data.DataLoader]
|
| 377 |
+
|
| 378 |
+
# Training job configs.
|
| 379 |
+
job: JobConfig = attrs.field(factory=JobConfig)
|
| 380 |
+
|
| 381 |
+
# Trainer configs.
|
| 382 |
+
trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
|
| 383 |
+
|
| 384 |
+
if USE_MEGATRON:
|
| 385 |
+
# Megatron-Core configs
|
| 386 |
+
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
|
| 387 |
+
else:
|
| 388 |
+
model_parallel: None = None
|
| 389 |
+
|
| 390 |
+
# Checkpointer configs.
|
| 391 |
+
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
|
| 392 |
+
|
| 393 |
+
def pretty_print(self, use_color: bool = False) -> str:
|
| 394 |
+
return _pretty_print_attrs_instance(self, 0, use_color)
|
| 395 |
+
|
| 396 |
+
def to_dict(self) -> dict[str, Any]:
|
| 397 |
+
return attrs.asdict(self)
|
| 398 |
+
|
| 399 |
+
def validate(self) -> None:
|
| 400 |
+
"""Validate that the config has all required fields."""
|
| 401 |
+
|
| 402 |
+
# broadcast job.name across all ranks to make sure it is consistent
|
| 403 |
+
# otherwise, unaligned job names leads unaligned path to save checkpoints
|
| 404 |
+
job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
|
| 405 |
+
distributed.broadcast(job_name_tensor, 0)
|
| 406 |
+
self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
|
| 407 |
+
|
| 408 |
+
assert self.job.project != ""
|
| 409 |
+
assert self.job.group != ""
|
| 410 |
+
assert self.job.name != ""
|
imaginaire/lazy_config/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
|
| 20 |
+
from imaginaire.lazy_config.instantiate import instantiate
|
| 21 |
+
from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict
|
| 22 |
+
from imaginaire.lazy_config.omegaconf_patch import to_object
|
| 23 |
+
|
| 24 |
+
OmegaConf.to_object = to_object
|
| 25 |
+
|
| 26 |
+
PLACEHOLDER = None
|
| 27 |
+
|
| 28 |
+
__all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fixup_module_metadata(module_name, namespace, keys=None):
|
| 35 |
+
"""
|
| 36 |
+
Fix the __qualname__ of module members to be their exported api name, so
|
| 37 |
+
when they are referenced in docs, sphinx can find them. Reference:
|
| 38 |
+
https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
|
| 39 |
+
"""
|
| 40 |
+
if not DOC_BUILDING:
|
| 41 |
+
return
|
| 42 |
+
seen_ids = set()
|
| 43 |
+
|
| 44 |
+
def fix_one(qualname, name, obj):
|
| 45 |
+
# avoid infinite recursion (relevant when using
|
| 46 |
+
# typing.Generic, for example)
|
| 47 |
+
if id(obj) in seen_ids:
|
| 48 |
+
return
|
| 49 |
+
seen_ids.add(id(obj))
|
| 50 |
+
|
| 51 |
+
mod = getattr(obj, "__module__", None)
|
| 52 |
+
if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
|
| 53 |
+
obj.__module__ = module_name
|
| 54 |
+
# Modules, unlike everything else in Python, put fully-qualitied
|
| 55 |
+
# names into their __name__ attribute. We check for "." to avoid
|
| 56 |
+
# rewriting these.
|
| 57 |
+
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
| 58 |
+
obj.__name__ = name
|
| 59 |
+
obj.__qualname__ = qualname
|
| 60 |
+
if isinstance(obj, type):
|
| 61 |
+
for attr_name, attr_value in obj.__dict__.items():
|
| 62 |
+
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
| 63 |
+
|
| 64 |
+
if keys is None:
|
| 65 |
+
keys = namespace.keys()
|
| 66 |
+
for objname in keys:
|
| 67 |
+
if not objname.startswith("_"):
|
| 68 |
+
obj = namespace[objname]
|
| 69 |
+
fix_one(objname, objname, obj)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 73 |
+
del fixup_module_metadata
|
imaginaire/lazy_config/file_io.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
|
| 17 |
+
from iopath.common.file_io import PathManager as PathManagerBase
|
| 18 |
+
|
| 19 |
+
__all__ = ["PathHandler", "PathManager"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PathManager = PathManagerBase()
|
| 23 |
+
PathManager.register_handler(HTTPURLHandler())
|
| 24 |
+
PathManager.register_handler(OneDrivePathHandler())
|
imaginaire/lazy_config/instantiate.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import collections.abc as abc
|
| 17 |
+
import dataclasses
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import attrs
|
| 21 |
+
|
| 22 |
+
from imaginaire.lazy_config.registry import _convert_target_to_string, locate
|
| 23 |
+
from imaginaire.utils import log
|
| 24 |
+
|
| 25 |
+
__all__ = ["dump_dataclass", "instantiate"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_dataclass_or_attrs(target):
|
| 29 |
+
return dataclasses.is_dataclass(target) or attrs.has(target)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def dump_dataclass(obj: Any):
|
| 33 |
+
"""
|
| 34 |
+
Dump a dataclass recursively into a dict that can be later instantiated.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
obj: a dataclass object
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict
|
| 41 |
+
"""
|
| 42 |
+
assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
|
| 43 |
+
"dump_dataclass() requires an instance of a dataclass."
|
| 44 |
+
)
|
| 45 |
+
ret = {"_target_": _convert_target_to_string(type(obj))}
|
| 46 |
+
for f in dataclasses.fields(obj):
|
| 47 |
+
v = getattr(obj, f.name)
|
| 48 |
+
if dataclasses.is_dataclass(v):
|
| 49 |
+
v = dump_dataclass(v)
|
| 50 |
+
if isinstance(v, (list, tuple)):
|
| 51 |
+
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
|
| 52 |
+
ret[f.name] = v
|
| 53 |
+
return ret
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def instantiate(cfg, *args, **kwargs):
|
| 57 |
+
"""
|
| 58 |
+
Recursively instantiate objects defined in dictionaries by
|
| 59 |
+
"_target_" and arguments.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
cfg: a dict-like object with "_target_" that defines the caller, and
|
| 63 |
+
other keys that define the arguments
|
| 64 |
+
args: Optional positional parameters pass-through.
|
| 65 |
+
kwargs: Optional named parameters pass-through.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
object instantiated by cfg
|
| 69 |
+
"""
|
| 70 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 71 |
+
|
| 72 |
+
if isinstance(cfg, ListConfig):
|
| 73 |
+
lst = [instantiate(x) for x in cfg]
|
| 74 |
+
return ListConfig(lst, flags={"allow_objects": True})
|
| 75 |
+
if isinstance(cfg, list):
|
| 76 |
+
# Specialize for list, because many classes take
|
| 77 |
+
# list[objects] as arguments, such as ResNet, DatasetMapper
|
| 78 |
+
return [instantiate(x) for x in cfg]
|
| 79 |
+
|
| 80 |
+
# If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
|
| 81 |
+
# instantiate it to the actual dataclass.
|
| 82 |
+
if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
|
| 83 |
+
return OmegaConf.to_object(cfg)
|
| 84 |
+
|
| 85 |
+
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
|
| 86 |
+
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
|
| 87 |
+
# but faster: https://github.com/facebookresearch/hydra/issues/1200
|
| 88 |
+
is_recursive = getattr(cfg, "_recursive_", True)
|
| 89 |
+
if is_recursive:
|
| 90 |
+
cfg = {k: instantiate(v) for k, v in cfg.items()}
|
| 91 |
+
else:
|
| 92 |
+
cfg = {k: v for k, v in cfg.items()}
|
| 93 |
+
# pop the _recursive_ key to avoid passing it as a parameter
|
| 94 |
+
if "_recursive_" in cfg:
|
| 95 |
+
cfg.pop("_recursive_")
|
| 96 |
+
cls = cfg.pop("_target_")
|
| 97 |
+
cls = instantiate(cls)
|
| 98 |
+
|
| 99 |
+
if isinstance(cls, str):
|
| 100 |
+
cls_name = cls
|
| 101 |
+
cls = locate(cls_name)
|
| 102 |
+
assert cls is not None, cls_name
|
| 103 |
+
else:
|
| 104 |
+
try:
|
| 105 |
+
cls_name = cls.__module__ + "." + cls.__qualname__
|
| 106 |
+
except Exception:
|
| 107 |
+
# target could be anything, so the above could fail
|
| 108 |
+
cls_name = str(cls)
|
| 109 |
+
assert callable(cls), f"_target_ {cls} does not define a callable object"
|
| 110 |
+
try:
|
| 111 |
+
# override config with kwargs
|
| 112 |
+
instantiate_kwargs = {}
|
| 113 |
+
instantiate_kwargs.update(cfg)
|
| 114 |
+
instantiate_kwargs.update(kwargs)
|
| 115 |
+
return cls(*args, **instantiate_kwargs)
|
| 116 |
+
except TypeError:
|
| 117 |
+
log.error(f"Error when instantiating {cls_name}!")
|
| 118 |
+
raise
|
| 119 |
+
return cfg # return as-is if don't know what to do
|
imaginaire/lazy_config/lazy.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import ast
|
| 17 |
+
import builtins
|
| 18 |
+
import collections.abc as abc
|
| 19 |
+
import importlib
|
| 20 |
+
import inspect
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import pickle
|
| 24 |
+
import uuid
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
from contextlib import contextmanager
|
| 27 |
+
from copy import deepcopy
|
| 28 |
+
from dataclasses import is_dataclass
|
| 29 |
+
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast
|
| 30 |
+
|
| 31 |
+
import attrs
|
| 32 |
+
import yaml
|
| 33 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 34 |
+
|
| 35 |
+
from imaginaire.utils import log
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import dill as dill_pickle
|
| 39 |
+
except ImportError:
|
| 40 |
+
dill_pickle = None
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import cloudpickle
|
| 44 |
+
except ImportError:
|
| 45 |
+
cloudpickle = None
|
| 46 |
+
|
| 47 |
+
from imaginaire.lazy_config.file_io import PathManager
|
| 48 |
+
from imaginaire.lazy_config.registry import _convert_target_to_string
|
| 49 |
+
|
| 50 |
+
__all__ = ["LazyCall", "LazyConfig", "LazyDict"]
|
| 51 |
+
|
| 52 |
+
T = TypeVar("T")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def sort_dict(d: dict[str, Any]) -> OrderedDict[str, Any]:
|
| 56 |
+
return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
|
| 60 |
+
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def sort_recursive(obj: dict[str, Any] | list[Any] | Any) -> OrderedDict[str, Any] | list[Any] | Any:
|
| 64 |
+
if isinstance(obj, dict):
|
| 65 |
+
return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
|
| 66 |
+
elif isinstance(obj, list):
|
| 67 |
+
return [sort_recursive(item) for item in obj]
|
| 68 |
+
return obj
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
yaml.add_representer(OrderedDict, dict_representer)
|
| 72 |
+
|
| 73 |
+
OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
|
| 74 |
+
OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_default_params(cls_or_func):
|
| 78 |
+
if callable(cls_or_func):
|
| 79 |
+
# inspect signature for function
|
| 80 |
+
signature = inspect.signature(cls_or_func)
|
| 81 |
+
else:
|
| 82 |
+
# inspect signature for class
|
| 83 |
+
signature = inspect.signature(cls_or_func.__init__)
|
| 84 |
+
params = signature.parameters
|
| 85 |
+
default_params = {
|
| 86 |
+
name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
|
| 87 |
+
}
|
| 88 |
+
return default_params
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if TYPE_CHECKING:
|
| 92 |
+
# Have `LazyDict[T]` behave as `T`, so that attribute access works. Ideally, it
|
| 93 |
+
# would be a subclass of `T`, but this doesn't seem to be possible in the type
|
| 94 |
+
# system yet.
|
| 95 |
+
LazyDict: TypeAlias = T
|
| 96 |
+
else:
|
| 97 |
+
LazyDict = DictConfig
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class LazyCall(Generic[T]):
|
| 101 |
+
"""
|
| 102 |
+
Wrap a callable so that when it's called, the call will not be executed,
|
| 103 |
+
but returns a dict that describes the call.
|
| 104 |
+
|
| 105 |
+
LazyCall object has to be called with only keyword arguments. Positional
|
| 106 |
+
arguments are not yet supported.
|
| 107 |
+
|
| 108 |
+
Examples:
|
| 109 |
+
::
|
| 110 |
+
from detectron2.config import instantiate, LazyCall
|
| 111 |
+
|
| 112 |
+
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
|
| 113 |
+
layer_cfg.out_channels = 64 # can edit it afterwards
|
| 114 |
+
layer = instantiate(layer_cfg)
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, target: type[T]):
|
| 118 |
+
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
|
| 119 |
+
raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
|
| 120 |
+
self._target = target
|
| 121 |
+
|
| 122 |
+
def __call__(self, **kwargs) -> LazyDict[T]:
|
| 123 |
+
if is_dataclass(self._target) or attrs.has(self._target):
|
| 124 |
+
# omegaconf object cannot hold dataclass type
|
| 125 |
+
# https://github.com/omry/omegaconf/issues/784
|
| 126 |
+
target = _convert_target_to_string(self._target)
|
| 127 |
+
else:
|
| 128 |
+
target = self._target
|
| 129 |
+
kwargs["_target_"] = target
|
| 130 |
+
|
| 131 |
+
_final_params = get_default_params(self._target)
|
| 132 |
+
_final_params.update(kwargs)
|
| 133 |
+
|
| 134 |
+
return cast(LazyDict[T], DictConfig(content=_final_params, flags={"allow_objects": True}))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _visit_dict_config(cfg, func):
|
| 138 |
+
"""
|
| 139 |
+
Apply func recursively to all DictConfig in cfg.
|
| 140 |
+
"""
|
| 141 |
+
if isinstance(cfg, DictConfig):
|
| 142 |
+
func(cfg)
|
| 143 |
+
for v in cfg.values():
|
| 144 |
+
_visit_dict_config(v, func)
|
| 145 |
+
elif isinstance(cfg, ListConfig):
|
| 146 |
+
for v in cfg:
|
| 147 |
+
_visit_dict_config(v, func)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _validate_py_syntax(filename):
|
| 151 |
+
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
| 152 |
+
with PathManager.open(filename, "r") as f:
|
| 153 |
+
content = f.read()
|
| 154 |
+
try:
|
| 155 |
+
ast.parse(content)
|
| 156 |
+
except SyntaxError as e:
|
| 157 |
+
raise SyntaxError(f"Config file {filename} has syntax error!") from e
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _cast_to_config(obj):
|
| 161 |
+
# if given a dict, return DictConfig instead
|
| 162 |
+
if isinstance(obj, dict):
|
| 163 |
+
return DictConfig(obj, flags={"allow_objects": True})
|
| 164 |
+
return obj
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
|
| 168 |
+
"""
|
| 169 |
+
A namespace to put all imported config into.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _random_package_name(filename):
|
| 174 |
+
# generate a random package name when loading config files
|
| 175 |
+
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@contextmanager
|
| 179 |
+
def _patch_import():
|
| 180 |
+
"""
|
| 181 |
+
Enhance relative import statements in config files, so that they:
|
| 182 |
+
1. locate files purely based on relative location, regardless of packages.
|
| 183 |
+
e.g. you can import file without having __init__
|
| 184 |
+
2. do not cache modules globally; modifications of module states has no side effect
|
| 185 |
+
3. support other storage system through PathManager, so config files can be in the cloud
|
| 186 |
+
4. imported dict are turned into omegaconf.DictConfig automatically
|
| 187 |
+
"""
|
| 188 |
+
old_import = builtins.__import__
|
| 189 |
+
|
| 190 |
+
def find_relative_file(original_file, relative_import_path, level):
|
| 191 |
+
# NOTE: "from . import x" is not handled. Because then it's unclear
|
| 192 |
+
# if such import should produce `x` as a python module or DictConfig.
|
| 193 |
+
# This can be discussed further if needed.
|
| 194 |
+
relative_import_err = """
|
| 195 |
+
Relative import of directories is not allowed within config files.
|
| 196 |
+
Within a config file, relative import can only import other config files.
|
| 197 |
+
""".replace("\n", " ")
|
| 198 |
+
if not len(relative_import_path):
|
| 199 |
+
raise ImportError(relative_import_err)
|
| 200 |
+
|
| 201 |
+
cur_file = os.path.dirname(original_file)
|
| 202 |
+
for _ in range(level - 1):
|
| 203 |
+
cur_file = os.path.dirname(cur_file)
|
| 204 |
+
cur_name = relative_import_path.lstrip(".")
|
| 205 |
+
for part in cur_name.split("."):
|
| 206 |
+
cur_file = os.path.join(cur_file, part)
|
| 207 |
+
if not cur_file.endswith(".py"):
|
| 208 |
+
cur_file += ".py"
|
| 209 |
+
if not PathManager.isfile(cur_file):
|
| 210 |
+
cur_file_no_suffix = cur_file[: -len(".py")]
|
| 211 |
+
if PathManager.isdir(cur_file_no_suffix):
|
| 212 |
+
raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
|
| 213 |
+
else:
|
| 214 |
+
raise ImportError(
|
| 215 |
+
f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
|
| 216 |
+
)
|
| 217 |
+
return cur_file
|
| 218 |
+
|
| 219 |
+
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 220 |
+
if (
|
| 221 |
+
# Only deal with relative imports inside config files
|
| 222 |
+
level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
|
| 223 |
+
):
|
| 224 |
+
cur_file = find_relative_file(globals["__file__"], name, level)
|
| 225 |
+
_validate_py_syntax(cur_file)
|
| 226 |
+
spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
|
| 227 |
+
module = importlib.util.module_from_spec(spec)
|
| 228 |
+
module.__file__ = cur_file
|
| 229 |
+
with PathManager.open(cur_file) as f:
|
| 230 |
+
content = f.read()
|
| 231 |
+
exec(compile(content, cur_file, "exec"), module.__dict__)
|
| 232 |
+
for name in fromlist: # turn imported dict into DictConfig automatically
|
| 233 |
+
val = _cast_to_config(module.__dict__[name])
|
| 234 |
+
module.__dict__[name] = val
|
| 235 |
+
return module
|
| 236 |
+
return old_import(name, globals, locals, fromlist=fromlist, level=level)
|
| 237 |
+
|
| 238 |
+
builtins.__import__ = new_import
|
| 239 |
+
yield new_import
|
| 240 |
+
builtins.__import__ = old_import
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class LazyConfig:
|
| 244 |
+
"""
|
| 245 |
+
Provide methods to save, load, and overrides an omegaconf config object
|
| 246 |
+
which may contain definition of lazily-constructed objects.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def load_rel(filename: str, keys: None | str | tuple[str, ...] = None):
|
| 251 |
+
"""
|
| 252 |
+
Similar to :meth:`load()`, but load path relative to the caller's
|
| 253 |
+
source file.
|
| 254 |
+
|
| 255 |
+
This has the same functionality as a relative import, except that this method
|
| 256 |
+
accepts filename as a string, so more characters are allowed in the filename.
|
| 257 |
+
"""
|
| 258 |
+
caller_frame = inspect.stack()[1]
|
| 259 |
+
caller_fname = caller_frame[0].f_code.co_filename
|
| 260 |
+
assert caller_fname != "<string>", "load_rel Unable to find caller"
|
| 261 |
+
caller_dir = os.path.dirname(caller_fname)
|
| 262 |
+
filename = os.path.join(caller_dir, filename)
|
| 263 |
+
return LazyConfig.load(filename, keys)
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def load(filename: str, keys: None | str | tuple[str, ...] = None):
|
| 267 |
+
"""
|
| 268 |
+
Load a config file.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
filename: absolute path or relative path w.r.t. the current working directory
|
| 272 |
+
keys: keys to load and return. If not given, return all keys
|
| 273 |
+
(whose values are config objects) in a dict.
|
| 274 |
+
"""
|
| 275 |
+
has_keys = keys is not None
|
| 276 |
+
filename = filename.replace("/./", "/") # redundant
|
| 277 |
+
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
|
| 278 |
+
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
|
| 279 |
+
if filename.endswith(".py"):
|
| 280 |
+
_validate_py_syntax(filename)
|
| 281 |
+
|
| 282 |
+
with _patch_import():
|
| 283 |
+
# Record the filename
|
| 284 |
+
module_namespace = {
|
| 285 |
+
"__file__": filename,
|
| 286 |
+
"__package__": _random_package_name(filename),
|
| 287 |
+
}
|
| 288 |
+
with PathManager.open(filename) as f:
|
| 289 |
+
content = f.read()
|
| 290 |
+
# Compile first with filename to:
|
| 291 |
+
# 1. make filename appears in stacktrace
|
| 292 |
+
# 2. make load_rel able to find its parent's (possibly remote) location
|
| 293 |
+
exec(compile(content, filename, "exec"), module_namespace)
|
| 294 |
+
|
| 295 |
+
ret = module_namespace
|
| 296 |
+
else:
|
| 297 |
+
with PathManager.open(filename) as f:
|
| 298 |
+
obj = yaml.unsafe_load(f)
|
| 299 |
+
ret = OmegaConf.create(obj, flags={"allow_objects": True})
|
| 300 |
+
|
| 301 |
+
if has_keys:
|
| 302 |
+
if isinstance(keys, str):
|
| 303 |
+
return _cast_to_config(ret[keys])
|
| 304 |
+
else:
|
| 305 |
+
return tuple(_cast_to_config(ret[a]) for a in keys)
|
| 306 |
+
else:
|
| 307 |
+
if filename.endswith(".py"):
|
| 308 |
+
# when not specified, only load those that are config objects
|
| 309 |
+
ret = DictConfig(
|
| 310 |
+
{
|
| 311 |
+
name: _cast_to_config(value)
|
| 312 |
+
for name, value in ret.items()
|
| 313 |
+
if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
|
| 314 |
+
},
|
| 315 |
+
flags={"allow_objects": True},
|
| 316 |
+
)
|
| 317 |
+
return ret
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def save_pkl(cfg, filename: str) -> str:
|
| 321 |
+
"""
|
| 322 |
+
Saves a Config object to a file using pickle serialization. This method is typically used
|
| 323 |
+
when the configuration object contains complex objects, such as lambdas, that are not supported by
|
| 324 |
+
simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
|
| 325 |
+
object before serialization to ensure that the original object remains unmodified.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
cfg: A Config object to be serialized and saved.
|
| 329 |
+
filename: The path and name of the file where the configuration should be saved. The function
|
| 330 |
+
assumes the file extension indicates a pickle format (e.g., .pkl).
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
str: The filename to which the configuration was saved. This can be used to verify the file location
|
| 334 |
+
or log the outcome.
|
| 335 |
+
|
| 336 |
+
Notes:
|
| 337 |
+
- The function logs a warning if the configuration is successfully saved using pickle.
|
| 338 |
+
- If saving fails, an error is logged with the exception details.
|
| 339 |
+
"""
|
| 340 |
+
try:
|
| 341 |
+
cfg = deepcopy(cfg)
|
| 342 |
+
except Exception:
|
| 343 |
+
pass
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
with PathManager.open(filename, "wb") as f:
|
| 347 |
+
pickle.dump(cfg, f)
|
| 348 |
+
log.warning(f"Config is saved using pickle at {filename}.")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
log.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
|
| 351 |
+
if dill_pickle:
|
| 352 |
+
try:
|
| 353 |
+
with PathManager.open(filename, "wb") as f:
|
| 354 |
+
pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
|
| 355 |
+
log.warning(f"Config is saved using dill at {filename}.")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
log.error(f"Failed to save config to {filename}: {e}.")
|
| 358 |
+
if cloudpickle:
|
| 359 |
+
try:
|
| 360 |
+
with PathManager.open(filename, "wb") as f:
|
| 361 |
+
pickle.dump(cloudpickle.dumps(cfg), f)
|
| 362 |
+
log.warning(f"Config is saved using cloudpickle at {filename}.")
|
| 363 |
+
except Exception as e:
|
| 364 |
+
log.error(f"Failed to save config to {filename}: {e}.")
|
| 365 |
+
else:
|
| 366 |
+
log.error("cloudpickle is not available. Cannot save the config.")
|
| 367 |
+
raise e
|
| 368 |
+
|
| 369 |
+
return filename
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def save_yaml(cfg, filename: str) -> str:
|
| 373 |
+
"""
|
| 374 |
+
Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
|
| 378 |
+
filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
|
| 382 |
+
|
| 383 |
+
Notes:
|
| 384 |
+
- The function logs a warning if the configuration is successfully saved using YAML.
|
| 385 |
+
- If saving fails, an error is logged with the exception details.
|
| 386 |
+
"""
|
| 387 |
+
logger = logging.getLogger(__name__)
|
| 388 |
+
try:
|
| 389 |
+
cfg = deepcopy(cfg)
|
| 390 |
+
except Exception:
|
| 391 |
+
pass
|
| 392 |
+
|
| 393 |
+
# Define a function to check if an item is serializable to YAML
|
| 394 |
+
def is_serializable(item):
|
| 395 |
+
try:
|
| 396 |
+
OmegaConf.to_yaml(item)
|
| 397 |
+
return True
|
| 398 |
+
except Exception as e:
|
| 399 |
+
return False
|
| 400 |
+
|
| 401 |
+
# Function to convert unserializable items to strings
|
| 402 |
+
def serialize_config(config):
|
| 403 |
+
if isinstance(config, DictConfig):
|
| 404 |
+
for key, value in config.items():
|
| 405 |
+
if isinstance(value, (DictConfig, ListConfig)):
|
| 406 |
+
try:
|
| 407 |
+
if "_target_" in value:
|
| 408 |
+
default_params = get_default_params(value["_target_"])
|
| 409 |
+
for default_key, default_v in default_params.items():
|
| 410 |
+
if default_key not in value:
|
| 411 |
+
value[default_key] = default_v
|
| 412 |
+
except Exception as e:
|
| 413 |
+
log.error(f"Failed to add default argument values: {e}")
|
| 414 |
+
|
| 415 |
+
serialize_config(value)
|
| 416 |
+
else:
|
| 417 |
+
if not is_serializable(value) and value is not None:
|
| 418 |
+
config[key] = str(value)
|
| 419 |
+
elif isinstance(config, ListConfig):
|
| 420 |
+
for i, item in enumerate(config):
|
| 421 |
+
if isinstance(item, (DictConfig, ListConfig)):
|
| 422 |
+
serialize_config(item)
|
| 423 |
+
else:
|
| 424 |
+
if not is_serializable(item) and item is not None:
|
| 425 |
+
config[i] = str(item)
|
| 426 |
+
else:
|
| 427 |
+
raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
|
| 428 |
+
return config
|
| 429 |
+
|
| 430 |
+
# Convert Config object to a DictConfig object.
|
| 431 |
+
config_dict = attrs.asdict(cfg)
|
| 432 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 433 |
+
|
| 434 |
+
# Serialize the DictConfig object by converting non-serializable objects to strings.
|
| 435 |
+
config_omegaconf = serialize_config(config_omegaconf)
|
| 436 |
+
|
| 437 |
+
config_dict: dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
|
| 438 |
+
sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
|
| 439 |
+
with open(filename, "w") as f:
|
| 440 |
+
yaml.dump(sorted_config, f, default_flow_style=False)
|
| 441 |
+
log.warning(f"Config is saved using omegaconf at {filename}.")
|
| 442 |
+
return filename
|
imaginaire/lazy_config/omegaconf_patch.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
from omegaconf.base import DictKeyType, SCMode
|
| 20 |
+
from omegaconf.dictconfig import DictConfig # pragma: no cover
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any:
|
| 24 |
+
"""
|
| 25 |
+
Converts an OmegaConf configuration object to a native Python container (dict or list), unless
|
| 26 |
+
the configuration is specifically created by LazyCall, in which case the original configuration
|
| 27 |
+
is returned directly.
|
| 28 |
+
|
| 29 |
+
This function serves as a modification of the original `to_object` method from OmegaConf,
|
| 30 |
+
preventing DictConfig objects created by LazyCall from being automatically converted to Python
|
| 31 |
+
dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
|
| 32 |
+
structure and behavior.
|
| 33 |
+
|
| 34 |
+
Differences from OmegaConf's original `to_object`:
|
| 35 |
+
- Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
|
| 36 |
+
|
| 37 |
+
Reference:
|
| 38 |
+
- Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
cfg (Any): The OmegaConf configuration object to convert.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
|
| 45 |
+
`cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
|
| 46 |
+
|
| 47 |
+
Examples:
|
| 48 |
+
>>> cfg = DictConfig({"key": "value", "_target_": "Model"})
|
| 49 |
+
>>> to_object(cfg)
|
| 50 |
+
DictConfig({"key": "value", "_target_": "Model"})
|
| 51 |
+
|
| 52 |
+
>>> cfg = DictConfig({"list": [1, 2, 3]})
|
| 53 |
+
>>> to_object(cfg)
|
| 54 |
+
{'list': [1, 2, 3]}
|
| 55 |
+
"""
|
| 56 |
+
if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
|
| 57 |
+
return cfg
|
| 58 |
+
|
| 59 |
+
return OmegaConf.to_container(
|
| 60 |
+
cfg=cfg,
|
| 61 |
+
resolve=True,
|
| 62 |
+
throw_on_missing=True,
|
| 63 |
+
enum_to_str=False,
|
| 64 |
+
structured_config_mode=SCMode.INSTANTIATE,
|
| 65 |
+
)
|
imaginaire/lazy_config/registry.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import pydoc
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from fvcore.common.registry import Registry # for backward compatibility.
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
``Registry`` and `locate` provide ways to map a string (typically found
|
| 23 |
+
in config files) to callable objects.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
__all__ = ["Registry", "locate"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _convert_target_to_string(t: Any) -> str:
|
| 30 |
+
"""
|
| 31 |
+
Inverse of ``locate()``.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
t: any object with ``__module__`` and ``__qualname__``
|
| 35 |
+
"""
|
| 36 |
+
module, qualname = t.__module__, t.__qualname__
|
| 37 |
+
|
| 38 |
+
# Compress the path to this object, e.g. ``module.submodule._impl.class``
|
| 39 |
+
# may become ``module.submodule.class``, if the later also resolves to the same
|
| 40 |
+
# object. This simplifies the string, and also is less affected by moving the
|
| 41 |
+
# class implementation.
|
| 42 |
+
module_parts = module.split(".")
|
| 43 |
+
for k in range(1, len(module_parts)):
|
| 44 |
+
prefix = ".".join(module_parts[:k])
|
| 45 |
+
candidate = f"{prefix}.{qualname}"
|
| 46 |
+
try:
|
| 47 |
+
if locate(candidate) is t:
|
| 48 |
+
return candidate
|
| 49 |
+
except ImportError:
|
| 50 |
+
pass
|
| 51 |
+
return f"{module}.{qualname}"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def locate(name: str) -> Any:
|
| 55 |
+
"""
|
| 56 |
+
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
|
| 57 |
+
such as "module.submodule.class_name".
|
| 58 |
+
|
| 59 |
+
Raise Exception if it cannot be found.
|
| 60 |
+
"""
|
| 61 |
+
obj = pydoc.locate(name)
|
| 62 |
+
|
| 63 |
+
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
|
| 64 |
+
# by pydoc.locate. Try a private function from hydra.
|
| 65 |
+
if obj is None:
|
| 66 |
+
try:
|
| 67 |
+
# from hydra.utils import get_method - will print many errors
|
| 68 |
+
from hydra.utils import _locate
|
| 69 |
+
except ImportError as e:
|
| 70 |
+
raise ImportError(f"Cannot dynamically locate object {name}!") from e
|
| 71 |
+
else:
|
| 72 |
+
obj = _locate(name) # it raises if fails
|
| 73 |
+
|
| 74 |
+
return obj
|
imaginaire/model.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from imaginaire.lazy_config import LazyDict, instantiate
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ImaginaireModel(torch.nn.Module):
|
| 24 |
+
"""The base model class of Imaginaire. It is inherited from torch.nn.Module.
|
| 25 |
+
|
| 26 |
+
All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
|
| 27 |
+
computation graphs. All inheriting child classes should implement the following methods:
|
| 28 |
+
- training_step(): The training step of the model, including the loss computation.
|
| 29 |
+
- validation_step(): The validation step of the model, including the loss computation.
|
| 30 |
+
- forward(): The computation graph for model inference.
|
| 31 |
+
The following methods have default implementations in ImaginaireModel:
|
| 32 |
+
- init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
def init_optimizer_scheduler(
|
| 39 |
+
self,
|
| 40 |
+
optimizer_config: LazyDict[torch.optim.Optimizer],
|
| 41 |
+
scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler],
|
| 42 |
+
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
|
| 43 |
+
"""Creates the optimizer and scheduler for the model.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
config_model (ModelConfig): The config object for the model.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 50 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 51 |
+
"""
|
| 52 |
+
optimizer_config.params = self.parameters()
|
| 53 |
+
optimizer = instantiate(optimizer_config)
|
| 54 |
+
scheduler_config.optimizer = optimizer
|
| 55 |
+
scheduler = instantiate(scheduler_config)
|
| 56 |
+
return optimizer, scheduler
|
| 57 |
+
|
| 58 |
+
def training_step(
|
| 59 |
+
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 60 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 61 |
+
"""The training step of the model, including the loss computation.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 65 |
+
iteration (int): Current iteration number.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
|
| 69 |
+
loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
|
| 70 |
+
"""
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def validation_step(
|
| 75 |
+
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 76 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 77 |
+
"""The validation step of the model, including the loss computation.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 81 |
+
iteration (int): Current iteration number.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
|
| 85 |
+
loss (torch.Tensor): The total loss (weighted sum of various losses).
|
| 86 |
+
"""
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
@torch.inference_mode()
|
| 90 |
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
| 91 |
+
"""The computation graph for model inference.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
*args: Whatever you decide to pass into the forward method.
|
| 95 |
+
**kwargs: Keyword arguments are also possible.
|
| 96 |
+
|
| 97 |
+
Return:
|
| 98 |
+
Your model's output.
|
| 99 |
+
"""
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
def on_model_init_start(self, set_barrier=False) -> None:
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
def on_model_init_end(self, set_barrier=False) -> None:
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
|
| 109 |
+
"""The model preparation before the training is launched
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
memory_format (torch.memory_format): Memory format of the model.
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
def on_before_zero_grad(
|
| 117 |
+
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
|
| 118 |
+
) -> None:
|
| 119 |
+
"""Hook before zero_grad() is called.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 123 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 124 |
+
iteration (int): Current iteration number.
|
| 125 |
+
"""
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def on_after_backward(self, iteration: int = 0) -> None:
|
| 129 |
+
"""Hook after loss.backward() is called.
|
| 130 |
+
|
| 131 |
+
This method is called immediately after the backward pass, allowing for custom operations
|
| 132 |
+
or modifications to be performed on the gradients before the optimizer step.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
iteration (int): Current iteration number.
|
| 136 |
+
"""
|
| 137 |
+
pass
|
imaginaire/trainer.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import signal
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch.utils.data
|
| 24 |
+
|
| 25 |
+
from imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from megatron.core import parallel_state
|
| 29 |
+
|
| 30 |
+
USE_MEGATRON = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
USE_MEGATRON = False
|
| 33 |
+
print("Megatron-core is not installed.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from imaginaire.lazy_config import LazyConfig, instantiate
|
| 37 |
+
from imaginaire.model import ImaginaireModel
|
| 38 |
+
from imaginaire.utils import callback, distributed, log, misc
|
| 39 |
+
from imaginaire.utils.checkpointer import Checkpointer
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ImaginaireTrainer:
|
| 43 |
+
"""The base trainer class of Imaginaire.
|
| 44 |
+
|
| 45 |
+
All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
|
| 46 |
+
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
|
| 47 |
+
mixed-precision training (fp16/bf16).
|
| 48 |
+
|
| 49 |
+
Attributes:
|
| 50 |
+
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
|
| 51 |
+
training_timer (misc.Timer): Timer object to time code blocks and functions.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
"""Constructor of the trainer.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
config (Config): The config object for the Imaginaire codebase.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.config = config
|
| 62 |
+
# Set up the distributed computing environment.
|
| 63 |
+
with misc.timer("init_distributed"):
|
| 64 |
+
distributed.init()
|
| 65 |
+
# Set up parallel states.
|
| 66 |
+
if hasattr(config.model, "context_parallel_size"):
|
| 67 |
+
if config.model_parallel.context_parallel_size > 1:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
|
| 70 |
+
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
log.critical(
|
| 74 |
+
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
|
| 75 |
+
)
|
| 76 |
+
config.model_parallel.context_parallel_size = config.model.context_parallel_size
|
| 77 |
+
if USE_MEGATRON:
|
| 78 |
+
if (
|
| 79 |
+
"create_gloo_process_groups"
|
| 80 |
+
in inspect.signature(parallel_state.initialize_model_parallel).parameters
|
| 81 |
+
):
|
| 82 |
+
parallel_state.initialize_model_parallel(
|
| 83 |
+
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 84 |
+
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 85 |
+
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 86 |
+
create_gloo_process_groups=False,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
parallel_state.initialize_model_parallel(
|
| 90 |
+
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 91 |
+
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 92 |
+
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 93 |
+
)
|
| 94 |
+
# `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
|
| 95 |
+
# It is not part of the original `parallel_state` API, so we need to set it manually.
|
| 96 |
+
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
|
| 97 |
+
if parallel_state.sequence_parallel:
|
| 98 |
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
| 99 |
+
|
| 100 |
+
# Create the local job directory, save the config file, and pipe to a local log.
|
| 101 |
+
if distributed.is_rank0():
|
| 102 |
+
os.makedirs(config.job.path_local, exist_ok=True)
|
| 103 |
+
# Save the config as .pkl for reproducibility.
|
| 104 |
+
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
|
| 105 |
+
# Save the config as .yaml for reading or parsing experiment hyperparameters.
|
| 106 |
+
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
|
| 107 |
+
dist.barrier()
|
| 108 |
+
log.init_loguru_file(f"{config.job.path_local}/stdout.log")
|
| 109 |
+
if distributed.is_rank0():
|
| 110 |
+
# Print important environment variables and the effective config.
|
| 111 |
+
log.info("Config:\n" + config.pretty_print(use_color=True))
|
| 112 |
+
misc.print_environ_variables(["TORCH_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
|
| 113 |
+
# Set the random seed. If multi-GPU, different ranks are set with different seeds.
|
| 114 |
+
misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
|
| 115 |
+
# Initialize cuDNN.
|
| 116 |
+
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
|
| 117 |
+
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
|
| 118 |
+
# Floating-point precision settings.
|
| 119 |
+
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
|
| 120 |
+
# Initialize the callback functions.
|
| 121 |
+
self.callbacks = callback.CallBackGroup(config=config, trainer=self)
|
| 122 |
+
# Initialize the model checkpointer.
|
| 123 |
+
if config.checkpoint.type is None:
|
| 124 |
+
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
|
| 125 |
+
else:
|
| 126 |
+
self.checkpointer: Checkpointer = instantiate(
|
| 127 |
+
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
|
| 128 |
+
)
|
| 129 |
+
# Initialize the timer for speed benchmarking.
|
| 130 |
+
self.training_timer = misc.TrainingTimer()
|
| 131 |
+
# Send a TimeoutError if a training step takes over timeout_period seconds.
|
| 132 |
+
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
|
| 133 |
+
|
| 134 |
+
def train(
|
| 135 |
+
self,
|
| 136 |
+
model: ImaginaireModel,
|
| 137 |
+
dataloader_train: torch.utils.data.DataLoader,
|
| 138 |
+
dataloader_val: torch.utils.data.DataLoader,
|
| 139 |
+
) -> None:
|
| 140 |
+
"""The training function.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
model (ImaginaireModel): The PyTorch model.
|
| 144 |
+
dataloader_train (torch.utils.data.DataLoader): The training data loader.
|
| 145 |
+
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 146 |
+
"""
|
| 147 |
+
# Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
|
| 148 |
+
model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
|
| 149 |
+
model.on_train_start(self.config.trainer.memory_format)
|
| 150 |
+
|
| 151 |
+
# Initialize the optimizer, scheduler, and grad_scaler.
|
| 152 |
+
self.callbacks.on_optimizer_init_start()
|
| 153 |
+
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
|
| 154 |
+
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
|
| 155 |
+
self.callbacks.on_optimizer_init_end()
|
| 156 |
+
# Load the model checkpoint and get the starting iteration number.
|
| 157 |
+
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
|
| 158 |
+
grad_accum_iter = 0
|
| 159 |
+
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 160 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 161 |
+
# Create a DDP model wrapper.
|
| 162 |
+
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
|
| 163 |
+
elif self.config.trainer.distributed_parallelism == "fsdp":
|
| 164 |
+
model_ddp = model
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 167 |
+
log.info("Starting training...")
|
| 168 |
+
self.callbacks.on_train_start(model, iteration=iteration)
|
| 169 |
+
# Initial validation.
|
| 170 |
+
if self.config.trainer.run_validation and iteration == 0:
|
| 171 |
+
self.validate(model, dataloader_val, iteration=iteration)
|
| 172 |
+
log.info("Initial validation done.")
|
| 173 |
+
_end_training = False
|
| 174 |
+
with (
|
| 175 |
+
maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
|
| 176 |
+
maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
|
| 177 |
+
):
|
| 178 |
+
while True:
|
| 179 |
+
dataloader_train_iter = iter(dataloader_train)
|
| 180 |
+
while True:
|
| 181 |
+
self.callbacks.on_before_dataloading(iteration)
|
| 182 |
+
try:
|
| 183 |
+
with self.training_timer("dataloader_train"):
|
| 184 |
+
data_batch = next(dataloader_train_iter)
|
| 185 |
+
except StopIteration:
|
| 186 |
+
break
|
| 187 |
+
finally:
|
| 188 |
+
self.callbacks.on_after_dataloading(iteration)
|
| 189 |
+
# If max_iter is reached, exit the training loop.
|
| 190 |
+
if iteration >= self.config.trainer.max_iter:
|
| 191 |
+
_end_training = True
|
| 192 |
+
break
|
| 193 |
+
# Move all tensors in the data batch to GPU device.
|
| 194 |
+
data_batch = misc.to(data_batch, device="cuda")
|
| 195 |
+
# The actual training step.
|
| 196 |
+
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
|
| 197 |
+
self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
|
| 198 |
+
if not model.training:
|
| 199 |
+
model_ddp.train()
|
| 200 |
+
assert model_ddp.training, "model_ddp is not in training mode."
|
| 201 |
+
assert model.training, "model is not in training mode."
|
| 202 |
+
output_batch, loss, grad_accum_iter = self.training_step(
|
| 203 |
+
model_ddp,
|
| 204 |
+
optimizer,
|
| 205 |
+
scheduler,
|
| 206 |
+
grad_scaler,
|
| 207 |
+
data_batch,
|
| 208 |
+
iteration=iteration,
|
| 209 |
+
grad_accum_iter=grad_accum_iter,
|
| 210 |
+
)
|
| 211 |
+
self.callbacks.on_training_step_batch_end(
|
| 212 |
+
model, data_batch, output_batch, loss, iteration=iteration
|
| 213 |
+
)
|
| 214 |
+
# If the gradients are still being accumulated, continue to load the next training batch.
|
| 215 |
+
if grad_accum_iter != 0:
|
| 216 |
+
continue
|
| 217 |
+
# Do the following when an actual optimizer (update) step has been made.
|
| 218 |
+
iteration += 1
|
| 219 |
+
# Save checkpoint.
|
| 220 |
+
if iteration % self.config.checkpoint.save_iter == 0:
|
| 221 |
+
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 222 |
+
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 223 |
+
# Validation.
|
| 224 |
+
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
|
| 225 |
+
self.validate(model, dataloader_val, iteration=iteration)
|
| 226 |
+
# This iteration is successful; reset the timeout signal.
|
| 227 |
+
signal.alarm(self.config.trainer.timeout_period)
|
| 228 |
+
if torch_profiler:
|
| 229 |
+
torch_profiler.step()
|
| 230 |
+
if memory_profiler:
|
| 231 |
+
memory_profiler.step()
|
| 232 |
+
if _end_training:
|
| 233 |
+
break
|
| 234 |
+
log.success("Done with training.")
|
| 235 |
+
if iteration % self.config.checkpoint.save_iter != 0:
|
| 236 |
+
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 237 |
+
self.callbacks.on_train_end(model, iteration=iteration)
|
| 238 |
+
self.checkpointer.finalize()
|
| 239 |
+
distributed.barrier()
|
| 240 |
+
self.callbacks.on_app_end()
|
| 241 |
+
|
| 242 |
+
def training_step(
|
| 243 |
+
self,
|
| 244 |
+
model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
|
| 245 |
+
optimizer: torch.optim.Optimizer,
|
| 246 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 247 |
+
grad_scaler: torch.amp.GradScaler,
|
| 248 |
+
data: dict[str, torch.Tensor],
|
| 249 |
+
iteration: int = 0,
|
| 250 |
+
grad_accum_iter: int = 0,
|
| 251 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
|
| 252 |
+
"""The training step.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
|
| 256 |
+
module, depending on whether distributed training is enabled or not.
|
| 257 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 258 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 259 |
+
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 260 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 261 |
+
iteration (int): Current iteration number.
|
| 262 |
+
grad_accum_iter (int): Number of gradient accumulation iterations.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
|
| 266 |
+
loss (torch.Tensor): The total loss of the training data batch.
|
| 267 |
+
"""
|
| 268 |
+
# Only let DDP sync gradient at the last iteration of the gradient accumulation window
|
| 269 |
+
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
|
| 270 |
+
self.callbacks.on_before_forward(iteration=iteration)
|
| 271 |
+
with self.training_timer("forward"):
|
| 272 |
+
output_batch, loss = model_ddp.training_step(data, iteration)
|
| 273 |
+
self.callbacks.on_after_forward(iteration=iteration)
|
| 274 |
+
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
|
| 275 |
+
with self.training_timer("backward"):
|
| 276 |
+
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
|
| 277 |
+
loss_scaled.backward()
|
| 278 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 279 |
+
model_ddp.module.on_after_backward()
|
| 280 |
+
else:
|
| 281 |
+
model_ddp.on_after_backward()
|
| 282 |
+
self.callbacks.on_after_backward(model_ddp, iteration=iteration)
|
| 283 |
+
grad_accum_iter += 1
|
| 284 |
+
if grad_accum_iter == self.config.trainer.grad_accum_iter:
|
| 285 |
+
with self.training_timer("optimizer_step"):
|
| 286 |
+
self.callbacks.on_before_optimizer_step(
|
| 287 |
+
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
|
| 288 |
+
)
|
| 289 |
+
grad_scaler.step(optimizer)
|
| 290 |
+
grad_scaler.update()
|
| 291 |
+
scheduler.step()
|
| 292 |
+
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
|
| 293 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 294 |
+
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 295 |
+
else:
|
| 296 |
+
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 297 |
+
optimizer.zero_grad(set_to_none=True)
|
| 298 |
+
grad_accum_iter = 0
|
| 299 |
+
return output_batch, loss, grad_accum_iter
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
|
| 303 |
+
"""Validate on the full validation dataset.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
model (ImaginaireModel): The PyTorch model.
|
| 307 |
+
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 308 |
+
iteration (int): Current iteration number.
|
| 309 |
+
"""
|
| 310 |
+
log.info(f"Validating at iteration {iteration}...")
|
| 311 |
+
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
|
| 312 |
+
model.eval()
|
| 313 |
+
# Evaluate on the full validation set.
|
| 314 |
+
with model.pipe.ema_scope(context="Validation", is_cpu=False):
|
| 315 |
+
for val_iter, data_batch in enumerate(dataloader_val):
|
| 316 |
+
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
|
| 317 |
+
break
|
| 318 |
+
data_batch = misc.to(data_batch, device="cuda")
|
| 319 |
+
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
|
| 320 |
+
output_batch, loss = model.validation_step(data_batch, iteration)
|
| 321 |
+
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 322 |
+
self.callbacks.on_validation_end(model, iteration=iteration)
|
imaginaire/utils/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/utils/callback.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
import warnings
|
| 20 |
+
from collections.abc import Callable
|
| 21 |
+
from typing import TYPE_CHECKING, Any
|
| 22 |
+
|
| 23 |
+
import omegaconf
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.data
|
| 26 |
+
import tqdm
|
| 27 |
+
|
| 28 |
+
from imaginaire.lazy_config import instantiate
|
| 29 |
+
from imaginaire.utils import distributed, log
|
| 30 |
+
from imaginaire.utils.misc import get_local_tensor_if_DTensor
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from megatron.core import parallel_state
|
| 34 |
+
except ImportError:
|
| 35 |
+
parallel_state = None
|
| 36 |
+
print("Megatron-core is not installed.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
from imaginaire.config import Config
|
| 41 |
+
from imaginaire.model import ImaginaireModel
|
| 42 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CallBackGroup:
|
| 46 |
+
"""A class for hosting a collection of callback objects.
|
| 47 |
+
|
| 48 |
+
It is used to execute callback functions of multiple callback objects with the same method name.
|
| 49 |
+
When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
|
| 50 |
+
self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
|
| 51 |
+
|
| 52 |
+
Attributes:
|
| 53 |
+
_callbacks (list[Callback]): List of callback objects.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
|
| 57 |
+
"""Initializes the list of callback objects.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
config (Config): The config object for the Imaginaire codebase.
|
| 61 |
+
trainer (ImaginaireTrainer): The main trainer.
|
| 62 |
+
"""
|
| 63 |
+
self._callbacks = []
|
| 64 |
+
callback_configs = config.trainer.callbacks
|
| 65 |
+
if callback_configs:
|
| 66 |
+
if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
|
| 69 |
+
"Please update your code",
|
| 70 |
+
DeprecationWarning,
|
| 71 |
+
stacklevel=2,
|
| 72 |
+
)
|
| 73 |
+
callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
|
| 74 |
+
for callback_name, current_callback_cfg in callback_configs.items():
|
| 75 |
+
if "_target_" not in current_callback_cfg:
|
| 76 |
+
log.critical(
|
| 77 |
+
f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
|
| 78 |
+
)
|
| 79 |
+
continue
|
| 80 |
+
log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
|
| 81 |
+
_callback = instantiate(current_callback_cfg)
|
| 82 |
+
assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
|
| 83 |
+
_callback.config = config
|
| 84 |
+
_callback.trainer = trainer
|
| 85 |
+
self._callbacks.append(_callback)
|
| 86 |
+
|
| 87 |
+
def __getattr__(self, method_name: str) -> Callable:
|
| 88 |
+
"""Loops through the callback objects to call the corresponding callback function.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
method_name (str): Callback method name.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def multi_callback_wrapper(*args, **kwargs) -> None:
|
| 95 |
+
for callback in self._callbacks:
|
| 96 |
+
assert hasattr(callback, method_name)
|
| 97 |
+
method = getattr(callback, method_name)
|
| 98 |
+
assert callable(method)
|
| 99 |
+
_ = method(*args, **kwargs)
|
| 100 |
+
|
| 101 |
+
return multi_callback_wrapper
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Callback:
|
| 105 |
+
"""The base class for all callbacks.
|
| 106 |
+
|
| 107 |
+
All callbacks should inherit from this class and adhere to the established method names and signatures.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, config: Config | None = None, trainer: ImaginaireTrainer | None = None):
|
| 111 |
+
"""Initializes a Callback object.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
|
| 115 |
+
trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
|
| 116 |
+
|
| 117 |
+
Notes:
|
| 118 |
+
The config and trainer parameters are optional to maintain backward compatibility.
|
| 119 |
+
In future releases, these parameters will be removed. Upon using these parameters, a deprecation
|
| 120 |
+
warning will be issued.
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
if config is not None or trainer is not None:
|
| 124 |
+
warnings.warn(
|
| 125 |
+
"The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
|
| 126 |
+
"Please update your code to create Callback instances without these parameters.",
|
| 127 |
+
DeprecationWarning,
|
| 128 |
+
stacklevel=2,
|
| 129 |
+
)
|
| 130 |
+
del config, trainer
|
| 131 |
+
|
| 132 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Called before the training step, for each batch. This is paired with on_training_step_end() but note that
|
| 138 |
+
when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
|
| 139 |
+
this function is called for every batch.
|
| 140 |
+
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 141 |
+
for every batch, albeit with the same iteration number.
|
| 142 |
+
"""
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def on_training_step_batch_start(
|
| 146 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 147 |
+
) -> None:
|
| 148 |
+
"""
|
| 149 |
+
Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
|
| 150 |
+
on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
|
| 151 |
+
Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
|
| 152 |
+
"""
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
def on_before_forward(self, iteration: int = 0) -> None:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def on_after_forward(self, iteration: int = 0) -> None:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
def on_before_backward(
|
| 162 |
+
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 163 |
+
) -> None:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
def on_optimizer_init_start(self) -> None:
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
def on_optimizer_init_end(self) -> None:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
def on_before_optimizer_step(
|
| 182 |
+
self,
|
| 183 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 184 |
+
optimizer: torch.optim.Optimizer,
|
| 185 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 186 |
+
grad_scaler: torch.amp.GradScaler,
|
| 187 |
+
iteration: int = 0,
|
| 188 |
+
) -> None:
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
def on_before_zero_grad(
|
| 192 |
+
self,
|
| 193 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 194 |
+
optimizer: torch.optim.Optimizer,
|
| 195 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 196 |
+
iteration: int = 0,
|
| 197 |
+
) -> None:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
def on_training_step_batch_end(
|
| 201 |
+
self,
|
| 202 |
+
model: ImaginaireModel,
|
| 203 |
+
data_batch: dict[str, torch.Tensor],
|
| 204 |
+
output_batch: dict[str, torch.Tensor],
|
| 205 |
+
loss: torch.Tensor,
|
| 206 |
+
iteration: int = 0,
|
| 207 |
+
) -> None:
|
| 208 |
+
"""
|
| 209 |
+
Called at the end of a training step for every batch even when using gradient accumulation.
|
| 210 |
+
This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
|
| 211 |
+
and therefore it may be the same for multiple batches.
|
| 212 |
+
"""
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
def on_training_step_end(
|
| 216 |
+
self,
|
| 217 |
+
model: ImaginaireModel,
|
| 218 |
+
data_batch: dict[str, torch.Tensor],
|
| 219 |
+
output_batch: dict[str, torch.Tensor],
|
| 220 |
+
loss: torch.Tensor,
|
| 221 |
+
iteration: int = 0,
|
| 222 |
+
) -> None:
|
| 223 |
+
"""
|
| 224 |
+
Called at the end of a training step, but note that when using gradient accumulation, this is only called
|
| 225 |
+
when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
|
| 226 |
+
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 227 |
+
for every batch.
|
| 228 |
+
"""
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
def on_validation_start(
|
| 232 |
+
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 233 |
+
) -> None:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
def on_validation_step_start(
|
| 237 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 238 |
+
) -> None:
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
def on_validation_step_end(
|
| 242 |
+
self,
|
| 243 |
+
model: ImaginaireModel,
|
| 244 |
+
data_batch: dict[str, torch.Tensor],
|
| 245 |
+
output_batch: dict[str, torch.Tensor],
|
| 246 |
+
loss: torch.Tensor,
|
| 247 |
+
iteration: int = 0,
|
| 248 |
+
) -> None:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 252 |
+
pass
|
| 253 |
+
|
| 254 |
+
def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
def on_load_checkpoint_end(
|
| 258 |
+
self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: str | None = None
|
| 259 |
+
) -> None:
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 266 |
+
"""
|
| 267 |
+
Called when checkpoint saving is about to start.
|
| 268 |
+
"""
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 272 |
+
"""
|
| 273 |
+
Called when the synchronous part of checkpointing is finished, this function can be used
|
| 274 |
+
along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
|
| 275 |
+
Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
|
| 276 |
+
does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
|
| 277 |
+
for that.
|
| 278 |
+
"""
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
|
| 282 |
+
"""
|
| 283 |
+
Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
|
| 284 |
+
For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
|
| 285 |
+
checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
|
| 286 |
+
checkpointing, this function is called as soon as the notification is received from the checkpointer process,
|
| 287 |
+
which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
|
| 288 |
+
the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
|
| 289 |
+
as this would be a significant overestimate.
|
| 290 |
+
"""
|
| 291 |
+
pass
|
| 292 |
+
|
| 293 |
+
def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 294 |
+
pass
|
| 295 |
+
|
| 296 |
+
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
def on_app_end(self) -> None:
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class EMAModelCallback(Callback):
|
| 304 |
+
"""The callback class for tracking EMA model weights."""
|
| 305 |
+
|
| 306 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 307 |
+
# Set up the EMA model weight tracker.
|
| 308 |
+
if model.config.ema.enabled:
|
| 309 |
+
assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
|
| 310 |
+
# EMA model must be kept in FP32 precision.
|
| 311 |
+
model.ema = model.ema.to(dtype=torch.float32)
|
| 312 |
+
else:
|
| 313 |
+
assert not hasattr(model, "ema"), "There should be no EMA initialized."
|
| 314 |
+
|
| 315 |
+
def on_training_step_end(
|
| 316 |
+
self,
|
| 317 |
+
model: ImaginaireModel,
|
| 318 |
+
data_batch: dict[str, torch.Tensor],
|
| 319 |
+
output_batch: dict[str, torch.Tensor],
|
| 320 |
+
loss: torch.Tensor,
|
| 321 |
+
iteration: int = 0,
|
| 322 |
+
) -> None:
|
| 323 |
+
# Update the EMA model with the new regular weights.
|
| 324 |
+
if model.config.ema.enabled:
|
| 325 |
+
model.ema.update_average(model, iteration)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ProgressBarCallback(Callback):
|
| 329 |
+
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 330 |
+
|
| 331 |
+
@distributed.rank0_only
|
| 332 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 333 |
+
self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 334 |
+
|
| 335 |
+
@distributed.rank0_only
|
| 336 |
+
def on_training_step_end(
|
| 337 |
+
self,
|
| 338 |
+
model: ImaginaireModel,
|
| 339 |
+
data_batch: dict[str, torch.Tensor],
|
| 340 |
+
output_batch: dict[str, torch.Tensor],
|
| 341 |
+
loss: torch.Tensor,
|
| 342 |
+
iteration: int = 0,
|
| 343 |
+
) -> None:
|
| 344 |
+
self.train_pbar.update()
|
| 345 |
+
|
| 346 |
+
@distributed.rank0_only
|
| 347 |
+
def on_validation_start(
|
| 348 |
+
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 349 |
+
) -> None:
|
| 350 |
+
if self.config.trainer.max_val_iter is not None:
|
| 351 |
+
num_iter = self.config.trainer.max_val_iter
|
| 352 |
+
else:
|
| 353 |
+
num_iter = len(dataloader_val)
|
| 354 |
+
assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
|
| 355 |
+
self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
|
| 356 |
+
|
| 357 |
+
@distributed.rank0_only
|
| 358 |
+
def on_validation_step_end(
|
| 359 |
+
self,
|
| 360 |
+
model: ImaginaireModel,
|
| 361 |
+
data_batch: dict[str, torch.Tensor],
|
| 362 |
+
output_batch: dict[str, torch.Tensor],
|
| 363 |
+
loss: torch.Tensor,
|
| 364 |
+
iteration: int = 0,
|
| 365 |
+
) -> None:
|
| 366 |
+
self.val_pbar.update()
|
| 367 |
+
|
| 368 |
+
@distributed.rank0_only
|
| 369 |
+
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 370 |
+
self.val_pbar.close()
|
| 371 |
+
|
| 372 |
+
@distributed.rank0_only
|
| 373 |
+
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 374 |
+
self.trainer.checkpointer.finalize()
|
| 375 |
+
self.train_pbar.close()
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class IterationLoggerCallback(Callback):
|
| 379 |
+
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 380 |
+
|
| 381 |
+
@distributed.rank0_only
|
| 382 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 383 |
+
# self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 384 |
+
self.start_iteration_time = time.time()
|
| 385 |
+
self.elapsed_iteration_time = 0
|
| 386 |
+
|
| 387 |
+
@distributed.rank0_only
|
| 388 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 389 |
+
self.start_iteration_time = time.time()
|
| 390 |
+
|
| 391 |
+
@distributed.rank0_only
|
| 392 |
+
def on_training_step_end(
|
| 393 |
+
self,
|
| 394 |
+
model: ImaginaireModel,
|
| 395 |
+
data_batch: dict[str, torch.Tensor],
|
| 396 |
+
output_batch: dict[str, torch.Tensor],
|
| 397 |
+
loss: torch.Tensor,
|
| 398 |
+
iteration: int = 0,
|
| 399 |
+
) -> None:
|
| 400 |
+
self.elapsed_iteration_time += time.time() - self.start_iteration_time
|
| 401 |
+
|
| 402 |
+
if iteration % self.config.trainer.logging_iter == 0:
|
| 403 |
+
avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
|
| 404 |
+
log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
|
| 405 |
+
|
| 406 |
+
self.elapsed_iteration_time = 0
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class LowPrecisionCallback(Callback):
|
| 410 |
+
"""The callback class handling low precision training
|
| 411 |
+
|
| 412 |
+
Config with non-primitive type makes it difficult to override the option.
|
| 413 |
+
The callback gets precision from model.precision instead.
|
| 414 |
+
It also auto disabled when using fp32.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
|
| 418 |
+
self.update_iter = update_iter
|
| 419 |
+
|
| 420 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 421 |
+
assert model.precision in [
|
| 422 |
+
torch.bfloat16,
|
| 423 |
+
torch.float16,
|
| 424 |
+
torch.half,
|
| 425 |
+
], "LowPrecisionCallback must use a low precision dtype."
|
| 426 |
+
self.precision_type = model.precision
|
| 427 |
+
|
| 428 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 429 |
+
for k, v in data.items():
|
| 430 |
+
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 431 |
+
data[k] = v.to(dtype=self.precision_type)
|
| 432 |
+
|
| 433 |
+
def on_validation_step_start(
|
| 434 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 435 |
+
) -> None:
|
| 436 |
+
for k, v in data.items():
|
| 437 |
+
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 438 |
+
data[k] = v.to(dtype=self.precision_type)
|
| 439 |
+
|
| 440 |
+
def on_before_zero_grad(
|
| 441 |
+
self,
|
| 442 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 443 |
+
optimizer: torch.optim.Optimizer,
|
| 444 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 445 |
+
iteration: int = 0,
|
| 446 |
+
) -> None:
|
| 447 |
+
if iteration % self.update_iter == 0:
|
| 448 |
+
if getattr(optimizer, "master_weights", False):
|
| 449 |
+
params, master_params = [], []
|
| 450 |
+
for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master, strict=False):
|
| 451 |
+
for p, p_master in zip(group["params"], group_master["params"], strict=False):
|
| 452 |
+
params.append(get_local_tensor_if_DTensor(p.data))
|
| 453 |
+
master_params.append(p_master.data)
|
| 454 |
+
torch._foreach_copy_(params, master_params)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class NVTXCallback(Callback):
|
| 458 |
+
"""The callback for creating NVTX ranges"""
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
synchronize: bool = False,
|
| 463 |
+
config: Config | None = None,
|
| 464 |
+
trainer: ImaginaireTrainer | None = None,
|
| 465 |
+
):
|
| 466 |
+
super().__init__(config, trainer)
|
| 467 |
+
self.synchronize = synchronize
|
| 468 |
+
|
| 469 |
+
def on_before_forward(self, iteration: int = 0) -> None:
|
| 470 |
+
if self.synchronize:
|
| 471 |
+
torch.cuda.synchronize()
|
| 472 |
+
torch.cuda.nvtx.range_push("forward")
|
| 473 |
+
|
| 474 |
+
def on_after_forward(self, iteration: int = 0) -> None:
|
| 475 |
+
if self.synchronize:
|
| 476 |
+
torch.cuda.synchronize()
|
| 477 |
+
torch.cuda.nvtx.range_pop()
|
| 478 |
+
|
| 479 |
+
def on_before_backward(
|
| 480 |
+
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 481 |
+
) -> None:
|
| 482 |
+
if self.synchronize:
|
| 483 |
+
torch.cuda.synchronize()
|
| 484 |
+
torch.cuda.nvtx.range_push("backward")
|
| 485 |
+
|
| 486 |
+
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 487 |
+
if self.synchronize:
|
| 488 |
+
torch.cuda.synchronize()
|
| 489 |
+
torch.cuda.nvtx.range_pop()
|
| 490 |
+
|
| 491 |
+
def on_before_optimizer_step(
|
| 492 |
+
self,
|
| 493 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 494 |
+
optimizer: torch.optim.Optimizer,
|
| 495 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 496 |
+
grad_scaler: torch.amp.GradScaler,
|
| 497 |
+
iteration: int = 0,
|
| 498 |
+
) -> None:
|
| 499 |
+
if self.synchronize:
|
| 500 |
+
torch.cuda.synchronize()
|
| 501 |
+
torch.cuda.nvtx.range_push("optimizer_step")
|
| 502 |
+
|
| 503 |
+
def on_before_zero_grad(
|
| 504 |
+
self,
|
| 505 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 506 |
+
optimizer: torch.optim.Optimizer,
|
| 507 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 508 |
+
iteration: int = 0,
|
| 509 |
+
) -> None:
|
| 510 |
+
if self.synchronize:
|
| 511 |
+
torch.cuda.synchronize()
|
| 512 |
+
torch.cuda.nvtx.range_pop()
|
| 513 |
+
|
| 514 |
+
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 515 |
+
torch.cuda.nvtx.range_push("dataloading")
|
| 516 |
+
|
| 517 |
+
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 518 |
+
torch.cuda.nvtx.range_pop()
|
imaginaire/utils/checkpointer.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import threading
|
| 20 |
+
from typing import TYPE_CHECKING, NamedTuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from imaginaire.model import ImaginaireModel
|
| 27 |
+
from imaginaire.utils import callback, distributed, log, misc
|
| 28 |
+
from imaginaire.utils.parallelism import ModelWrapper
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from imaginaire.config import CheckpointConfig, JobConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Checkpointer:
|
| 35 |
+
"""The checkpointer class. Supports checkpoint saving/loading to local disk."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
|
| 38 |
+
"""Constructor of the checkpointer.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
config_checkpoint (CheckpointConfig): The config object for the checkpointer.
|
| 42 |
+
"""
|
| 43 |
+
# Set the callback functions.
|
| 44 |
+
self.callbacks = callbacks
|
| 45 |
+
self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
|
| 46 |
+
self.strict_resume = config_checkpoint.strict_resume
|
| 47 |
+
self.load_path = config_checkpoint.load_path or None
|
| 48 |
+
self.load_training_state = config_checkpoint.load_training_state
|
| 49 |
+
self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
|
| 50 |
+
self.save_thread = None
|
| 51 |
+
|
| 52 |
+
def save(
|
| 53 |
+
self,
|
| 54 |
+
model: ImaginaireModel,
|
| 55 |
+
optimizer: torch.optim.Optimizer,
|
| 56 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 57 |
+
grad_scaler: torch.amp.GradScaler,
|
| 58 |
+
iteration: int,
|
| 59 |
+
) -> None:
|
| 60 |
+
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model (ImaginaireModel): The PyTorch model.
|
| 64 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 65 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 66 |
+
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 67 |
+
iteration (int): Current iteration number.
|
| 68 |
+
"""
|
| 69 |
+
self.callbacks.on_save_checkpoint_start(model, iteration)
|
| 70 |
+
|
| 71 |
+
checkpoint_file = f"iter_{iteration:09}.pt"
|
| 72 |
+
|
| 73 |
+
if distributed.get_rank() == 0:
|
| 74 |
+
state_dict = dict(
|
| 75 |
+
model=model.state_dict(),
|
| 76 |
+
optimizer=optimizer.state_dict(),
|
| 77 |
+
scheduler=scheduler.state_dict(),
|
| 78 |
+
grad_scaler=grad_scaler.state_dict(),
|
| 79 |
+
iteration=iteration,
|
| 80 |
+
)
|
| 81 |
+
state_dict = misc.to(state_dict, device="cpu")
|
| 82 |
+
self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
|
| 83 |
+
# Wait for previous saver thread to end.
|
| 84 |
+
if self.save_thread:
|
| 85 |
+
self.save_thread.join()
|
| 86 |
+
# Run the checkpoint saver in a separate thread.
|
| 87 |
+
self.save_thread = threading.Thread(
|
| 88 |
+
target=self._save_worker_local,
|
| 89 |
+
daemon=False,
|
| 90 |
+
args=(state_dict, checkpoint_file, distributed.get_rank()),
|
| 91 |
+
)
|
| 92 |
+
self.save_thread.start()
|
| 93 |
+
|
| 94 |
+
# Note: Checkpoints are saved on a separate thread and this callback is not accurate.
|
| 95 |
+
# Please check logs from on_save_checkpoint_success() for better accuracy
|
| 96 |
+
self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
|
| 97 |
+
|
| 98 |
+
@misc.timer("checkpoint saving (local)")
|
| 99 |
+
def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
|
| 100 |
+
"""Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
|
| 104 |
+
checkpoint_file (str): The file name of the model checkpoint.
|
| 105 |
+
rank (int): GPU device (default: 0).
|
| 106 |
+
"""
|
| 107 |
+
checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
|
| 108 |
+
os.makedirs(self.checkpoint_dir_local, exist_ok=True)
|
| 109 |
+
try:
|
| 110 |
+
torch.save(state_dict, checkpoint_path)
|
| 111 |
+
if rank == 0:
|
| 112 |
+
self._write_latest_checkpoint_file(checkpoint_file)
|
| 113 |
+
log.success(f"Saved checkpoint (local): {checkpoint_path}")
|
| 114 |
+
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
|
| 115 |
+
self.callbacks.on_save_checkpoint_success(iteration=iteration)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
log.exception(f"Checkpoint failed to save (local): {e}")
|
| 118 |
+
|
| 119 |
+
@misc.timer("checkpoint loading")
|
| 120 |
+
def load(
|
| 121 |
+
self,
|
| 122 |
+
model: ImaginaireModel,
|
| 123 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 124 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
|
| 125 |
+
grad_scaler: torch.amp.GradScaler | None = None,
|
| 126 |
+
) -> int:
|
| 127 |
+
"""Load network weights and optimizer states from a checkpoint in a single process.
|
| 128 |
+
|
| 129 |
+
The priority of the checkpoint loading logic is:
|
| 130 |
+
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
|
| 131 |
+
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
|
| 132 |
+
- This is typically used for inference mode.
|
| 133 |
+
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
|
| 134 |
+
3. If none of the above, randomly initialize the model parameters and train from scratch.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
model (ImaginaireModel): The PyTorch model.
|
| 138 |
+
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
|
| 139 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
|
| 140 |
+
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
iteration (int): the iteration number to start/resume from.
|
| 144 |
+
"""
|
| 145 |
+
self.callbacks.on_load_checkpoint_start(model)
|
| 146 |
+
|
| 147 |
+
latest_checkpoint_file = self._read_latest_checkpoint_file()
|
| 148 |
+
if latest_checkpoint_file is not None:
|
| 149 |
+
# 1. Resume training from latest_checkpoint.txt under the same name.
|
| 150 |
+
checkpoint_dir = self.checkpoint_dir_local
|
| 151 |
+
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
|
| 152 |
+
resume = True
|
| 153 |
+
only_resume_scheduler = True
|
| 154 |
+
else:
|
| 155 |
+
if self.load_path:
|
| 156 |
+
# 2. Load the module weights specified by config_checkpoint.path.
|
| 157 |
+
checkpoint_path = self.load_path
|
| 158 |
+
resume = self.load_training_state
|
| 159 |
+
only_resume_scheduler = self.only_load_scheduler_state
|
| 160 |
+
else:
|
| 161 |
+
# 3. Randomly initialize the model parameters and train from scratch.
|
| 162 |
+
checkpoint_path = None
|
| 163 |
+
resume = False
|
| 164 |
+
only_resume_scheduler = False
|
| 165 |
+
# Load checkpoint.
|
| 166 |
+
if checkpoint_path is not None:
|
| 167 |
+
self._check_checkpoint_exists(checkpoint_path)
|
| 168 |
+
log.info(f"Loading checkpoint (local): {checkpoint_path}")
|
| 169 |
+
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
| 170 |
+
log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
|
| 171 |
+
self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
|
| 172 |
+
# Load the state dicts.
|
| 173 |
+
log.info("- Loading the model...")
|
| 174 |
+
model.load_state_dict(state_dict["model"], strict=self.strict_resume)
|
| 175 |
+
if resume or only_resume_scheduler:
|
| 176 |
+
iteration = state_dict["iteration"]
|
| 177 |
+
assert scheduler
|
| 178 |
+
log.info("- Loading the scheduler...")
|
| 179 |
+
scheduler.load_state_dict(state_dict["scheduler"])
|
| 180 |
+
scheduler.last_epoch = iteration
|
| 181 |
+
else:
|
| 182 |
+
iteration = 0
|
| 183 |
+
if resume:
|
| 184 |
+
assert optimizer
|
| 185 |
+
log.info("- Loading the optimizer...")
|
| 186 |
+
optimizer.load_state_dict(state_dict["optimizer"])
|
| 187 |
+
log.info("- Loading the gradient scaler...")
|
| 188 |
+
grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
| 189 |
+
log.success(f"Done with loading the checkpoint (iteration {iteration}).")
|
| 190 |
+
else:
|
| 191 |
+
log.success("Done with loading the checkpoint.")
|
| 192 |
+
else:
|
| 193 |
+
# Checkpoint not found and not specified. We will train everything from scratch.
|
| 194 |
+
iteration = 0
|
| 195 |
+
log.info("Training from scratch.")
|
| 196 |
+
torch.cuda.empty_cache()
|
| 197 |
+
|
| 198 |
+
self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
|
| 199 |
+
|
| 200 |
+
return iteration
|
| 201 |
+
|
| 202 |
+
def _read_latest_checkpoint_file(self) -> str | None:
|
| 203 |
+
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
checkpoint_file (str | None): file name of the latest saved checkpoint.
|
| 207 |
+
"""
|
| 208 |
+
checkpoint_file = None
|
| 209 |
+
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 210 |
+
if os.path.isfile(latest_path):
|
| 211 |
+
checkpoint_file = open(latest_path).read().strip()
|
| 212 |
+
return checkpoint_file
|
| 213 |
+
|
| 214 |
+
def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
|
| 215 |
+
"""Track the file name of the latest saved checkpoint.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
checkpoint_file (str): file name of the latest saved checkpoint.
|
| 219 |
+
"""
|
| 220 |
+
content = f"{checkpoint_file}\n"
|
| 221 |
+
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 222 |
+
with open(latest_path, "w") as file:
|
| 223 |
+
file.write(content)
|
| 224 |
+
|
| 225 |
+
def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
|
| 226 |
+
"""If the file checkpoint_path does not exist, raise an error.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
checkpoint_path (str): full path to the checkpoint.
|
| 230 |
+
"""
|
| 231 |
+
if not os.path.exists(checkpoint_path):
|
| 232 |
+
raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
|
| 233 |
+
|
| 234 |
+
def finalize(self) -> None:
|
| 235 |
+
"""Finalize the checkpointer."""
|
| 236 |
+
if self.save_thread:
|
| 237 |
+
self.save_thread.join()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class _IncompatibleKeys(
|
| 241 |
+
NamedTuple(
|
| 242 |
+
"IncompatibleKeys",
|
| 243 |
+
[
|
| 244 |
+
("missing_keys", list[str]),
|
| 245 |
+
("unexpected_keys", list[str]),
|
| 246 |
+
("incorrect_shapes", list[tuple[str, tuple[int], tuple[int]]]),
|
| 247 |
+
],
|
| 248 |
+
)
|
| 249 |
+
):
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def load_checkpoint(
|
| 254 |
+
model_parts: list[nn.Module],
|
| 255 |
+
ckpt_dir,
|
| 256 |
+
model_ckpt_key_map: dict[str, str] = {}, # noqa: B006
|
| 257 |
+
):
|
| 258 |
+
log.info(f"Loading checkpoint from {ckpt_dir}.")
|
| 259 |
+
|
| 260 |
+
_model_wrapper = ModelWrapper(model_parts)
|
| 261 |
+
state_dict = _model_wrapper.state_dict()
|
| 262 |
+
# remove _extra_state
|
| 263 |
+
state_dict = {k: v for k, v in state_dict.items() if not k.endswith("._extra_state")}
|
| 264 |
+
|
| 265 |
+
# remap keys if needed
|
| 266 |
+
if model_ckpt_key_map:
|
| 267 |
+
for model_key, checkpoint_key in model_ckpt_key_map.items():
|
| 268 |
+
state_dict[checkpoint_key] = state_dict.pop(model_key)
|
| 269 |
+
log.info(f"Re-mapping {model_key} to {checkpoint_key}")
|
| 270 |
+
|
| 271 |
+
fs_storage_reader = dist.checkpoint.FileSystemReader(ckpt_dir)
|
| 272 |
+
dist.checkpoint.load(state_dict=state_dict, storage_reader=fs_storage_reader)
|
| 273 |
+
|
| 274 |
+
# inverse the remapping if needed
|
| 275 |
+
if model_ckpt_key_map:
|
| 276 |
+
for model_key, checkpoint_key in model_ckpt_key_map.items():
|
| 277 |
+
state_dict[model_key] = state_dict.pop(checkpoint_key)
|
| 278 |
+
log.info(f"Inverse re-mapping {checkpoint_key} to {model_key}")
|
| 279 |
+
|
| 280 |
+
_model_wrapper.load_state_dict(state_dict)
|
| 281 |
+
|
| 282 |
+
log.info(f"Finished loading checkpoint from {ckpt_dir}.")
|
imaginaire/utils/config_helper.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import importlib
|
| 17 |
+
import os
|
| 18 |
+
import pkgutil
|
| 19 |
+
import sys
|
| 20 |
+
from dataclasses import fields as dataclass_fields
|
| 21 |
+
from dataclasses import is_dataclass
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
import attr
|
| 25 |
+
import attrs
|
| 26 |
+
from hydra import compose, initialize
|
| 27 |
+
from hydra.core.config_store import ConfigStore
|
| 28 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 29 |
+
from omegaconf import DictConfig, OmegaConf
|
| 30 |
+
|
| 31 |
+
from imaginaire.config import Config
|
| 32 |
+
from imaginaire.utils import log
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_attrs_or_dataclass(obj) -> bool:
|
| 36 |
+
"""
|
| 37 |
+
Check if the object is an instance of an attrs class or a dataclass.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
obj: The object to check.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
|
| 44 |
+
"""
|
| 45 |
+
return is_dataclass(obj) or attr.has(type(obj))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_fields(obj):
|
| 49 |
+
"""
|
| 50 |
+
Get the fields of an attrs class or a dataclass.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
list: A list of field names.
|
| 57 |
+
|
| 58 |
+
Raises:
|
| 59 |
+
ValueError: If the object is neither an attrs class nor a dataclass.
|
| 60 |
+
"""
|
| 61 |
+
if is_dataclass(obj):
|
| 62 |
+
return [field.name for field in dataclass_fields(obj)]
|
| 63 |
+
elif attr.has(type(obj)):
|
| 64 |
+
return [field.name for field in attr.fields(type(obj))]
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("The object is neither an attrs class nor a dataclass.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def override(config: Config, overrides: list[str] | None = None) -> Config:
|
| 70 |
+
"""
|
| 71 |
+
:param config: the instance of class `Config` (usually from `make_config`)
|
| 72 |
+
:param overrides: list of overrides for config
|
| 73 |
+
:return: the composed instance of class `Config`
|
| 74 |
+
"""
|
| 75 |
+
# Store the class of the config for reconstruction after overriding.
|
| 76 |
+
# config_class = type(config)
|
| 77 |
+
|
| 78 |
+
# Convert Config object to a DictConfig object
|
| 79 |
+
config_dict = attrs.asdict(config)
|
| 80 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 81 |
+
# Enforce "--" separator between the script arguments and overriding configs.
|
| 82 |
+
if overrides:
|
| 83 |
+
if overrides[0] != "--":
|
| 84 |
+
raise ValueError('Hydra config overrides must be separated with a "--" token.')
|
| 85 |
+
overrides = overrides[1:]
|
| 86 |
+
# Use Hydra to handle overrides
|
| 87 |
+
cs = ConfigStore.instance()
|
| 88 |
+
cs.store(name="config", node=config_omegaconf)
|
| 89 |
+
if not GlobalHydra().is_initialized():
|
| 90 |
+
with initialize(version_base=None):
|
| 91 |
+
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 92 |
+
OmegaConf.resolve(config_omegaconf)
|
| 93 |
+
else:
|
| 94 |
+
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 95 |
+
OmegaConf.resolve(config_omegaconf)
|
| 96 |
+
|
| 97 |
+
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
|
| 98 |
+
"""
|
| 99 |
+
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
ref_instance: The reference instance to determine the type and fields when needed
|
| 103 |
+
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
AssertionError: If the fields do not match or if extra keys are found.
|
| 110 |
+
Exception: If there is an error constructing the new instance.
|
| 111 |
+
"""
|
| 112 |
+
is_type = is_attrs_or_dataclass(ref_instance)
|
| 113 |
+
if not is_type:
|
| 114 |
+
return kwargs
|
| 115 |
+
else:
|
| 116 |
+
ref_fields = set(get_fields(ref_instance))
|
| 117 |
+
assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
|
| 118 |
+
"kwargs must be a dictionary or a DictConfig"
|
| 119 |
+
)
|
| 120 |
+
keys = set(kwargs.keys())
|
| 121 |
+
|
| 122 |
+
# ref_fields must equal to or include all keys
|
| 123 |
+
extra_keys = keys - ref_fields
|
| 124 |
+
assert ref_fields == keys or keys.issubset(ref_fields), (
|
| 125 |
+
f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
resolved_kwargs: dict[str, Any] = {}
|
| 129 |
+
for f in keys:
|
| 130 |
+
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
|
| 131 |
+
try:
|
| 132 |
+
new_instance = type(ref_instance)(**resolved_kwargs)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
|
| 135 |
+
log.error(e)
|
| 136 |
+
raise e
|
| 137 |
+
return new_instance
|
| 138 |
+
|
| 139 |
+
config = config_from_dict(config, config_omegaconf)
|
| 140 |
+
|
| 141 |
+
return config
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_config_module(config_file: str) -> str:
|
| 145 |
+
if not config_file.endswith(".py"):
|
| 146 |
+
log.error("Config file cannot be specified as module.")
|
| 147 |
+
log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
|
| 148 |
+
assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
|
| 149 |
+
# Convert to importable module format.
|
| 150 |
+
config_module = config_file.replace("/", ".").replace(".py", "")
|
| 151 |
+
return config_module
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
|
| 155 |
+
"""
|
| 156 |
+
Import all modules from the specified package path recursively.
|
| 157 |
+
|
| 158 |
+
This function is typically used in conjunction with Hydra to ensure that all modules
|
| 159 |
+
within a specified package are imported, which is necessary for registering configurations.
|
| 160 |
+
|
| 161 |
+
Example usage:
|
| 162 |
+
```python
|
| 163 |
+
import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
package_path (str): The dotted path to the package from which to import all modules.
|
| 168 |
+
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
| 169 |
+
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
| 170 |
+
"""
|
| 171 |
+
log.critical(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
| 172 |
+
package = importlib.import_module(package_path)
|
| 173 |
+
package_directory = package.__path__
|
| 174 |
+
|
| 175 |
+
def import_modules_recursively(directory: str, prefix: str) -> None:
|
| 176 |
+
"""
|
| 177 |
+
Recursively imports or reloads all modules in the given directory.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
directory (str): The file system path to the current package directory.
|
| 181 |
+
prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
|
| 182 |
+
"""
|
| 183 |
+
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
|
| 184 |
+
if skip_underscore and module_name.startswith("_"):
|
| 185 |
+
log.debug(f"Skipping module {module_name} as it starts with an underscore")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
full_module_name = f"{prefix}.{module_name}"
|
| 189 |
+
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
|
| 190 |
+
|
| 191 |
+
if full_module_name in sys.modules and reload:
|
| 192 |
+
importlib.reload(sys.modules[full_module_name])
|
| 193 |
+
else:
|
| 194 |
+
importlib.import_module(full_module_name)
|
| 195 |
+
|
| 196 |
+
if is_pkg:
|
| 197 |
+
sub_package_directory = os.path.join(directory, module_name)
|
| 198 |
+
import_modules_recursively(sub_package_directory, full_module_name)
|
| 199 |
+
|
| 200 |
+
for directory in package_directory:
|
| 201 |
+
import_modules_recursively(directory, package_path)
|
imaginaire/utils/device.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import pynvml
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Device:
|
| 23 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
|
| 24 |
+
|
| 25 |
+
def __init__(self, device_idx: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
| 28 |
+
|
| 29 |
+
def get_name(self) -> str:
|
| 30 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
| 31 |
+
|
| 32 |
+
def get_cpu_affinity(self) -> list[int]:
|
| 33 |
+
affinity_string = ""
|
| 34 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
| 35 |
+
# assume nvml returns list of 64 bit ints
|
| 36 |
+
affinity_string = f"{j:064b}" + affinity_string
|
| 37 |
+
affinity_list = [int(x) for x in affinity_string]
|
| 38 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
| 39 |
+
return [i for i, e in enumerate(affinity_list) if e != 0]
|
imaginaire/utils/distributed.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import collections.abc
|
| 20 |
+
import ctypes
|
| 21 |
+
import functools
|
| 22 |
+
import os
|
| 23 |
+
from collections.abc import Callable, Container
|
| 24 |
+
from contextlib import contextmanager
|
| 25 |
+
from datetime import timedelta
|
| 26 |
+
from typing import TYPE_CHECKING, Any
|
| 27 |
+
|
| 28 |
+
import pynvml
|
| 29 |
+
import torch
|
| 30 |
+
import torch.distributed as dist
|
| 31 |
+
from torch.distributed import get_process_group_ranks
|
| 32 |
+
|
| 33 |
+
from imaginaire.utils.device import Device
|
| 34 |
+
|
| 35 |
+
if dist.is_available():
|
| 36 |
+
from torch.distributed.distributed_c10d import _get_default_group
|
| 37 |
+
from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
|
| 38 |
+
|
| 39 |
+
from imaginaire.utils import log
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING:
|
| 42 |
+
from imaginaire.config import DDPConfig
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from megatron.core import parallel_state
|
| 46 |
+
except ImportError:
|
| 47 |
+
print("Megatron-core is not installed.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init() -> int | None:
|
| 51 |
+
"""Initialize distributed training."""
|
| 52 |
+
if dist.is_initialized():
|
| 53 |
+
return torch.cuda.current_device()
|
| 54 |
+
|
| 55 |
+
# Set GPU affinity.
|
| 56 |
+
pynvml.nvmlInit()
|
| 57 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 58 |
+
try:
|
| 59 |
+
device = Device(local_rank)
|
| 60 |
+
os.sched_setaffinity(0, device.get_cpu_affinity())
|
| 61 |
+
except (OSError, pynvml.NVMLError) as e:
|
| 62 |
+
log.warning(f"Failed to set device affinity: {e}")
|
| 63 |
+
# Set up NCCL communication.
|
| 64 |
+
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
|
| 65 |
+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
| 66 |
+
if dist.is_available():
|
| 67 |
+
torch.cuda.set_device(local_rank)
|
| 68 |
+
# Get the timeout value from environment variable
|
| 69 |
+
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
|
| 70 |
+
# Convert the timeout to an integer (if it isn't already) and then to a timedelta
|
| 71 |
+
timeout_timedelta = timedelta(seconds=int(timeout_seconds))
|
| 72 |
+
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
|
| 73 |
+
log.info(
|
| 74 |
+
f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
|
| 75 |
+
rank0_only=False,
|
| 76 |
+
)
|
| 77 |
+
# Increase the L2 fetch granularity for faster speed.
|
| 78 |
+
_libcudart = ctypes.CDLL("libcudart.so")
|
| 79 |
+
# Set device limit on the current device.
|
| 80 |
+
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
| 81 |
+
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
| 82 |
+
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
|
| 83 |
+
log.info(f"Training with {get_world_size()} GPUs.")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_rank(group: dist.ProcessGroup | None = None) -> int:
|
| 87 |
+
"""Get the rank (GPU device) of the worker.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
rank (int): The rank of the worker.
|
| 91 |
+
"""
|
| 92 |
+
rank = 0
|
| 93 |
+
if dist.is_available() and dist.is_initialized():
|
| 94 |
+
rank = dist.get_rank(group)
|
| 95 |
+
return rank
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_world_size(group: dist.ProcessGroup | None = None) -> int:
|
| 99 |
+
"""Get world size. How many GPUs are available in this job.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
world_size (int): The total number of GPUs available in this job.
|
| 103 |
+
"""
|
| 104 |
+
world_size = 1
|
| 105 |
+
if dist.is_available() and dist.is_initialized():
|
| 106 |
+
world_size = dist.get_world_size(group)
|
| 107 |
+
return world_size
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def is_rank0() -> bool:
|
| 111 |
+
"""Check if current process is the master GPU.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
(bool): True if this function is called from the master GPU, else False.
|
| 115 |
+
"""
|
| 116 |
+
return get_rank() == 0
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def is_local_rank0() -> bool:
|
| 120 |
+
"""Check if current process is the local master GPU in the current node.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
(bool): True if this function is called from the local master GPU, else False.
|
| 124 |
+
"""
|
| 125 |
+
return torch.cuda.current_device() == 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def rank0_only(func: Callable) -> Callable:
|
| 129 |
+
"""Apply this function only to the master GPU.
|
| 130 |
+
|
| 131 |
+
Example usage:
|
| 132 |
+
@rank0_only
|
| 133 |
+
def func(x):
|
| 134 |
+
return x + 3
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
func (Callable): a function.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
(Callable): A function wrapper executing the function only on the master GPU.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
@functools.wraps(func)
|
| 144 |
+
def wrapper(*args, **kwargs):
|
| 145 |
+
if is_rank0():
|
| 146 |
+
return func(*args, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
return wrapper
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def barrier() -> None:
|
| 154 |
+
"""Barrier for all GPUs."""
|
| 155 |
+
if dist.is_available() and dist.is_initialized():
|
| 156 |
+
dist.barrier()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def rank0_first(func: Callable) -> Callable:
|
| 160 |
+
"""run the function on rank 0 first, then on other ranks."""
|
| 161 |
+
|
| 162 |
+
@functools.wraps(func)
|
| 163 |
+
def wrapper(*args, **kwargs):
|
| 164 |
+
if is_rank0():
|
| 165 |
+
result = func(*args, **kwargs)
|
| 166 |
+
barrier()
|
| 167 |
+
if not is_rank0():
|
| 168 |
+
result = func(*args, **kwargs)
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
return wrapper
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
|
| 175 |
+
"""Wraps the model to enable data parallalism for training across multiple GPU devices.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
config_ddp (DDPConfig): The data parallel config.
|
| 179 |
+
model (torch.nn.Module): The PyTorch module.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
|
| 183 |
+
if distributed environment is available, otherwise return the original model.
|
| 184 |
+
"""
|
| 185 |
+
if dist.is_available() and dist.is_initialized():
|
| 186 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 187 |
+
try:
|
| 188 |
+
ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
log.info(e)
|
| 191 |
+
log.info("parallel_state not initialized, treating all GPUs equally for DDP")
|
| 192 |
+
ddp_group = None
|
| 193 |
+
|
| 194 |
+
model = DistributedDataParallel(
|
| 195 |
+
model,
|
| 196 |
+
device_ids=[local_rank],
|
| 197 |
+
output_device=local_rank,
|
| 198 |
+
find_unused_parameters=config_ddp.find_unused_parameters,
|
| 199 |
+
static_graph=config_ddp.static_graph,
|
| 200 |
+
broadcast_buffers=config_ddp.broadcast_buffers,
|
| 201 |
+
process_group=ddp_group,
|
| 202 |
+
)
|
| 203 |
+
return model
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
|
| 207 |
+
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step().
|
| 208 |
+
|
| 209 |
+
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
|
| 210 |
+
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
|
| 211 |
+
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
|
| 212 |
+
training_step), allowing us to preserve the function names and signatures.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, model: torch.nn.Module, *args, **kwargs):
|
| 216 |
+
super().__init__(model, *args, **kwargs)
|
| 217 |
+
self.show_sync_grad_static_graph_warning = True
|
| 218 |
+
|
| 219 |
+
def training_step(self, *args, **kwargs) -> Any:
|
| 220 |
+
# Cache the original model.forward() method.
|
| 221 |
+
original_forward = self.module.forward
|
| 222 |
+
|
| 223 |
+
def wrapped_training_step(*_args, **_kwargs):
|
| 224 |
+
# Unpatch immediately before calling training_step() because itself may want to call the real forward.
|
| 225 |
+
self.module.forward = original_forward
|
| 226 |
+
# The actual .training_step().
|
| 227 |
+
return self.module.training_step(*_args, **_kwargs)
|
| 228 |
+
|
| 229 |
+
# Patch the original_module's forward so we can redirect the arguments back to the real method.
|
| 230 |
+
self.module.forward = wrapped_training_step
|
| 231 |
+
# Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
|
| 232 |
+
# Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
|
| 233 |
+
return self(*args, **kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@contextmanager
|
| 237 |
+
def ddp_sync_grad(model, enabled):
|
| 238 |
+
r"""
|
| 239 |
+
Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
|
| 240 |
+
Modified from:
|
| 241 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
|
| 242 |
+
Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
|
| 243 |
+
|
| 244 |
+
Within this context, gradients will be accumulated on module
|
| 245 |
+
variables, which will later be synchronized in the first
|
| 246 |
+
forward-backward pass exiting the context.
|
| 247 |
+
|
| 248 |
+
.. warning::
|
| 249 |
+
The forward pass should be included inside the context manager, or
|
| 250 |
+
else gradients will still be synchronized.
|
| 251 |
+
"""
|
| 252 |
+
assert isinstance(model, torch.nn.Module)
|
| 253 |
+
if isinstance(model, DistributedDataParallel):
|
| 254 |
+
old_require_backward_grad_sync = model.require_backward_grad_sync
|
| 255 |
+
if model.static_graph and model.require_backward_grad_sync != enabled:
|
| 256 |
+
if model.show_sync_grad_static_graph_warning:
|
| 257 |
+
log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
|
| 258 |
+
model.show_sync_grad_static_graph_warning = False
|
| 259 |
+
else:
|
| 260 |
+
model.require_backward_grad_sync = enabled
|
| 261 |
+
try:
|
| 262 |
+
yield
|
| 263 |
+
finally:
|
| 264 |
+
if isinstance(model, DistributedDataParallel):
|
| 265 |
+
model.require_backward_grad_sync = old_require_backward_grad_sync
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
|
| 269 |
+
"""Aggregate the list of data batches from all devices and process the results.
|
| 270 |
+
|
| 271 |
+
This is used for gathering validation data batches with imaginaire.utils.dataloader.DistributedEvalSampler.
|
| 272 |
+
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
|
| 273 |
+
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
|
| 274 |
+
created before calling dis.all_gather().
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
|
| 278 |
+
leaf entries are tensors.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
|
| 282 |
+
leaf entries are concatenated tensors.
|
| 283 |
+
"""
|
| 284 |
+
if isinstance(data_batches[0], torch.Tensor):
|
| 285 |
+
# Concatenate the local data batches.
|
| 286 |
+
data_concat = torch.cat(data_batches, dim=0) # type: ignore
|
| 287 |
+
# Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
|
| 288 |
+
max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
|
| 289 |
+
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
|
| 290 |
+
if len(data_concat) < max_num_local_samples:
|
| 291 |
+
assert len(data_concat) + 1 == max_num_local_samples
|
| 292 |
+
dummy = torch.empty_like(data_concat[:1])
|
| 293 |
+
data_concat = torch.cat([data_concat, dummy], dim=0)
|
| 294 |
+
dummy_count = torch.tensor(1, device="cuda")
|
| 295 |
+
else:
|
| 296 |
+
dummy_count = torch.tensor(0, device="cuda")
|
| 297 |
+
# Get all concatenated batches from all ranks and concatenate again.
|
| 298 |
+
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
|
| 299 |
+
data_concat = all_gather_tensor(data_concat.contiguous())
|
| 300 |
+
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
|
| 301 |
+
# Remove the dummy samples.
|
| 302 |
+
if dummy_count > 0:
|
| 303 |
+
data_collate = data_collate[:-dummy_count]
|
| 304 |
+
elif isinstance(data_batches[0], collections.abc.Mapping):
|
| 305 |
+
data_collate = dict()
|
| 306 |
+
for key in data_batches[0].keys():
|
| 307 |
+
data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
|
| 308 |
+
else:
|
| 309 |
+
raise TypeError
|
| 310 |
+
return data_collate
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@torch.no_grad()
|
| 314 |
+
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
|
| 315 |
+
"""Gather the corresponding tensor from all GPU devices to a list.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
tensor (torch.Tensor): Pytorch tensor.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
|
| 322 |
+
"""
|
| 323 |
+
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
| 324 |
+
dist.all_gather(tensor_list, tensor)
|
| 325 |
+
return tensor_list
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def broadcast(tensor, src, group=None, async_op=False):
|
| 329 |
+
world_size = get_world_size()
|
| 330 |
+
if world_size < 2:
|
| 331 |
+
return tensor
|
| 332 |
+
dist.broadcast(tensor, src=src, group=group, async_op=async_op)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
|
| 336 |
+
r"""Reduce to rank 0"""
|
| 337 |
+
world_size = get_world_size()
|
| 338 |
+
if world_size < 2:
|
| 339 |
+
return tensor
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
dist.reduce(tensor, dst=rank)
|
| 342 |
+
if get_rank() == rank:
|
| 343 |
+
if reduce == "mean":
|
| 344 |
+
tensor /= world_size
|
| 345 |
+
elif reduce == "sum":
|
| 346 |
+
pass
|
| 347 |
+
else:
|
| 348 |
+
raise NotImplementedError
|
| 349 |
+
return tensor
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def sync_model_states(
|
| 353 |
+
model: torch.nn.Module,
|
| 354 |
+
process_group: dist.ProcessGroup | None = None,
|
| 355 |
+
src: int = 0,
|
| 356 |
+
params_and_buffers_to_ignore: Container[str] | None = None,
|
| 357 |
+
broadcast_buffers: bool = True,
|
| 358 |
+
):
|
| 359 |
+
"""
|
| 360 |
+
Modify based on DDP source code
|
| 361 |
+
Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
|
| 362 |
+
|
| 363 |
+
This function ensures that all processes in the specified process group have the same initial parameters and
|
| 364 |
+
buffers from the source rank, typically rank 0. It is useful when different processes start with different model
|
| 365 |
+
states and a synchronization is required to ensure consistency across all ranks.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
model (nn.Module): The model whose parameters and buffers are to be synchronized.
|
| 369 |
+
process_group (dist.ProcessGroup, optional): The process group for communication. If None,
|
| 370 |
+
the default group is used. Defaults to None.
|
| 371 |
+
src (int, optional): The source rank from which parameters and buffers will be broadcasted.
|
| 372 |
+
Defaults to 0.
|
| 373 |
+
params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
|
| 374 |
+
names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
|
| 375 |
+
included.
|
| 376 |
+
broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
|
| 377 |
+
|
| 378 |
+
Side Effects:
|
| 379 |
+
This function modifies the state of the model in-place to synchronize it with the source rank's model state.
|
| 380 |
+
|
| 381 |
+
Raises:
|
| 382 |
+
RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
|
| 383 |
+
|
| 384 |
+
Examples:
|
| 385 |
+
>>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
|
| 386 |
+
>>> # useful and save our time when model weights are huge
|
| 387 |
+
>>> if dist.get_rank == 0:
|
| 388 |
+
>>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
|
| 389 |
+
>>> dist.barrir()
|
| 390 |
+
>>> sync_model_states(model) # sync rank0 weights to other ranks
|
| 391 |
+
"""
|
| 392 |
+
if not dist.is_available() or not dist.is_initialized():
|
| 393 |
+
return
|
| 394 |
+
if process_group is None:
|
| 395 |
+
process_group = _get_default_group()
|
| 396 |
+
if not params_and_buffers_to_ignore:
|
| 397 |
+
params_and_buffers_to_ignore = set()
|
| 398 |
+
|
| 399 |
+
log.info(
|
| 400 |
+
f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Build tuple of (module, parameter) for all parameters that require grads.
|
| 404 |
+
modules_and_parameters = [
|
| 405 |
+
(module, parameter)
|
| 406 |
+
for module_name, module in model.named_modules()
|
| 407 |
+
for parameter in [
|
| 408 |
+
param
|
| 409 |
+
# Note that we access module.named_parameters instead of
|
| 410 |
+
# parameters(module). parameters(module) is only needed in the
|
| 411 |
+
# single-process multi device case, where it accesses replicated
|
| 412 |
+
# parameters through _former_parameters.
|
| 413 |
+
for param_name, param in module.named_parameters(recurse=False)
|
| 414 |
+
if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 415 |
+
# if param.requires_grad
|
| 416 |
+
# and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 417 |
+
]
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
# Deduplicate any parameters that might be shared across child modules.
|
| 421 |
+
memo = set()
|
| 422 |
+
modules_and_parameters = [
|
| 423 |
+
# "p not in memo" is the deduplication check.
|
| 424 |
+
# "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
|
| 425 |
+
(m, p)
|
| 426 |
+
for m, p in modules_and_parameters
|
| 427 |
+
if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# Build list of parameters.
|
| 431 |
+
parameters = [parameter for _, parameter in modules_and_parameters]
|
| 432 |
+
if len(parameters) == 0:
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
_verify_param_shape_across_processes(process_group, parameters)
|
| 436 |
+
|
| 437 |
+
_sync_module_states(
|
| 438 |
+
module=model,
|
| 439 |
+
process_group=process_group,
|
| 440 |
+
broadcast_bucket_size=(250 * 1024 * 1024),
|
| 441 |
+
src=src,
|
| 442 |
+
params_and_buffers_to_ignore=params_and_buffers_to_ignore,
|
| 443 |
+
broadcast_buffers=broadcast_buffers,
|
| 444 |
+
)
|
imaginaire/utils/easy_io/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/utils/easy_io/backends/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 17 |
+
from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
|
| 18 |
+
from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
|
| 19 |
+
from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"BaseStorageBackend",
|
| 23 |
+
"HTTPBackend",
|
| 24 |
+
"LocalBackend",
|
| 25 |
+
"backends",
|
| 26 |
+
"prefix_to_backends",
|
| 27 |
+
"register_backend",
|
| 28 |
+
]
|
imaginaire/utils/easy_io/backends/base_backend.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import os.path as osp
|
| 18 |
+
from abc import ABCMeta, abstractmethod
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def mkdir_or_exist(dir_name, mode=0o777):
|
| 22 |
+
if dir_name == "":
|
| 23 |
+
return
|
| 24 |
+
dir_name = osp.expanduser(dir_name)
|
| 25 |
+
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def has_method(obj, method):
|
| 29 |
+
return hasattr(obj, method) and callable(getattr(obj, method))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
| 33 |
+
"""Abstract class of storage backends.
|
| 34 |
+
|
| 35 |
+
All backends need to implement two apis: :meth:`get()` and
|
| 36 |
+
:meth:`get_text()`.
|
| 37 |
+
|
| 38 |
+
- :meth:`get()` reads the file as a byte stream.
|
| 39 |
+
- :meth:`get_text()` reads the file as texts.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
# a flag to indicate whether the backend can create a symlink for a file
|
| 43 |
+
# This attribute will be deprecated in future.
|
| 44 |
+
_allow_symlink = False
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def allow_symlink(self):
|
| 48 |
+
return self._allow_symlink
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def name(self):
|
| 52 |
+
return self.__class__.__name__
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def get(self, filepath):
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def get_text(self, filepath):
|
| 60 |
+
pass
|
imaginaire/utils/easy_io/backends/http_backend.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import tempfile
|
| 18 |
+
from collections.abc import Generator
|
| 19 |
+
from contextlib import contextmanager
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from urllib.request import urlopen
|
| 22 |
+
|
| 23 |
+
from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HTTPBackend(BaseStorageBackend):
|
| 27 |
+
"""HTTP and HTTPS storage bachend."""
|
| 28 |
+
|
| 29 |
+
def get(self, filepath: str) -> bytes:
|
| 30 |
+
"""Read bytes from a given ``filepath``.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
filepath (str): Path to read data.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
bytes: Expected bytes object.
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
>>> backend = HTTPBackend()
|
| 40 |
+
>>> backend.get('http://path/of/file')
|
| 41 |
+
b'hello world'
|
| 42 |
+
"""
|
| 43 |
+
return urlopen(filepath).read()
|
| 44 |
+
|
| 45 |
+
def get_text(self, filepath, encoding="utf-8") -> str:
|
| 46 |
+
"""Read text from a given ``filepath``.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
filepath (str): Path to read data.
|
| 50 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 51 |
+
Defaults to 'utf-8'.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
str: Expected text reading from ``filepath``.
|
| 55 |
+
|
| 56 |
+
Examples:
|
| 57 |
+
>>> backend = HTTPBackend()
|
| 58 |
+
>>> backend.get_text('http://path/of/file')
|
| 59 |
+
'hello world'
|
| 60 |
+
"""
|
| 61 |
+
return urlopen(filepath).read().decode(encoding)
|
| 62 |
+
|
| 63 |
+
@contextmanager
|
| 64 |
+
def get_local_path(self, filepath: str) -> Generator[str | Path, None, None]:
|
| 65 |
+
"""Download a file from ``filepath`` to a local temporary directory,
|
| 66 |
+
and return the temporary path.
|
| 67 |
+
|
| 68 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 69 |
+
can be called with ``with`` statement, and when exists from the
|
| 70 |
+
``with`` statement, the temporary path will be released.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
filepath (str): Download a file from ``filepath``.
|
| 74 |
+
|
| 75 |
+
Yields:
|
| 76 |
+
Iterable[str]: Only yield one temporary path.
|
| 77 |
+
|
| 78 |
+
Examples:
|
| 79 |
+
>>> backend = HTTPBackend()
|
| 80 |
+
>>> # After existing from the ``with`` clause,
|
| 81 |
+
>>> # the path will be removed
|
| 82 |
+
>>> with backend.get_local_path('http://path/of/file') as path:
|
| 83 |
+
... # do something here
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
| 87 |
+
f.write(self.get(filepath))
|
| 88 |
+
f.close()
|
| 89 |
+
yield f.name
|
| 90 |
+
finally:
|
| 91 |
+
os.remove(f.name)
|
imaginaire/utils/easy_io/backends/local_backend.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import io
|
| 17 |
+
import os
|
| 18 |
+
import os.path as osp
|
| 19 |
+
import shutil
|
| 20 |
+
from collections.abc import Generator, Iterator
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LocalBackend(BaseStorageBackend):
|
| 28 |
+
"""Raw local storage backend."""
|
| 29 |
+
|
| 30 |
+
_allow_symlink = True
|
| 31 |
+
|
| 32 |
+
def get(self, filepath: str | Path) -> bytes:
|
| 33 |
+
"""Read bytes from a given ``filepath`` with 'rb' mode.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
filepath (str or Path): Path to read data.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
bytes: Expected bytes object.
|
| 40 |
+
|
| 41 |
+
Examples:
|
| 42 |
+
>>> backend = LocalBackend()
|
| 43 |
+
>>> filepath = '/path/of/file'
|
| 44 |
+
>>> backend.get(filepath)
|
| 45 |
+
b'hello world'
|
| 46 |
+
"""
|
| 47 |
+
with open(filepath, "rb") as f:
|
| 48 |
+
value = f.read()
|
| 49 |
+
return value
|
| 50 |
+
|
| 51 |
+
def get_text(self, filepath: str | Path, encoding: str = "utf-8") -> str:
|
| 52 |
+
"""Read text from a given ``filepath`` with 'r' mode.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
filepath (str or Path): Path to read data.
|
| 56 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 57 |
+
Defaults to 'utf-8'.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
str: Expected text reading from ``filepath``.
|
| 61 |
+
|
| 62 |
+
Examples:
|
| 63 |
+
>>> backend = LocalBackend()
|
| 64 |
+
>>> filepath = '/path/of/file'
|
| 65 |
+
>>> backend.get_text(filepath)
|
| 66 |
+
'hello world'
|
| 67 |
+
"""
|
| 68 |
+
with open(filepath, encoding=encoding) as f:
|
| 69 |
+
text = f.read()
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
def put(self, obj: bytes | io.BytesIO, filepath: str | Path) -> None:
|
| 73 |
+
"""Write bytes to a given ``filepath`` with 'wb' mode.
|
| 74 |
+
|
| 75 |
+
Note:
|
| 76 |
+
``put`` will create a directory if the directory of
|
| 77 |
+
``filepath`` does not exist.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
obj (bytes): Data to be written.
|
| 81 |
+
filepath (str or Path): Path to write data.
|
| 82 |
+
|
| 83 |
+
Examples:
|
| 84 |
+
>>> backend = LocalBackend()
|
| 85 |
+
>>> filepath = '/path/of/file'
|
| 86 |
+
>>> backend.put(b'hello world', filepath)
|
| 87 |
+
"""
|
| 88 |
+
mkdir_or_exist(osp.dirname(filepath))
|
| 89 |
+
if isinstance(obj, io.BytesIO):
|
| 90 |
+
obj.seek(0)
|
| 91 |
+
obj = obj.getvalue()
|
| 92 |
+
with open(filepath, "wb") as f:
|
| 93 |
+
f.write(obj)
|
| 94 |
+
|
| 95 |
+
def put_text(self, obj: str, filepath: str | Path, encoding: str = "utf-8") -> None:
|
| 96 |
+
"""Write text to a given ``filepath`` with 'w' mode.
|
| 97 |
+
|
| 98 |
+
Note:
|
| 99 |
+
``put_text`` will create a directory if the directory of
|
| 100 |
+
``filepath`` does not exist.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
obj (str): Data to be written.
|
| 104 |
+
filepath (str or Path): Path to write data.
|
| 105 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 106 |
+
Defaults to 'utf-8'.
|
| 107 |
+
|
| 108 |
+
Examples:
|
| 109 |
+
>>> backend = LocalBackend()
|
| 110 |
+
>>> filepath = '/path/of/file'
|
| 111 |
+
>>> backend.put_text('hello world', filepath)
|
| 112 |
+
"""
|
| 113 |
+
mkdir_or_exist(osp.dirname(filepath))
|
| 114 |
+
with open(filepath, "w", encoding=encoding) as f:
|
| 115 |
+
f.write(obj)
|
| 116 |
+
|
| 117 |
+
def exists(self, filepath: str | Path) -> bool:
|
| 118 |
+
"""Check whether a file path exists.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
filepath (str or Path): Path to be checked whether exists.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 125 |
+
|
| 126 |
+
Examples:
|
| 127 |
+
>>> backend = LocalBackend()
|
| 128 |
+
>>> filepath = '/path/of/file'
|
| 129 |
+
>>> backend.exists(filepath)
|
| 130 |
+
True
|
| 131 |
+
"""
|
| 132 |
+
return osp.exists(filepath)
|
| 133 |
+
|
| 134 |
+
def isdir(self, filepath: str | Path) -> bool:
|
| 135 |
+
"""Check whether a file path is a directory.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
filepath (str or Path): Path to be checked whether it is a
|
| 139 |
+
directory.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 143 |
+
``False`` otherwise.
|
| 144 |
+
|
| 145 |
+
Examples:
|
| 146 |
+
>>> backend = LocalBackend()
|
| 147 |
+
>>> filepath = '/path/of/dir'
|
| 148 |
+
>>> backend.isdir(filepath)
|
| 149 |
+
True
|
| 150 |
+
"""
|
| 151 |
+
return osp.isdir(filepath)
|
| 152 |
+
|
| 153 |
+
def isfile(self, filepath: str | Path) -> bool:
|
| 154 |
+
"""Check whether a file path is a file.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 161 |
+
otherwise.
|
| 162 |
+
|
| 163 |
+
Examples:
|
| 164 |
+
>>> backend = LocalBackend()
|
| 165 |
+
>>> filepath = '/path/of/file'
|
| 166 |
+
>>> backend.isfile(filepath)
|
| 167 |
+
True
|
| 168 |
+
"""
|
| 169 |
+
return osp.isfile(filepath)
|
| 170 |
+
|
| 171 |
+
def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str:
|
| 172 |
+
r"""Concatenate all file paths.
|
| 173 |
+
|
| 174 |
+
Join one or more filepath components intelligently. The return value
|
| 175 |
+
is the concatenation of filepath and any members of \*filepaths.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
filepath (str or Path): Path to be concatenated.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
str: The result of concatenation.
|
| 182 |
+
|
| 183 |
+
Examples:
|
| 184 |
+
>>> backend = LocalBackend()
|
| 185 |
+
>>> filepath1 = '/path/of/dir1'
|
| 186 |
+
>>> filepath2 = 'dir2'
|
| 187 |
+
>>> filepath3 = 'path/of/file'
|
| 188 |
+
>>> backend.join_path(filepath1, filepath2, filepath3)
|
| 189 |
+
'/path/of/dir/dir2/path/of/file'
|
| 190 |
+
"""
|
| 191 |
+
return osp.join(filepath, *filepaths)
|
| 192 |
+
|
| 193 |
+
@contextmanager
|
| 194 |
+
def get_local_path(self, filepath: str) -> Generator[str, None, None]:
|
| 195 |
+
"""Download data from filepath to local path with a context manager.
|
| 196 |
+
|
| 197 |
+
If filepath exists in localhost, it just return filepath.
|
| 198 |
+
If filepath doesn't exist in localhost, it will download the data
|
| 199 |
+
to local path, and return the path, then the path will be removed
|
| 200 |
+
after existing from the with statement.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
filepath (str): Path to be read data.
|
| 204 |
+
|
| 205 |
+
Yields:
|
| 206 |
+
str: Local path.
|
| 207 |
+
|
| 208 |
+
Examples:
|
| 209 |
+
>>> with backend.get_local_path('http://example.com/abc.jpg') as path:
|
| 210 |
+
... # do something here
|
| 211 |
+
"""
|
| 212 |
+
yield filepath
|
| 213 |
+
|
| 214 |
+
def copyfile(
|
| 215 |
+
self,
|
| 216 |
+
src: str | Path,
|
| 217 |
+
dst: str | Path,
|
| 218 |
+
) -> str:
|
| 219 |
+
"""Copy a file src to dst and return the destination file.
|
| 220 |
+
|
| 221 |
+
src and dst should have the same prefix. If dst specifies a directory,
|
| 222 |
+
the file will be copied into dst using the base filename from src. If
|
| 223 |
+
dst specifies a file that already exists, it will be replaced.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
src (str or Path): A file to be copied.
|
| 227 |
+
dst (str or Path): Copy file to dst.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
str: The destination file.
|
| 231 |
+
|
| 232 |
+
Raises:
|
| 233 |
+
SameFileError: If src and dst are the same file, a SameFileError
|
| 234 |
+
will be raised.
|
| 235 |
+
|
| 236 |
+
Examples:
|
| 237 |
+
>>> backend = LocalBackend()
|
| 238 |
+
>>> # dst is a file
|
| 239 |
+
>>> src = '/path/of/file'
|
| 240 |
+
>>> dst = '/path1/of/file1'
|
| 241 |
+
>>> # src will be copied to '/path1/of/file1'
|
| 242 |
+
>>> backend.copyfile(src, dst)
|
| 243 |
+
'/path1/of/file1'
|
| 244 |
+
|
| 245 |
+
>>> # dst is a directory
|
| 246 |
+
>>> dst = '/path1/of/dir'
|
| 247 |
+
>>> # src will be copied to '/path1/of/dir/file'
|
| 248 |
+
>>> backend.copyfile(src, dst)
|
| 249 |
+
'/path1/of/dir/file'
|
| 250 |
+
"""
|
| 251 |
+
return shutil.copy(src, dst)
|
| 252 |
+
|
| 253 |
+
def copytree(
|
| 254 |
+
self,
|
| 255 |
+
src: str | Path,
|
| 256 |
+
dst: str | Path,
|
| 257 |
+
) -> str:
|
| 258 |
+
"""Recursively copy an entire directory tree rooted at src to a
|
| 259 |
+
directory named dst and return the destination directory.
|
| 260 |
+
|
| 261 |
+
src and dst should have the same prefix and dst must not already exist.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
src (str or Path): A directory to be copied.
|
| 265 |
+
dst (str or Path): Copy directory to dst.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
str: The destination directory.
|
| 269 |
+
|
| 270 |
+
Raises:
|
| 271 |
+
FileExistsError: If dst had already existed, a FileExistsError will
|
| 272 |
+
be raised.
|
| 273 |
+
|
| 274 |
+
Examples:
|
| 275 |
+
>>> backend = LocalBackend()
|
| 276 |
+
>>> src = '/path/of/dir1'
|
| 277 |
+
>>> dst = '/path/of/dir2'
|
| 278 |
+
>>> backend.copytree(src, dst)
|
| 279 |
+
'/path/of/dir2'
|
| 280 |
+
"""
|
| 281 |
+
return shutil.copytree(src, dst)
|
| 282 |
+
|
| 283 |
+
def copyfile_from_local(
|
| 284 |
+
self,
|
| 285 |
+
src: str | Path,
|
| 286 |
+
dst: str | Path,
|
| 287 |
+
) -> str:
|
| 288 |
+
"""Copy a local file src to dst and return the destination file. Same
|
| 289 |
+
as :meth:`copyfile`.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
src (str or Path): A local file to be copied.
|
| 293 |
+
dst (str or Path): Copy file to dst.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
str: If dst specifies a directory, the file will be copied into dst
|
| 297 |
+
using the base filename from src.
|
| 298 |
+
|
| 299 |
+
Raises:
|
| 300 |
+
SameFileError: If src and dst are the same file, a SameFileError
|
| 301 |
+
will be raised.
|
| 302 |
+
|
| 303 |
+
Examples:
|
| 304 |
+
>>> backend = LocalBackend()
|
| 305 |
+
>>> # dst is a file
|
| 306 |
+
>>> src = '/path/of/file'
|
| 307 |
+
>>> dst = '/path1/of/file1'
|
| 308 |
+
>>> # src will be copied to '/path1/of/file1'
|
| 309 |
+
>>> backend.copyfile_from_local(src, dst)
|
| 310 |
+
'/path1/of/file1'
|
| 311 |
+
|
| 312 |
+
>>> # dst is a directory
|
| 313 |
+
>>> dst = '/path1/of/dir'
|
| 314 |
+
>>> # src will be copied to
|
| 315 |
+
>>> backend.copyfile_from_local(src, dst)
|
| 316 |
+
'/path1/of/dir/file'
|
| 317 |
+
"""
|
| 318 |
+
return self.copyfile(src, dst)
|
| 319 |
+
|
| 320 |
+
def copytree_from_local(
|
| 321 |
+
self,
|
| 322 |
+
src: str | Path,
|
| 323 |
+
dst: str | Path,
|
| 324 |
+
) -> str:
|
| 325 |
+
"""Recursively copy an entire directory tree rooted at src to a
|
| 326 |
+
directory named dst and return the destination directory. Same as
|
| 327 |
+
:meth:`copytree`.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
src (str or Path): A local directory to be copied.
|
| 331 |
+
dst (str or Path): Copy directory to dst.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
str: The destination directory.
|
| 335 |
+
|
| 336 |
+
Examples:
|
| 337 |
+
>>> backend = LocalBackend()
|
| 338 |
+
>>> src = '/path/of/dir1'
|
| 339 |
+
>>> dst = '/path/of/dir2'
|
| 340 |
+
>>> backend.copytree_from_local(src, dst)
|
| 341 |
+
'/path/of/dir2'
|
| 342 |
+
"""
|
| 343 |
+
return self.copytree(src, dst)
|
| 344 |
+
|
| 345 |
+
def copyfile_to_local(
|
| 346 |
+
self,
|
| 347 |
+
src: str | Path,
|
| 348 |
+
dst: str | Path,
|
| 349 |
+
dst_type: str | None = None,
|
| 350 |
+
) -> str:
|
| 351 |
+
"""Copy the file src to local dst and return the destination file. Same
|
| 352 |
+
as :meth:`copyfile`.
|
| 353 |
+
|
| 354 |
+
If dst specifies a directory, the file will be copied into dst using
|
| 355 |
+
the base filename from src. If dst specifies a file that already
|
| 356 |
+
exists, it will be replaced.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
src (str or Path): A file to be copied.
|
| 360 |
+
dst (str or Path): Copy file to to local dst.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
str: If dst specifies a directory, the file will be copied into dst
|
| 364 |
+
using the base filename from src.
|
| 365 |
+
|
| 366 |
+
Examples:
|
| 367 |
+
>>> backend = LocalBackend()
|
| 368 |
+
>>> # dst is a file
|
| 369 |
+
>>> src = '/path/of/file'
|
| 370 |
+
>>> dst = '/path1/of/file1'
|
| 371 |
+
>>> # src will be copied to '/path1/of/file1'
|
| 372 |
+
>>> backend.copyfile_to_local(src, dst)
|
| 373 |
+
'/path1/of/file1'
|
| 374 |
+
|
| 375 |
+
>>> # dst is a directory
|
| 376 |
+
>>> dst = '/path1/of/dir'
|
| 377 |
+
>>> # src will be copied to
|
| 378 |
+
>>> backend.copyfile_to_local(src, dst)
|
| 379 |
+
'/path1/of/dir/file'
|
| 380 |
+
"""
|
| 381 |
+
return self.copyfile(src, dst)
|
| 382 |
+
|
| 383 |
+
def copytree_to_local(
|
| 384 |
+
self,
|
| 385 |
+
src: str | Path,
|
| 386 |
+
dst: str | Path,
|
| 387 |
+
) -> str:
|
| 388 |
+
"""Recursively copy an entire directory tree rooted at src to a local
|
| 389 |
+
directory named dst and return the destination directory.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
src (str or Path): A directory to be copied.
|
| 393 |
+
dst (str or Path): Copy directory to local dst.
|
| 394 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 395 |
+
prefix of uri corresponding backend. Defaults to None.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
str: The destination directory.
|
| 399 |
+
|
| 400 |
+
Examples:
|
| 401 |
+
>>> backend = LocalBackend()
|
| 402 |
+
>>> src = '/path/of/dir1'
|
| 403 |
+
>>> dst = '/path/of/dir2'
|
| 404 |
+
>>> backend.copytree_from_local(src, dst)
|
| 405 |
+
'/path/of/dir2'
|
| 406 |
+
"""
|
| 407 |
+
return self.copytree(src, dst)
|
| 408 |
+
|
| 409 |
+
def remove(self, filepath: str | Path) -> None:
|
| 410 |
+
"""Remove a file.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
filepath (str or Path): Path to be removed.
|
| 414 |
+
|
| 415 |
+
Raises:
|
| 416 |
+
IsADirectoryError: If filepath is a directory, an IsADirectoryError
|
| 417 |
+
will be raised.
|
| 418 |
+
FileNotFoundError: If filepath does not exist, an FileNotFoundError
|
| 419 |
+
will be raised.
|
| 420 |
+
|
| 421 |
+
Examples:
|
| 422 |
+
>>> backend = LocalBackend()
|
| 423 |
+
>>> filepath = '/path/of/file'
|
| 424 |
+
>>> backend.remove(filepath)
|
| 425 |
+
"""
|
| 426 |
+
if not self.exists(filepath):
|
| 427 |
+
raise FileNotFoundError(f"filepath {filepath} does not exist")
|
| 428 |
+
|
| 429 |
+
if self.isdir(filepath):
|
| 430 |
+
raise IsADirectoryError("filepath should be a file")
|
| 431 |
+
|
| 432 |
+
os.remove(filepath)
|
| 433 |
+
|
| 434 |
+
def rmtree(self, dir_path: str | Path) -> None:
|
| 435 |
+
"""Recursively delete a directory tree.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
dir_path (str or Path): A directory to be removed.
|
| 439 |
+
|
| 440 |
+
Examples:
|
| 441 |
+
>>> dir_path = '/path/of/dir'
|
| 442 |
+
>>> backend.rmtree(dir_path)
|
| 443 |
+
"""
|
| 444 |
+
shutil.rmtree(dir_path)
|
| 445 |
+
|
| 446 |
+
def copy_if_symlink_fails(
|
| 447 |
+
self,
|
| 448 |
+
src: str | Path,
|
| 449 |
+
dst: str | Path,
|
| 450 |
+
) -> bool:
|
| 451 |
+
"""Create a symbolic link pointing to src named dst.
|
| 452 |
+
|
| 453 |
+
If failed to create a symbolic link pointing to src, directly copy src
|
| 454 |
+
to dst instead.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
src (str or Path): Create a symbolic link pointing to src.
|
| 458 |
+
dst (str or Path): Create a symbolic link named dst.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
bool: Return True if successfully create a symbolic link pointing
|
| 462 |
+
to src. Otherwise, return False.
|
| 463 |
+
|
| 464 |
+
Examples:
|
| 465 |
+
>>> backend = LocalBackend()
|
| 466 |
+
>>> src = '/path/of/file'
|
| 467 |
+
>>> dst = '/path1/of/file1'
|
| 468 |
+
>>> backend.copy_if_symlink_fails(src, dst)
|
| 469 |
+
True
|
| 470 |
+
>>> src = '/path/of/dir'
|
| 471 |
+
>>> dst = '/path1/of/dir1'
|
| 472 |
+
>>> backend.copy_if_symlink_fails(src, dst)
|
| 473 |
+
True
|
| 474 |
+
"""
|
| 475 |
+
try:
|
| 476 |
+
os.symlink(src, dst)
|
| 477 |
+
return True
|
| 478 |
+
except Exception:
|
| 479 |
+
if self.isfile(src):
|
| 480 |
+
self.copyfile(src, dst)
|
| 481 |
+
else:
|
| 482 |
+
self.copytree(src, dst)
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
+
def list_dir_or_file(
|
| 486 |
+
self,
|
| 487 |
+
dir_path: str | Path,
|
| 488 |
+
list_dir: bool = True,
|
| 489 |
+
list_file: bool = True,
|
| 490 |
+
suffix: str | tuple[str] | None = None,
|
| 491 |
+
recursive: bool = False,
|
| 492 |
+
) -> Iterator[str]:
|
| 493 |
+
"""Scan a directory to find the interested directories or files in
|
| 494 |
+
arbitrary order.
|
| 495 |
+
|
| 496 |
+
Note:
|
| 497 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
dir_path (str or Path): Path of the directory.
|
| 501 |
+
list_dir (bool): List the directories. Defaults to True.
|
| 502 |
+
list_file (bool): List the path of files. Defaults to True.
|
| 503 |
+
suffix (str or tuple[str], optional): File suffix that we are
|
| 504 |
+
interested in. Defaults to None.
|
| 505 |
+
recursive (bool): If set to True, recursively scan the directory.
|
| 506 |
+
Defaults to False.
|
| 507 |
+
|
| 508 |
+
Yields:
|
| 509 |
+
Iterable[str]: A relative path to ``dir_path``.
|
| 510 |
+
|
| 511 |
+
Examples:
|
| 512 |
+
>>> backend = LocalBackend()
|
| 513 |
+
>>> dir_path = '/path/of/dir'
|
| 514 |
+
>>> # list those files and directories in current directory
|
| 515 |
+
>>> for file_path in backend.list_dir_or_file(dir_path):
|
| 516 |
+
... print(file_path)
|
| 517 |
+
>>> # only list files
|
| 518 |
+
>>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
|
| 519 |
+
... print(file_path)
|
| 520 |
+
>>> # only list directories
|
| 521 |
+
>>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
|
| 522 |
+
... print(file_path)
|
| 523 |
+
>>> # only list files ending with specified suffixes
|
| 524 |
+
>>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
|
| 525 |
+
... print(file_path)
|
| 526 |
+
>>> # list all files and directory recursively
|
| 527 |
+
>>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
|
| 528 |
+
... print(file_path)
|
| 529 |
+
"""
|
| 530 |
+
if list_dir and suffix is not None:
|
| 531 |
+
raise TypeError("`suffix` should be None when `list_dir` is True")
|
| 532 |
+
|
| 533 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 534 |
+
raise TypeError("`suffix` must be a string or tuple of strings")
|
| 535 |
+
|
| 536 |
+
root = dir_path
|
| 537 |
+
|
| 538 |
+
def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
|
| 539 |
+
for entry in os.scandir(dir_path):
|
| 540 |
+
if not entry.name.startswith(".") and entry.is_file():
|
| 541 |
+
rel_path = osp.relpath(entry.path, root)
|
| 542 |
+
if (suffix is None or rel_path.endswith(suffix)) and list_file:
|
| 543 |
+
yield rel_path
|
| 544 |
+
elif osp.isdir(entry.path):
|
| 545 |
+
if list_dir:
|
| 546 |
+
rel_dir = osp.relpath(entry.path, root)
|
| 547 |
+
yield rel_dir
|
| 548 |
+
if recursive:
|
| 549 |
+
yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive)
|
| 550 |
+
|
| 551 |
+
return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
imaginaire/utils/easy_io/backends/registry_utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
|
| 18 |
+
from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 19 |
+
from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
|
| 20 |
+
from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
|
| 21 |
+
|
| 22 |
+
backends: dict = {}
|
| 23 |
+
prefix_to_backends: dict = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _register_backend(
|
| 27 |
+
name: str,
|
| 28 |
+
backend: type[BaseStorageBackend],
|
| 29 |
+
force: bool = False,
|
| 30 |
+
prefixes: str | list | tuple | None = None,
|
| 31 |
+
):
|
| 32 |
+
"""Register a backend.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
name (str): The name of the registered backend.
|
| 36 |
+
backend (BaseStorageBackend): The backend class to be registered,
|
| 37 |
+
which must be a subclass of :class:`BaseStorageBackend`.
|
| 38 |
+
force (bool): Whether to override the backend if the name has already
|
| 39 |
+
been registered. Defaults to False.
|
| 40 |
+
prefixes (str or list[str] or tuple[str], optional): The prefix
|
| 41 |
+
of the registered storage backend. Defaults to None.
|
| 42 |
+
"""
|
| 43 |
+
global backends, prefix_to_backends
|
| 44 |
+
|
| 45 |
+
if not isinstance(name, str):
|
| 46 |
+
raise TypeError(f"the backend name should be a string, but got {type(name)}")
|
| 47 |
+
|
| 48 |
+
if not inspect.isclass(backend):
|
| 49 |
+
raise TypeError(f"backend should be a class, but got {type(backend)}")
|
| 50 |
+
if not issubclass(backend, BaseStorageBackend):
|
| 51 |
+
raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
|
| 52 |
+
|
| 53 |
+
if name in backends and not force:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 56 |
+
)
|
| 57 |
+
backends[name] = backend
|
| 58 |
+
|
| 59 |
+
if prefixes is not None:
|
| 60 |
+
if isinstance(prefixes, str):
|
| 61 |
+
prefixes = [prefixes]
|
| 62 |
+
else:
|
| 63 |
+
assert isinstance(prefixes, (list, tuple))
|
| 64 |
+
|
| 65 |
+
for prefix in prefixes:
|
| 66 |
+
if prefix in prefix_to_backends and not force:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
prefix_to_backends[prefix] = backend
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def register_backend(
|
| 75 |
+
name: str,
|
| 76 |
+
backend: type[BaseStorageBackend] | None = None,
|
| 77 |
+
force: bool = False,
|
| 78 |
+
prefixes: str | list | tuple | None = None,
|
| 79 |
+
):
|
| 80 |
+
"""Register a backend.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
name (str): The name of the registered backend.
|
| 84 |
+
backend (class, optional): The backend class to be registered,
|
| 85 |
+
which must be a subclass of :class:`BaseStorageBackend`.
|
| 86 |
+
When this method is used as a decorator, backend is None.
|
| 87 |
+
Defaults to None.
|
| 88 |
+
force (bool): Whether to override the backend if the name has already
|
| 89 |
+
been registered. Defaults to False.
|
| 90 |
+
prefixes (str or list[str] or tuple[str], optional): The prefix
|
| 91 |
+
of the registered storage backend. Defaults to None.
|
| 92 |
+
|
| 93 |
+
This method can be used as a normal method or a decorator.
|
| 94 |
+
|
| 95 |
+
Examples:
|
| 96 |
+
|
| 97 |
+
>>> class NewBackend(BaseStorageBackend):
|
| 98 |
+
... def get(self, filepath):
|
| 99 |
+
... return filepath
|
| 100 |
+
...
|
| 101 |
+
... def get_text(self, filepath):
|
| 102 |
+
... return filepath
|
| 103 |
+
>>> register_backend('new', NewBackend)
|
| 104 |
+
|
| 105 |
+
>>> @register_backend('new')
|
| 106 |
+
... class NewBackend(BaseStorageBackend):
|
| 107 |
+
... def get(self, filepath):
|
| 108 |
+
... return filepath
|
| 109 |
+
...
|
| 110 |
+
... def get_text(self, filepath):
|
| 111 |
+
... return filepath
|
| 112 |
+
"""
|
| 113 |
+
if backend is not None:
|
| 114 |
+
_register_backend(name, backend, force=force, prefixes=prefixes)
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
def _register(backend_cls):
|
| 118 |
+
_register_backend(name, backend_cls, force=force, prefixes=prefixes)
|
| 119 |
+
return backend_cls
|
| 120 |
+
|
| 121 |
+
return _register
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
register_backend("local", LocalBackend, prefixes="")
|
| 125 |
+
register_backend("http", HTTPBackend, prefixes=["http", "https"])
|
imaginaire/utils/easy_io/easy_io.py
ADDED
|
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import warnings
|
| 18 |
+
from collections.abc import Generator, Iterator
|
| 19 |
+
from contextlib import contextmanager
|
| 20 |
+
from io import BytesIO, StringIO
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import IO, Any
|
| 23 |
+
|
| 24 |
+
from imaginaire.utils.easy_io.backends import backends, prefix_to_backends
|
| 25 |
+
from imaginaire.utils.easy_io.file_client import FileClient
|
| 26 |
+
from imaginaire.utils.easy_io.handlers import file_handlers
|
| 27 |
+
|
| 28 |
+
backend_instances: dict = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def is_filepath(filepath):
|
| 32 |
+
return isinstance(filepath, (str, Path))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _parse_uri_prefix(uri: str | Path) -> str:
|
| 36 |
+
"""Parse the prefix of uri.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
uri (str or Path): Uri to be parsed that contains the file prefix.
|
| 40 |
+
|
| 41 |
+
Examples:
|
| 42 |
+
>>> _parse_uri_prefix('/home/path/of/your/file')
|
| 43 |
+
''
|
| 44 |
+
>>> _parse_uri_prefix('http://path/of/your/file')
|
| 45 |
+
'http'
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
str: Return the prefix of uri if the uri contains '://'. Otherwise,
|
| 49 |
+
return ''.
|
| 50 |
+
"""
|
| 51 |
+
assert is_filepath(uri)
|
| 52 |
+
uri = str(uri)
|
| 53 |
+
# if uri does not contains '://', the uri will be handled by
|
| 54 |
+
# LocalBackend by default
|
| 55 |
+
if "://" not in uri:
|
| 56 |
+
return ""
|
| 57 |
+
else:
|
| 58 |
+
prefix, _ = uri.split("://")
|
| 59 |
+
if ":" in prefix:
|
| 60 |
+
_, prefix = prefix.split(":")
|
| 61 |
+
return prefix
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_file_backend(prefix: str, backend_args: dict):
|
| 65 |
+
"""Return a file backend based on the prefix or backend_args.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
prefix (str): Prefix of uri.
|
| 69 |
+
backend_args (dict): Arguments to instantiate the corresponding
|
| 70 |
+
backend.
|
| 71 |
+
"""
|
| 72 |
+
# backend name has a higher priority
|
| 73 |
+
if "backend" in backend_args:
|
| 74 |
+
# backend_args should not be modified
|
| 75 |
+
backend_args_bak = backend_args.copy()
|
| 76 |
+
backend_name = backend_args_bak.pop("backend")
|
| 77 |
+
backend = backends[backend_name](**backend_args_bak)
|
| 78 |
+
else:
|
| 79 |
+
backend = prefix_to_backends[prefix](**backend_args)
|
| 80 |
+
return backend
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_file_backend(
|
| 84 |
+
uri: str | Path | None = None,
|
| 85 |
+
*,
|
| 86 |
+
backend_args: dict | None = None,
|
| 87 |
+
enable_singleton: bool = False,
|
| 88 |
+
backend_key: str | None = None,
|
| 89 |
+
):
|
| 90 |
+
"""Return a file backend based on the prefix of uri or backend_args.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
uri (str or Path): Uri to be parsed that contains the file prefix.
|
| 94 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 95 |
+
corresponding backend. Defaults to None.
|
| 96 |
+
enable_singleton (bool): Whether to enable the singleton pattern.
|
| 97 |
+
If it is True, the backend created will be reused if the
|
| 98 |
+
signature is same with the previous one. Defaults to False.
|
| 99 |
+
backend_key: str: The key to register the backend. Defaults to None.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
BaseStorageBackend: Instantiated Backend object.
|
| 103 |
+
|
| 104 |
+
Examples:
|
| 105 |
+
>>> # get file backend based on the prefix of uri
|
| 106 |
+
>>> uri = 'http://path/of/your/file'
|
| 107 |
+
>>> backend = get_file_backend(uri)
|
| 108 |
+
>>> # get file backend based on the backend_args
|
| 109 |
+
>>> backend = get_file_backend(backend_args={'backend': 'http'})
|
| 110 |
+
>>> # backend name has a higher priority if 'backend' in backend_args
|
| 111 |
+
>>> backend = get_file_backend(uri, backend_args={'backend': 'http'})
|
| 112 |
+
"""
|
| 113 |
+
global backend_instances
|
| 114 |
+
if backend_key is not None:
|
| 115 |
+
if backend_key in backend_instances:
|
| 116 |
+
return backend_instances[backend_key]
|
| 117 |
+
|
| 118 |
+
if backend_args is None:
|
| 119 |
+
backend_args = {}
|
| 120 |
+
|
| 121 |
+
if uri is None and "backend" not in backend_args and backend_key is None:
|
| 122 |
+
raise ValueError('uri should not be None when "backend" does not exist in backend_args and backend_key is None')
|
| 123 |
+
|
| 124 |
+
if uri is not None:
|
| 125 |
+
prefix = _parse_uri_prefix(uri)
|
| 126 |
+
else:
|
| 127 |
+
prefix = ""
|
| 128 |
+
|
| 129 |
+
if enable_singleton:
|
| 130 |
+
unique_key = f"{prefix}:{json.dumps(backend_args)}"
|
| 131 |
+
if unique_key in backend_instances:
|
| 132 |
+
return backend_instances[unique_key]
|
| 133 |
+
|
| 134 |
+
backend = _get_file_backend(prefix, backend_args)
|
| 135 |
+
backend_instances[unique_key] = backend
|
| 136 |
+
if backend_key is not None:
|
| 137 |
+
backend_instances[backend_key] = backend
|
| 138 |
+
return backend
|
| 139 |
+
else:
|
| 140 |
+
backend = _get_file_backend(prefix, backend_args)
|
| 141 |
+
return backend
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get(
|
| 145 |
+
filepath: str | Path,
|
| 146 |
+
backend_args: dict | None = None,
|
| 147 |
+
backend_key: str | None = None,
|
| 148 |
+
) -> bytes:
|
| 149 |
+
"""Read bytes from a given ``filepath`` with 'rb' mode.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
filepath (str or Path): Path to read data.
|
| 153 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 154 |
+
corresponding backend. Defaults to None.
|
| 155 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
bytes: Expected bytes object.
|
| 159 |
+
|
| 160 |
+
Examples:
|
| 161 |
+
>>> filepath = '/path/of/file'
|
| 162 |
+
>>> get(filepath)
|
| 163 |
+
b'hello world'
|
| 164 |
+
"""
|
| 165 |
+
backend = get_file_backend(
|
| 166 |
+
filepath,
|
| 167 |
+
backend_args=backend_args,
|
| 168 |
+
enable_singleton=True,
|
| 169 |
+
backend_key=backend_key,
|
| 170 |
+
)
|
| 171 |
+
return backend.get(filepath)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_text(
|
| 175 |
+
filepath: str | Path,
|
| 176 |
+
encoding="utf-8",
|
| 177 |
+
backend_args: dict | None = None,
|
| 178 |
+
backend_key: str | None = None,
|
| 179 |
+
) -> str:
|
| 180 |
+
"""Read text from a given ``filepath`` with 'r' mode.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
filepath (str or Path): Path to read data.
|
| 184 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 185 |
+
Defaults to 'utf-8'.
|
| 186 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 187 |
+
corresponding backend. Defaults to None.
|
| 188 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
str: Expected text reading from ``filepath``.
|
| 192 |
+
|
| 193 |
+
Examples:
|
| 194 |
+
>>> filepath = '/path/of/file'
|
| 195 |
+
>>> get_text(filepath)
|
| 196 |
+
'hello world'
|
| 197 |
+
"""
|
| 198 |
+
backend = get_file_backend(
|
| 199 |
+
filepath,
|
| 200 |
+
backend_args=backend_args,
|
| 201 |
+
enable_singleton=True,
|
| 202 |
+
backend_key=backend_key,
|
| 203 |
+
)
|
| 204 |
+
return backend.get_text(filepath, encoding)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def put(
|
| 208 |
+
obj: bytes,
|
| 209 |
+
filepath: str | Path,
|
| 210 |
+
backend_args: dict | None = None,
|
| 211 |
+
backend_key: str | None = None,
|
| 212 |
+
) -> None:
|
| 213 |
+
"""Write bytes to a given ``filepath`` with 'wb' mode.
|
| 214 |
+
|
| 215 |
+
Note:
|
| 216 |
+
``put`` should create a directory if the directory of
|
| 217 |
+
``filepath`` does not exist.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
obj (bytes): Data to be written.
|
| 221 |
+
filepath (str or Path): Path to write data.
|
| 222 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 223 |
+
corresponding backend. Defaults to None.
|
| 224 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 225 |
+
|
| 226 |
+
Examples:
|
| 227 |
+
>>> filepath = '/path/of/file'
|
| 228 |
+
>>> put(b'hello world', filepath)
|
| 229 |
+
"""
|
| 230 |
+
backend = get_file_backend(
|
| 231 |
+
filepath,
|
| 232 |
+
backend_args=backend_args,
|
| 233 |
+
enable_singleton=True,
|
| 234 |
+
backend_key=backend_key,
|
| 235 |
+
)
|
| 236 |
+
backend.put(obj, filepath)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def put_text(
|
| 240 |
+
obj: str,
|
| 241 |
+
filepath: str | Path,
|
| 242 |
+
backend_args: dict | None = None,
|
| 243 |
+
backend_key: str | None = None,
|
| 244 |
+
) -> None:
|
| 245 |
+
"""Write text to a given ``filepath`` with 'w' mode.
|
| 246 |
+
|
| 247 |
+
Note:
|
| 248 |
+
``put_text`` should create a directory if the directory of
|
| 249 |
+
``filepath`` does not exist.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
obj (str): Data to be written.
|
| 253 |
+
filepath (str or Path): Path to write data.
|
| 254 |
+
encoding (str, optional): The encoding format used to open the
|
| 255 |
+
``filepath``. Defaults to 'utf-8'.
|
| 256 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 257 |
+
corresponding backend. Defaults to None.
|
| 258 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 259 |
+
|
| 260 |
+
Examples:
|
| 261 |
+
>>> filepath = '/path/of/file'
|
| 262 |
+
>>> put_text('hello world', filepath)
|
| 263 |
+
"""
|
| 264 |
+
backend = get_file_backend(
|
| 265 |
+
filepath,
|
| 266 |
+
backend_args=backend_args,
|
| 267 |
+
enable_singleton=True,
|
| 268 |
+
backend_key=backend_key,
|
| 269 |
+
)
|
| 270 |
+
backend.put_text(obj, filepath)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def exists(
|
| 274 |
+
filepath: str | Path,
|
| 275 |
+
backend_args: dict | None = None,
|
| 276 |
+
backend_key: str | None = None,
|
| 277 |
+
) -> bool:
|
| 278 |
+
"""Check whether a file path exists.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
filepath (str or Path): Path to be checked whether exists.
|
| 282 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 283 |
+
corresponding backend. Defaults to None.
|
| 284 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 288 |
+
|
| 289 |
+
Examples:
|
| 290 |
+
>>> filepath = '/path/of/file'
|
| 291 |
+
>>> exists(filepath)
|
| 292 |
+
True
|
| 293 |
+
"""
|
| 294 |
+
backend = get_file_backend(
|
| 295 |
+
filepath,
|
| 296 |
+
backend_args=backend_args,
|
| 297 |
+
enable_singleton=True,
|
| 298 |
+
backend_key=backend_key,
|
| 299 |
+
)
|
| 300 |
+
return backend.exists(filepath)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def isdir(
|
| 304 |
+
filepath: str | Path,
|
| 305 |
+
backend_args: dict | None = None,
|
| 306 |
+
backend_key: str | None = None,
|
| 307 |
+
) -> bool:
|
| 308 |
+
"""Check whether a file path is a directory.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
filepath (str or Path): Path to be checked whether it is a
|
| 312 |
+
directory.
|
| 313 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 314 |
+
corresponding backend. Defaults to None.
|
| 315 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 319 |
+
``False`` otherwise.
|
| 320 |
+
|
| 321 |
+
Examples:
|
| 322 |
+
>>> filepath = '/path/of/dir'
|
| 323 |
+
>>> isdir(filepath)
|
| 324 |
+
True
|
| 325 |
+
"""
|
| 326 |
+
backend = get_file_backend(
|
| 327 |
+
filepath,
|
| 328 |
+
backend_args=backend_args,
|
| 329 |
+
enable_singleton=True,
|
| 330 |
+
backend_key=backend_key,
|
| 331 |
+
)
|
| 332 |
+
return backend.isdir(filepath)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def isfile(
|
| 336 |
+
filepath: str | Path,
|
| 337 |
+
backend_args: dict | None = None,
|
| 338 |
+
backend_key: str | None = None,
|
| 339 |
+
) -> bool:
|
| 340 |
+
"""Check whether a file path is a file.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
| 344 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 345 |
+
corresponding backend. Defaults to None.
|
| 346 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 350 |
+
otherwise.
|
| 351 |
+
|
| 352 |
+
Examples:
|
| 353 |
+
>>> filepath = '/path/of/file'
|
| 354 |
+
>>> isfile(filepath)
|
| 355 |
+
True
|
| 356 |
+
"""
|
| 357 |
+
backend = get_file_backend(
|
| 358 |
+
filepath,
|
| 359 |
+
backend_args=backend_args,
|
| 360 |
+
enable_singleton=True,
|
| 361 |
+
backend_key=backend_key,
|
| 362 |
+
)
|
| 363 |
+
return backend.isfile(filepath)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def join_path(
|
| 367 |
+
filepath: str | Path,
|
| 368 |
+
*filepaths: str | Path,
|
| 369 |
+
backend_args: dict | None = None,
|
| 370 |
+
backend_key: str | None = None,
|
| 371 |
+
) -> str | Path:
|
| 372 |
+
r"""Concatenate all file paths.
|
| 373 |
+
|
| 374 |
+
Join one or more filepath components intelligently. The return value
|
| 375 |
+
is the concatenation of filepath and any members of \*filepaths.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
filepath (str or Path): Path to be concatenated.
|
| 379 |
+
*filepaths (str or Path): Other paths to be concatenated.
|
| 380 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 381 |
+
corresponding backend. Defaults to None.
|
| 382 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
str: The result of concatenation.
|
| 386 |
+
|
| 387 |
+
Examples:
|
| 388 |
+
>>> filepath1 = '/path/of/dir1'
|
| 389 |
+
>>> filepath2 = 'dir2'
|
| 390 |
+
>>> filepath3 = 'path/of/file'
|
| 391 |
+
>>> join_path(filepath1, filepath2, filepath3)
|
| 392 |
+
'/path/of/dir/dir2/path/of/file'
|
| 393 |
+
"""
|
| 394 |
+
backend = get_file_backend(
|
| 395 |
+
filepath,
|
| 396 |
+
backend_args=backend_args,
|
| 397 |
+
enable_singleton=True,
|
| 398 |
+
backend_key=backend_key,
|
| 399 |
+
)
|
| 400 |
+
return backend.join_path(filepath, *filepaths)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@contextmanager
|
| 404 |
+
def get_local_path(
|
| 405 |
+
filepath: str | Path,
|
| 406 |
+
backend_args: dict | None = None,
|
| 407 |
+
backend_key: str | None = None,
|
| 408 |
+
) -> Generator[str | Path, None, None]:
|
| 409 |
+
"""Download data from ``filepath`` and write the data to local path.
|
| 410 |
+
|
| 411 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 412 |
+
can be called with ``with`` statement, and when exists from the
|
| 413 |
+
``with`` statement, the temporary path will be released.
|
| 414 |
+
|
| 415 |
+
Note:
|
| 416 |
+
If the ``filepath`` is a local path, just return itself and it will
|
| 417 |
+
not be released (removed).
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
filepath (str or Path): Path to be read data.
|
| 421 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 422 |
+
corresponding backend. Defaults to None.
|
| 423 |
+
|
| 424 |
+
Yields:
|
| 425 |
+
Iterable[str]: Only yield one path.
|
| 426 |
+
|
| 427 |
+
Examples:
|
| 428 |
+
>>> with get_local_path('http://example.com/file.jpg') as path:
|
| 429 |
+
... # do something here
|
| 430 |
+
"""
|
| 431 |
+
backend = get_file_backend(
|
| 432 |
+
filepath,
|
| 433 |
+
backend_args=backend_args,
|
| 434 |
+
enable_singleton=True,
|
| 435 |
+
backend_key=backend_key,
|
| 436 |
+
)
|
| 437 |
+
with backend.get_local_path(str(filepath)) as local_path:
|
| 438 |
+
yield local_path
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def copyfile(
|
| 442 |
+
src: str | Path,
|
| 443 |
+
dst: str | Path,
|
| 444 |
+
backend_args: dict | None = None,
|
| 445 |
+
backend_key: str | None = None,
|
| 446 |
+
) -> str | Path:
|
| 447 |
+
"""Copy a file src to dst and return the destination file.
|
| 448 |
+
|
| 449 |
+
src and dst should have the same prefix. If dst specifies a directory,
|
| 450 |
+
the file will be copied into dst using the base filename from src. If
|
| 451 |
+
dst specifies a file that already exists, it will be replaced.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
src (str or Path): A file to be copied.
|
| 455 |
+
dst (str or Path): Copy file to dst.
|
| 456 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 457 |
+
corresponding backend. Defaults to None.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
str: The destination file.
|
| 461 |
+
|
| 462 |
+
Raises:
|
| 463 |
+
SameFileError: If src and dst are the same file, a SameFileError will
|
| 464 |
+
be raised.
|
| 465 |
+
|
| 466 |
+
Examples:
|
| 467 |
+
>>> # dst is a file
|
| 468 |
+
>>> src = '/path/of/file'
|
| 469 |
+
>>> dst = '/path1/of/file1'
|
| 470 |
+
>>> # src will be copied to '/path1/of/file1'
|
| 471 |
+
>>> copyfile(src, dst)
|
| 472 |
+
'/path1/of/file1'
|
| 473 |
+
|
| 474 |
+
>>> # dst is a directory
|
| 475 |
+
>>> dst = '/path1/of/dir'
|
| 476 |
+
>>> # src will be copied to '/path1/of/dir/file'
|
| 477 |
+
>>> copyfile(src, dst)
|
| 478 |
+
'/path1/of/dir/file'
|
| 479 |
+
"""
|
| 480 |
+
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 481 |
+
return backend.copyfile(src, dst)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def copytree(
|
| 485 |
+
src: str | Path,
|
| 486 |
+
dst: str | Path,
|
| 487 |
+
backend_args: dict | None = None,
|
| 488 |
+
backend_key: str | None = None,
|
| 489 |
+
) -> str | Path:
|
| 490 |
+
"""Recursively copy an entire directory tree rooted at src to a directory
|
| 491 |
+
named dst and return the destination directory.
|
| 492 |
+
|
| 493 |
+
src and dst should have the same prefix and dst must not already exist.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
src (str or Path): A directory to be copied.
|
| 497 |
+
dst (str or Path): Copy directory to dst.
|
| 498 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 499 |
+
corresponding backend. Defaults to None.
|
| 500 |
+
backend_key (str, optional): The key to get the backend from register.
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
str: The destination directory.
|
| 504 |
+
|
| 505 |
+
Raises:
|
| 506 |
+
FileExistsError: If dst had already existed, a FileExistsError will be
|
| 507 |
+
raised.
|
| 508 |
+
|
| 509 |
+
Examples:
|
| 510 |
+
>>> src = '/path/of/dir1'
|
| 511 |
+
>>> dst = '/path/of/dir2'
|
| 512 |
+
>>> copytree(src, dst)
|
| 513 |
+
'/path/of/dir2'
|
| 514 |
+
"""
|
| 515 |
+
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 516 |
+
return backend.copytree(src, dst)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def copyfile_from_local(
|
| 520 |
+
src: str | Path,
|
| 521 |
+
dst: str | Path,
|
| 522 |
+
backend_args: dict | None = None,
|
| 523 |
+
backend_key: str | None = None,
|
| 524 |
+
) -> str | Path:
|
| 525 |
+
"""Copy a local file src to dst and return the destination file.
|
| 526 |
+
|
| 527 |
+
Note:
|
| 528 |
+
If the backend is the instance of LocalBackend, it does the same
|
| 529 |
+
thing with :func:`copyfile`.
|
| 530 |
+
|
| 531 |
+
Args:
|
| 532 |
+
src (str or Path): A local file to be copied.
|
| 533 |
+
dst (str or Path): Copy file to dst.
|
| 534 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 535 |
+
corresponding backend. Defaults to None.
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
str: If dst specifies a directory, the file will be copied into dst
|
| 539 |
+
using the base filename from src.
|
| 540 |
+
|
| 541 |
+
Examples:
|
| 542 |
+
>>> # dst is a file
|
| 543 |
+
>>> src = '/path/of/file'
|
| 544 |
+
>>> dst = 'http://example.com/file1'
|
| 545 |
+
>>> # src will be copied to 'http://example.com/file1'
|
| 546 |
+
>>> copyfile_from_local(src, dst)
|
| 547 |
+
http://example.com/file1
|
| 548 |
+
|
| 549 |
+
>>> # dst is a directory
|
| 550 |
+
>>> dst = 'http://example.com/dir'
|
| 551 |
+
>>> # src will be copied to 'http://example.com/dir/file''
|
| 552 |
+
>>> copyfile_from_local(src, dst)
|
| 553 |
+
'http://example.com/dir/file'
|
| 554 |
+
"""
|
| 555 |
+
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 556 |
+
return backend.copyfile_from_local(src, dst)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def copytree_from_local(
|
| 560 |
+
src: str | Path,
|
| 561 |
+
dst: str | Path,
|
| 562 |
+
backend_args: dict | None = None,
|
| 563 |
+
backend_key: str | None = None,
|
| 564 |
+
) -> str | Path:
|
| 565 |
+
"""Recursively copy an entire directory tree rooted at src to a directory
|
| 566 |
+
named dst and return the destination directory.
|
| 567 |
+
|
| 568 |
+
Note:
|
| 569 |
+
If the backend is the instance of LocalBackend, it does the same
|
| 570 |
+
thing with :func:`copytree`.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
src (str or Path): A local directory to be copied.
|
| 574 |
+
dst (str or Path): Copy directory to dst.
|
| 575 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 576 |
+
corresponding backend. Defaults to None.
|
| 577 |
+
|
| 578 |
+
Returns:
|
| 579 |
+
str: The destination directory.
|
| 580 |
+
|
| 581 |
+
Examples:
|
| 582 |
+
>>> src = '/path/of/dir'
|
| 583 |
+
>>> dst = 'http://example.com/dir'
|
| 584 |
+
>>> copyfile_from_local(src, dst)
|
| 585 |
+
'http://example.com/dir'
|
| 586 |
+
"""
|
| 587 |
+
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 588 |
+
return backend.copytree_from_local(src, dst)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def copyfile_to_local(
|
| 592 |
+
src: str | Path,
|
| 593 |
+
dst: str | Path,
|
| 594 |
+
dst_type: str, # Choose from ["file", "dir"]
|
| 595 |
+
backend_args: dict | None = None,
|
| 596 |
+
backend_key: str | None = None,
|
| 597 |
+
) -> str | Path:
|
| 598 |
+
"""Copy the file src to local dst and return the destination file.
|
| 599 |
+
|
| 600 |
+
If dst specifies a directory, the file will be copied into dst using
|
| 601 |
+
the base filename from src. If dst specifies a file that already
|
| 602 |
+
exists, it will be replaced.
|
| 603 |
+
|
| 604 |
+
Note:
|
| 605 |
+
If the backend is the instance of LocalBackend, it does the same
|
| 606 |
+
thing with :func:`copyfile`.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
src (str or Path): A file to be copied.
|
| 610 |
+
dst (str or Path): Copy file to to local dst.
|
| 611 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 612 |
+
corresponding backend. Defaults to None.
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
str: If dst specifies a directory, the file will be copied into dst
|
| 616 |
+
using the base filename from src.
|
| 617 |
+
|
| 618 |
+
Examples:
|
| 619 |
+
>>> # dst is a file
|
| 620 |
+
>>> src = 'http://example.com/file'
|
| 621 |
+
>>> dst = '/path/of/file'
|
| 622 |
+
>>> # src will be copied to '/path/of/file'
|
| 623 |
+
>>> copyfile_to_local(src, dst)
|
| 624 |
+
'/path/of/file'
|
| 625 |
+
|
| 626 |
+
>>> # dst is a directory
|
| 627 |
+
>>> dst = '/path/of/dir'
|
| 628 |
+
>>> # src will be copied to '/path/of/dir/file'
|
| 629 |
+
>>> copyfile_to_local(src, dst)
|
| 630 |
+
'/path/of/dir/file'
|
| 631 |
+
"""
|
| 632 |
+
assert dst_type in ["file", "dir"]
|
| 633 |
+
Path(dst).parent.mkdir(parents=True, exist_ok=True)
|
| 634 |
+
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 635 |
+
return backend.copyfile_to_local(src, dst, dst_type=dst_type)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def copytree_to_local(
|
| 639 |
+
src: str | Path,
|
| 640 |
+
dst: str | Path,
|
| 641 |
+
backend_args: dict | None = None,
|
| 642 |
+
backend_key: str | None = None,
|
| 643 |
+
) -> str | Path:
|
| 644 |
+
"""Recursively copy an entire directory tree rooted at src to a local
|
| 645 |
+
directory named dst and return the destination directory.
|
| 646 |
+
|
| 647 |
+
Note:
|
| 648 |
+
If the backend is the instance of LocalBackend, it does the same
|
| 649 |
+
thing with :func:`copytree`.
|
| 650 |
+
|
| 651 |
+
Args:
|
| 652 |
+
src (str or Path): A directory to be copied.
|
| 653 |
+
dst (str or Path): Copy directory to local dst.
|
| 654 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 655 |
+
corresponding backend. Defaults to None.
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
str: The destination directory.
|
| 659 |
+
|
| 660 |
+
Examples:
|
| 661 |
+
>>> src = 'http://example.com/dir'
|
| 662 |
+
>>> dst = '/path/of/dir'
|
| 663 |
+
>>> copytree_to_local(src, dst)
|
| 664 |
+
'/path/of/dir'
|
| 665 |
+
"""
|
| 666 |
+
Path(dst).parent.mkdir(parents=True, exist_ok=True)
|
| 667 |
+
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 668 |
+
return backend.copytree_to_local(src, dst)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def remove(
|
| 672 |
+
filepath: str | Path,
|
| 673 |
+
backend_args: dict | None = None,
|
| 674 |
+
backend_key: str | None = None,
|
| 675 |
+
) -> None:
|
| 676 |
+
"""Remove a file.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
filepath (str, Path): Path to be removed.
|
| 680 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 681 |
+
corresponding backend. Defaults to None.
|
| 682 |
+
|
| 683 |
+
Raises:
|
| 684 |
+
FileNotFoundError: If filepath does not exist, an FileNotFoundError
|
| 685 |
+
will be raised.
|
| 686 |
+
IsADirectoryError: If filepath is a directory, an IsADirectoryError
|
| 687 |
+
will be raised.
|
| 688 |
+
|
| 689 |
+
Examples:
|
| 690 |
+
>>> filepath = '/path/of/file'
|
| 691 |
+
>>> remove(filepath)
|
| 692 |
+
"""
|
| 693 |
+
backend = get_file_backend(
|
| 694 |
+
filepath,
|
| 695 |
+
backend_args=backend_args,
|
| 696 |
+
enable_singleton=True,
|
| 697 |
+
backend_key=backend_key,
|
| 698 |
+
)
|
| 699 |
+
backend.remove(filepath)
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def rmtree(
|
| 703 |
+
dir_path: str | Path,
|
| 704 |
+
backend_args: dict | None = None,
|
| 705 |
+
backend_key: str | None = None,
|
| 706 |
+
) -> None:
|
| 707 |
+
"""Recursively delete a directory tree.
|
| 708 |
+
|
| 709 |
+
Args:
|
| 710 |
+
dir_path (str or Path): A directory to be removed.
|
| 711 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 712 |
+
corresponding backend. Defaults to None.
|
| 713 |
+
|
| 714 |
+
Examples:
|
| 715 |
+
>>> dir_path = '/path/of/dir'
|
| 716 |
+
>>> rmtree(dir_path)
|
| 717 |
+
"""
|
| 718 |
+
backend = get_file_backend(
|
| 719 |
+
dir_path,
|
| 720 |
+
backend_args=backend_args,
|
| 721 |
+
enable_singleton=True,
|
| 722 |
+
backend_key=backend_key,
|
| 723 |
+
)
|
| 724 |
+
backend.rmtree(dir_path)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def copy_if_symlink_fails(
|
| 728 |
+
src: str | Path,
|
| 729 |
+
dst: str | Path,
|
| 730 |
+
backend_args: dict | None = None,
|
| 731 |
+
backend_key: str | None = None,
|
| 732 |
+
) -> bool:
|
| 733 |
+
"""Create a symbolic link pointing to src named dst.
|
| 734 |
+
|
| 735 |
+
If failed to create a symbolic link pointing to src, directory copy src to
|
| 736 |
+
dst instead.
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
src (str or Path): Create a symbolic link pointing to src.
|
| 740 |
+
dst (str or Path): Create a symbolic link named dst.
|
| 741 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 742 |
+
corresponding backend. Defaults to None.
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
bool: Return True if successfully create a symbolic link pointing to
|
| 746 |
+
src. Otherwise, return False.
|
| 747 |
+
|
| 748 |
+
Examples:
|
| 749 |
+
>>> src = '/path/of/file'
|
| 750 |
+
>>> dst = '/path1/of/file1'
|
| 751 |
+
>>> copy_if_symlink_fails(src, dst)
|
| 752 |
+
True
|
| 753 |
+
>>> src = '/path/of/dir'
|
| 754 |
+
>>> dst = '/path1/of/dir1'
|
| 755 |
+
>>> copy_if_symlink_fails(src, dst)
|
| 756 |
+
True
|
| 757 |
+
"""
|
| 758 |
+
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 759 |
+
return backend.copy_if_symlink_fails(src, dst)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def list_dir(
|
| 763 |
+
dir_path: str | Path,
|
| 764 |
+
backend_args: dict | None = None,
|
| 765 |
+
backend_key: str | None = None,
|
| 766 |
+
):
|
| 767 |
+
"""List all folders in a directory with a given path.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
dir_path (str | Path): Path of the directory.
|
| 771 |
+
|
| 772 |
+
Examples:
|
| 773 |
+
>>> dir_path = '/path/of/dir'
|
| 774 |
+
>>> for file_path in list_dir(dir_path):
|
| 775 |
+
... print(file_path)
|
| 776 |
+
"""
|
| 777 |
+
if not dir_path.endswith("/"):
|
| 778 |
+
dir_path += "/"
|
| 779 |
+
backend = get_file_backend(
|
| 780 |
+
dir_path,
|
| 781 |
+
backend_args=backend_args,
|
| 782 |
+
enable_singleton=True,
|
| 783 |
+
backend_key=backend_key,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
return backend.list_dir(dir_path)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def list_dir_or_file(
|
| 790 |
+
dir_path: str | Path,
|
| 791 |
+
list_dir: bool = True,
|
| 792 |
+
list_file: bool = True,
|
| 793 |
+
suffix: str | tuple[str] | None = None,
|
| 794 |
+
recursive: bool = False,
|
| 795 |
+
backend_args: dict | None = None,
|
| 796 |
+
backend_key: str | None = None,
|
| 797 |
+
) -> Iterator[str]:
|
| 798 |
+
"""Scan a directory to find the interested directories or files in
|
| 799 |
+
arbitrary order.
|
| 800 |
+
|
| 801 |
+
Note:
|
| 802 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
dir_path (str or Path): Path of the directory.
|
| 806 |
+
list_dir (bool): List the directories. Defaults to True.
|
| 807 |
+
list_file (bool): List the path of files. Defaults to True.
|
| 808 |
+
suffix (str or tuple[str], optional): File suffix that we are
|
| 809 |
+
interested in. Defaults to None.
|
| 810 |
+
recursive (bool): If set to True, recursively scan the directory.
|
| 811 |
+
Defaults to False.
|
| 812 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 813 |
+
corresponding backend. Defaults to None.
|
| 814 |
+
|
| 815 |
+
Yields:
|
| 816 |
+
Iterable[str]: A relative path to ``dir_path``.
|
| 817 |
+
|
| 818 |
+
Examples:
|
| 819 |
+
>>> dir_path = '/path/of/dir'
|
| 820 |
+
>>> for file_path in list_dir_or_file(dir_path):
|
| 821 |
+
... print(file_path)
|
| 822 |
+
>>> # list those files and directories in current directory
|
| 823 |
+
>>> for file_path in list_dir_or_file(dir_path):
|
| 824 |
+
... print(file_path)
|
| 825 |
+
>>> # only list files
|
| 826 |
+
>>> for file_path in list_dir_or_file(dir_path, list_dir=False):
|
| 827 |
+
... print(file_path)
|
| 828 |
+
>>> # only list directories
|
| 829 |
+
>>> for file_path in list_dir_or_file(dir_path, list_file=False):
|
| 830 |
+
... print(file_path)
|
| 831 |
+
>>> # only list files ending with specified suffixes
|
| 832 |
+
>>> for file_path in list_dir_or_file(dir_path, suffix='.txt'):
|
| 833 |
+
... print(file_path)
|
| 834 |
+
>>> # list all files and directory recursively
|
| 835 |
+
>>> for file_path in list_dir_or_file(dir_path, recursive=True):
|
| 836 |
+
... print(file_path)
|
| 837 |
+
"""
|
| 838 |
+
backend = get_file_backend(
|
| 839 |
+
dir_path,
|
| 840 |
+
backend_args=backend_args,
|
| 841 |
+
enable_singleton=True,
|
| 842 |
+
backend_key=backend_key,
|
| 843 |
+
)
|
| 844 |
+
yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def load(
|
| 848 |
+
file: str | Path | IO[Any],
|
| 849 |
+
file_format: str | None = None,
|
| 850 |
+
file_client_args: dict | None = None,
|
| 851 |
+
fast_backend: bool = False,
|
| 852 |
+
backend_args: dict | None = None,
|
| 853 |
+
backend_key: str | None = None,
|
| 854 |
+
**kwargs,
|
| 855 |
+
):
|
| 856 |
+
"""Load data from json/yaml/pickle files.
|
| 857 |
+
|
| 858 |
+
This method provides a unified api for loading data from serialized files.
|
| 859 |
+
|
| 860 |
+
``load`` supports loading data from serialized files those can be storaged
|
| 861 |
+
in different backends.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
| 865 |
+
object.
|
| 866 |
+
file_format (str, optional): If not specified, the file format will be
|
| 867 |
+
inferred from the file extension, otherwise use the specified one.
|
| 868 |
+
Currently supported formats include "json", "yaml/yml" and
|
| 869 |
+
"pickle/pkl".
|
| 870 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
| 871 |
+
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
| 872 |
+
Defaults to None. It will be deprecated in future. Please use
|
| 873 |
+
``backend_args`` instead.
|
| 874 |
+
fast_backend: bool: Whether to use multiprocess. Defaults to False.
|
| 875 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 876 |
+
prefix of uri corresponding backend. Defaults to None.
|
| 877 |
+
New in v0.2.0.
|
| 878 |
+
|
| 879 |
+
Examples:
|
| 880 |
+
>>> load('/path/of/your/file') # file is storaged in disk
|
| 881 |
+
>>> load('https://path/of/your/file') # file is storaged in Internet
|
| 882 |
+
|
| 883 |
+
Returns:
|
| 884 |
+
The content from the file.
|
| 885 |
+
"""
|
| 886 |
+
if isinstance(file, Path):
|
| 887 |
+
file = str(file)
|
| 888 |
+
if file_format is None and isinstance(file, str):
|
| 889 |
+
file_format = file.split(".")[-1]
|
| 890 |
+
# convert file_format to lower case
|
| 891 |
+
file_format = file_format.lower()
|
| 892 |
+
if file_format not in file_handlers:
|
| 893 |
+
raise TypeError(f"Unsupported format: {file_format}")
|
| 894 |
+
|
| 895 |
+
if file_client_args is not None:
|
| 896 |
+
warnings.warn( # noqa: B028
|
| 897 |
+
'"file_client_args" will be deprecated in future. Please use "backend_args" instead',
|
| 898 |
+
DeprecationWarning,
|
| 899 |
+
)
|
| 900 |
+
if backend_args is not None:
|
| 901 |
+
raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.')
|
| 902 |
+
|
| 903 |
+
handler = file_handlers[file_format]
|
| 904 |
+
if isinstance(file, str):
|
| 905 |
+
if file_client_args is not None:
|
| 906 |
+
file_client = FileClient.infer_client(file_client_args, file)
|
| 907 |
+
file_backend = file_client
|
| 908 |
+
else:
|
| 909 |
+
file_backend = get_file_backend(
|
| 910 |
+
file,
|
| 911 |
+
backend_args=backend_args,
|
| 912 |
+
backend_key=backend_key,
|
| 913 |
+
enable_singleton=True,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if handler.str_like:
|
| 917 |
+
with StringIO(file_backend.get_text(file)) as f:
|
| 918 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
| 919 |
+
else:
|
| 920 |
+
if fast_backend:
|
| 921 |
+
if hasattr(file_backend, "fast_get"):
|
| 922 |
+
with BytesIO(file_backend.fast_get(file)) as f:
|
| 923 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
| 924 |
+
else:
|
| 925 |
+
warnings.warn( # noqa: B028
|
| 926 |
+
f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get"
|
| 927 |
+
)
|
| 928 |
+
with BytesIO(file_backend.get(file)) as f:
|
| 929 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
| 930 |
+
else:
|
| 931 |
+
with BytesIO(file_backend.get(file)) as f:
|
| 932 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
| 933 |
+
elif hasattr(file, "read"):
|
| 934 |
+
obj = handler.load_from_fileobj(file, **kwargs)
|
| 935 |
+
else:
|
| 936 |
+
raise TypeError('"file" must be a filepath str or a file-object')
|
| 937 |
+
return obj
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def dump(
|
| 941 |
+
obj: Any,
|
| 942 |
+
file: str | Path | IO[Any] | None = None,
|
| 943 |
+
file_format: str | None = None,
|
| 944 |
+
file_client_args: dict | None = None,
|
| 945 |
+
fast_backend: bool = False,
|
| 946 |
+
backend_args: dict | None = None,
|
| 947 |
+
backend_key: str | None = None,
|
| 948 |
+
**kwargs,
|
| 949 |
+
):
|
| 950 |
+
"""Dump data to json/yaml/pickle strings or files.
|
| 951 |
+
|
| 952 |
+
This method provides a unified api for dumping data as strings or to files,
|
| 953 |
+
and also supports custom arguments for each file format.
|
| 954 |
+
|
| 955 |
+
``dump`` supports dumping data as strings or to files which is saved to
|
| 956 |
+
different backends.
|
| 957 |
+
|
| 958 |
+
Args:
|
| 959 |
+
obj (any): The python object to be dumped.
|
| 960 |
+
file (str or :obj:`Path` or file-like object, optional): If not
|
| 961 |
+
specified, then the object is dumped to a str, otherwise to a file
|
| 962 |
+
specified by the filename or file-like object.
|
| 963 |
+
file_format (str, optional): Same as :func:`load`.
|
| 964 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
| 965 |
+
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
| 966 |
+
Defaults to None. It will be deprecated in future. Please use
|
| 967 |
+
``backend_args`` instead.
|
| 968 |
+
fast_backend: bool: Whether to use multiprocess. Defaults to False.
|
| 969 |
+
backend_args (dict, optional): Arguments to instantiate the
|
| 970 |
+
prefix of uri corresponding backend. Defaults to None.
|
| 971 |
+
New in v0.2.0.
|
| 972 |
+
backend_key: str: The key to register the backend. Defaults to None.
|
| 973 |
+
|
| 974 |
+
Examples:
|
| 975 |
+
>>> dump('hello world', '/path/of/your/file') # disk
|
| 976 |
+
>>> dump('hello world', 'http://path/of/your/file') # http
|
| 977 |
+
|
| 978 |
+
Returns:
|
| 979 |
+
bool: True for success, False otherwise.
|
| 980 |
+
"""
|
| 981 |
+
if isinstance(file, Path):
|
| 982 |
+
file = str(file)
|
| 983 |
+
if file_format is None:
|
| 984 |
+
if isinstance(file, str):
|
| 985 |
+
file_format = file.split(".")[-1]
|
| 986 |
+
elif file is None:
|
| 987 |
+
raise ValueError("file_format must be specified since file is None")
|
| 988 |
+
# convert file_format to lower case
|
| 989 |
+
file_format = file_format.lower()
|
| 990 |
+
if file_format not in file_handlers:
|
| 991 |
+
raise TypeError(f"Unsupported format: {file_format}")
|
| 992 |
+
|
| 993 |
+
if file_client_args is not None:
|
| 994 |
+
warnings.warn( # noqa: B028
|
| 995 |
+
'"file_client_args" will be deprecated in future. Please use "backend_args" instead',
|
| 996 |
+
DeprecationWarning,
|
| 997 |
+
)
|
| 998 |
+
if backend_args is not None:
|
| 999 |
+
raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.')
|
| 1000 |
+
|
| 1001 |
+
handler = file_handlers[file_format]
|
| 1002 |
+
if file is None:
|
| 1003 |
+
return handler.dump_to_str(obj, **kwargs)
|
| 1004 |
+
elif isinstance(file, str):
|
| 1005 |
+
if file_client_args is not None:
|
| 1006 |
+
file_client = FileClient.infer_client(file_client_args, file)
|
| 1007 |
+
file_backend = file_client
|
| 1008 |
+
else:
|
| 1009 |
+
file_backend = get_file_backend(
|
| 1010 |
+
file,
|
| 1011 |
+
backend_args=backend_args,
|
| 1012 |
+
backend_key=backend_key,
|
| 1013 |
+
enable_singleton=True,
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
if handler.str_like:
|
| 1017 |
+
with StringIO() as f:
|
| 1018 |
+
handler.dump_to_fileobj(obj, f, **kwargs)
|
| 1019 |
+
file_backend.put_text(f.getvalue(), file)
|
| 1020 |
+
else:
|
| 1021 |
+
with BytesIO() as f:
|
| 1022 |
+
handler.dump_to_fileobj(obj, f, **kwargs)
|
| 1023 |
+
if fast_backend:
|
| 1024 |
+
if hasattr(file_backend, "fast_put"):
|
| 1025 |
+
file_backend.fast_put(f, file)
|
| 1026 |
+
else:
|
| 1027 |
+
warnings.warn("fast_backend is not supported by the backend, fallback to normal put") # noqa: B028
|
| 1028 |
+
file_backend.put(f, file)
|
| 1029 |
+
else:
|
| 1030 |
+
file_backend.put(f, file)
|
| 1031 |
+
elif hasattr(file, "write"):
|
| 1032 |
+
handler.dump_to_fileobj(obj, file, **kwargs)
|
| 1033 |
+
else:
|
| 1034 |
+
raise TypeError('"file" must be a filename str or a file-object')
|
imaginaire/utils/easy_io/file_client.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
from collections.abc import Generator, Iterator
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
from imaginaire.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_filepath(filepath):
|
| 26 |
+
return isinstance(filepath, (str, Path))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HardDiskBackend(LocalBackend):
|
| 30 |
+
"""Raw hard disks storage backend."""
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def name(self):
|
| 34 |
+
return self.__class__.__name__
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FileClient:
|
| 38 |
+
"""A general file client to access files in different backends.
|
| 39 |
+
|
| 40 |
+
The client loads a file or text in a specified backend from its path
|
| 41 |
+
and returns it as a binary or text file. There are two ways to choose a
|
| 42 |
+
backend, the name of backend and the prefix of path. Although both of them
|
| 43 |
+
can be used to choose a storage backend, ``backend`` has a higher priority
|
| 44 |
+
that is if they are all set, the storage backend will be chosen by the
|
| 45 |
+
backend argument. If they are all `None`, the disk backend will be chosen.
|
| 46 |
+
Note that It can also register other backend accessor with a given name,
|
| 47 |
+
prefixes, and backend class. In addition, We use the singleton pattern to
|
| 48 |
+
avoid repeated object creation. If the arguments are the same, the same
|
| 49 |
+
object will be returned.
|
| 50 |
+
|
| 51 |
+
Warning:
|
| 52 |
+
`FileClient` will be deprecated in future. Please use io functions
|
| 53 |
+
in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
backend (str, optional): The storage backend type. Options are "disk",
|
| 57 |
+
"memcached", "lmdb" and "http". Defaults to None.
|
| 58 |
+
prefix (str, optional): The prefix of the registered storage backend.
|
| 59 |
+
Options are "http", "https". Defaults to None.
|
| 60 |
+
|
| 61 |
+
Examples:
|
| 62 |
+
>>> # only set backend
|
| 63 |
+
>>> file_client = FileClient(backend='disk')
|
| 64 |
+
>>> # only set prefix
|
| 65 |
+
>>> file_client = FileClient(prefix='http')
|
| 66 |
+
>>> # set both backend and prefix but use backend to choose client
|
| 67 |
+
>>> file_client = FileClient(backend='http', prefix='http')
|
| 68 |
+
>>> # if the arguments are the same, the same object is returned
|
| 69 |
+
>>> file_client1 = FileClient(backend='disk')
|
| 70 |
+
>>> file_client1 is file_client
|
| 71 |
+
True
|
| 72 |
+
|
| 73 |
+
Attributes:
|
| 74 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
_backends = { # noqa: RUF012
|
| 78 |
+
"disk": HardDiskBackend,
|
| 79 |
+
"http": HTTPBackend,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
_prefix_to_backends: dict = { # noqa: RUF012
|
| 83 |
+
"http": HTTPBackend,
|
| 84 |
+
"https": HTTPBackend,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
_instances: dict = {} # noqa: RUF012
|
| 88 |
+
|
| 89 |
+
client: Any
|
| 90 |
+
|
| 91 |
+
def __new__(cls, backend=None, prefix=None, **kwargs):
|
| 92 |
+
if backend is None and prefix is None:
|
| 93 |
+
backend = "disk"
|
| 94 |
+
if backend is not None and backend not in cls._backends:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}"
|
| 97 |
+
)
|
| 98 |
+
if prefix is not None and prefix not in cls._prefix_to_backends:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# concatenate the arguments to a unique key for determining whether
|
| 104 |
+
# objects with the same arguments were created
|
| 105 |
+
arg_key = f"{backend}:{prefix}"
|
| 106 |
+
for key, value in kwargs.items():
|
| 107 |
+
arg_key += f":{key}:{value}"
|
| 108 |
+
|
| 109 |
+
# if a backend was overridden, it will create a new object
|
| 110 |
+
if arg_key in cls._instances:
|
| 111 |
+
_instance = cls._instances[arg_key]
|
| 112 |
+
else:
|
| 113 |
+
# create a new object and put it to _instance
|
| 114 |
+
_instance = super().__new__(cls)
|
| 115 |
+
if backend is not None:
|
| 116 |
+
_instance.client = cls._backends[backend](**kwargs)
|
| 117 |
+
else:
|
| 118 |
+
_instance.client = cls._prefix_to_backends[prefix](**kwargs)
|
| 119 |
+
|
| 120 |
+
cls._instances[arg_key] = _instance
|
| 121 |
+
|
| 122 |
+
return _instance
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def name(self):
|
| 126 |
+
return self.client.name
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def allow_symlink(self):
|
| 130 |
+
return self.client.allow_symlink
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def parse_uri_prefix(uri: str | Path) -> str | None:
|
| 134 |
+
"""Parse the prefix of a uri.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
uri (str | Path): Uri to be parsed that contains the file prefix.
|
| 138 |
+
|
| 139 |
+
Examples:
|
| 140 |
+
>>> FileClient.parse_uri_prefix('http://path/of/your/file')
|
| 141 |
+
'http'
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
str | None: Return the prefix of uri if the uri contains '://' else
|
| 145 |
+
``None``.
|
| 146 |
+
"""
|
| 147 |
+
assert is_filepath(uri)
|
| 148 |
+
uri = str(uri)
|
| 149 |
+
if "://" not in uri:
|
| 150 |
+
return None
|
| 151 |
+
else:
|
| 152 |
+
prefix, _ = uri.split("://")
|
| 153 |
+
return prefix
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def infer_client(
|
| 157 |
+
cls,
|
| 158 |
+
file_client_args: dict | None = None,
|
| 159 |
+
uri: str | Path | None = None,
|
| 160 |
+
) -> "FileClient":
|
| 161 |
+
"""Infer a suitable file client based on the URI and arguments.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
| 165 |
+
FileClient. Defaults to None.
|
| 166 |
+
uri (str | Path, optional): Uri to be parsed that contains the file
|
| 167 |
+
prefix. Defaults to None.
|
| 168 |
+
|
| 169 |
+
Examples:
|
| 170 |
+
>>> uri = 'http://path/of/your/file'
|
| 171 |
+
>>> file_client = FileClient.infer_client(uri=uri)
|
| 172 |
+
>>> file_client_args = {'backend': 'disk'}
|
| 173 |
+
>>> file_client = FileClient.infer_client(file_client_args)
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
FileClient: Instantiated FileClient object.
|
| 177 |
+
"""
|
| 178 |
+
assert file_client_args is not None or uri is not None
|
| 179 |
+
if file_client_args is None:
|
| 180 |
+
file_prefix = cls.parse_uri_prefix(uri) # type: ignore
|
| 181 |
+
return cls(prefix=file_prefix)
|
| 182 |
+
else:
|
| 183 |
+
return cls(**file_client_args)
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def _register_backend(cls, name, backend, force=False, prefixes=None):
|
| 187 |
+
if not isinstance(name, str):
|
| 188 |
+
raise TypeError(f"the backend name should be a string, but got {type(name)}")
|
| 189 |
+
if not inspect.isclass(backend):
|
| 190 |
+
raise TypeError(f"backend should be a class but got {type(backend)}")
|
| 191 |
+
if not issubclass(backend, BaseStorageBackend):
|
| 192 |
+
raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
|
| 193 |
+
if not force and name in cls._backends:
|
| 194 |
+
raise KeyError(
|
| 195 |
+
f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if name in cls._backends and force:
|
| 199 |
+
for arg_key, instance in list(cls._instances.items()):
|
| 200 |
+
if isinstance(instance.client, cls._backends[name]):
|
| 201 |
+
cls._instances.pop(arg_key)
|
| 202 |
+
cls._backends[name] = backend
|
| 203 |
+
|
| 204 |
+
if prefixes is not None:
|
| 205 |
+
if isinstance(prefixes, str):
|
| 206 |
+
prefixes = [prefixes]
|
| 207 |
+
else:
|
| 208 |
+
assert isinstance(prefixes, (list, tuple))
|
| 209 |
+
for prefix in prefixes:
|
| 210 |
+
if prefix not in cls._prefix_to_backends:
|
| 211 |
+
cls._prefix_to_backends[prefix] = backend
|
| 212 |
+
elif (prefix in cls._prefix_to_backends) and force:
|
| 213 |
+
overridden_backend = cls._prefix_to_backends[prefix]
|
| 214 |
+
for arg_key, instance in list(cls._instances.items()):
|
| 215 |
+
if isinstance(instance.client, overridden_backend):
|
| 216 |
+
cls._instances.pop(arg_key)
|
| 217 |
+
else:
|
| 218 |
+
raise KeyError(
|
| 219 |
+
f"{prefix} is already registered as a storage backend,"
|
| 220 |
+
' add "force=True" if you want to override it'
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def register_backend(cls, name, backend=None, force=False, prefixes=None):
|
| 225 |
+
"""Register a backend to FileClient.
|
| 226 |
+
|
| 227 |
+
This method can be used as a normal class method or a decorator.
|
| 228 |
+
|
| 229 |
+
.. code-block:: python
|
| 230 |
+
|
| 231 |
+
class NewBackend(BaseStorageBackend):
|
| 232 |
+
|
| 233 |
+
def get(self, filepath):
|
| 234 |
+
return filepath
|
| 235 |
+
|
| 236 |
+
def get_text(self, filepath):
|
| 237 |
+
return filepath
|
| 238 |
+
|
| 239 |
+
FileClient.register_backend('new', NewBackend)
|
| 240 |
+
|
| 241 |
+
or
|
| 242 |
+
|
| 243 |
+
.. code-block:: python
|
| 244 |
+
|
| 245 |
+
@FileClient.register_backend('new')
|
| 246 |
+
class NewBackend(BaseStorageBackend):
|
| 247 |
+
|
| 248 |
+
def get(self, filepath):
|
| 249 |
+
return filepath
|
| 250 |
+
|
| 251 |
+
def get_text(self, filepath):
|
| 252 |
+
return filepath
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
name (str): The name of the registered backend.
|
| 256 |
+
backend (class, optional): The backend class to be registered,
|
| 257 |
+
which must be a subclass of :class:`BaseStorageBackend`.
|
| 258 |
+
When this method is used as a decorator, backend is None.
|
| 259 |
+
Defaults to None.
|
| 260 |
+
force (bool, optional): Whether to override the backend if the name
|
| 261 |
+
has already been registered. Defaults to False.
|
| 262 |
+
prefixes (str or list[str] or tuple[str], optional): The prefixes
|
| 263 |
+
of the registered storage backend. Defaults to None.
|
| 264 |
+
`New in version 1.3.15.`
|
| 265 |
+
"""
|
| 266 |
+
if backend is not None:
|
| 267 |
+
cls._register_backend(name, backend, force=force, prefixes=prefixes)
|
| 268 |
+
return
|
| 269 |
+
|
| 270 |
+
def _register(backend_cls):
|
| 271 |
+
cls._register_backend(name, backend_cls, force=force, prefixes=prefixes)
|
| 272 |
+
return backend_cls
|
| 273 |
+
|
| 274 |
+
return _register
|
| 275 |
+
|
| 276 |
+
def get(self, filepath: str | Path) -> bytes | memoryview:
|
| 277 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 278 |
+
|
| 279 |
+
Note:
|
| 280 |
+
There are two types of return values for ``get``, one is ``bytes``
|
| 281 |
+
and the other is ``memoryview``. The advantage of using memoryview
|
| 282 |
+
is that you can avoid copying, and if you want to convert it to
|
| 283 |
+
``bytes``, you can use ``.tobytes()``.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
filepath (str or Path): Path to read data.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
bytes | memoryview: Expected bytes object or a memory view of the
|
| 290 |
+
bytes object.
|
| 291 |
+
"""
|
| 292 |
+
return self.client.get(filepath)
|
| 293 |
+
|
| 294 |
+
def get_text(self, filepath: str | Path, encoding="utf-8") -> str:
|
| 295 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
filepath (str or Path): Path to read data.
|
| 299 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
| 300 |
+
Defaults to 'utf-8'.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
str: Expected text reading from ``filepath``.
|
| 304 |
+
"""
|
| 305 |
+
return self.client.get_text(filepath, encoding)
|
| 306 |
+
|
| 307 |
+
def put(self, obj: bytes, filepath: str | Path) -> None:
|
| 308 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 309 |
+
|
| 310 |
+
Note:
|
| 311 |
+
``put`` should create a directory if the directory of ``filepath``
|
| 312 |
+
does not exist.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
obj (bytes): Data to be written.
|
| 316 |
+
filepath (str or Path): Path to write data.
|
| 317 |
+
"""
|
| 318 |
+
self.client.put(obj, filepath)
|
| 319 |
+
|
| 320 |
+
def put_text(self, obj: str, filepath: str | Path) -> None:
|
| 321 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
| 322 |
+
|
| 323 |
+
Note:
|
| 324 |
+
``put_text`` should create a directory if the directory of
|
| 325 |
+
``filepath`` does not exist.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
obj (str): Data to be written.
|
| 329 |
+
filepath (str or Path): Path to write data.
|
| 330 |
+
encoding (str, optional): The encoding format used to open the
|
| 331 |
+
`filepath`. Defaults to 'utf-8'.
|
| 332 |
+
"""
|
| 333 |
+
self.client.put_text(obj, filepath)
|
| 334 |
+
|
| 335 |
+
def remove(self, filepath: str | Path) -> None:
|
| 336 |
+
"""Remove a file.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
filepath (str, Path): Path to be removed.
|
| 340 |
+
"""
|
| 341 |
+
self.client.remove(filepath)
|
| 342 |
+
|
| 343 |
+
def exists(self, filepath: str | Path) -> bool:
|
| 344 |
+
"""Check whether a file path exists.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
filepath (str or Path): Path to be checked whether exists.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 351 |
+
"""
|
| 352 |
+
return self.client.exists(filepath)
|
| 353 |
+
|
| 354 |
+
def isdir(self, filepath: str | Path) -> bool:
|
| 355 |
+
"""Check whether a file path is a directory.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
filepath (str or Path): Path to be checked whether it is a
|
| 359 |
+
directory.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 363 |
+
``False`` otherwise.
|
| 364 |
+
"""
|
| 365 |
+
return self.client.isdir(filepath)
|
| 366 |
+
|
| 367 |
+
def isfile(self, filepath: str | Path) -> bool:
|
| 368 |
+
"""Check whether a file path is a file.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 375 |
+
otherwise.
|
| 376 |
+
"""
|
| 377 |
+
return self.client.isfile(filepath)
|
| 378 |
+
|
| 379 |
+
def join_path(self, filepath: str | Path, *filepaths: str | Path) -> str:
|
| 380 |
+
r"""Concatenate all file paths.
|
| 381 |
+
|
| 382 |
+
Join one or more filepath components intelligently. The return value
|
| 383 |
+
is the concatenation of filepath and any members of \*filepaths.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
filepath (str or Path): Path to be concatenated.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
str: The result of concatenation.
|
| 390 |
+
"""
|
| 391 |
+
return self.client.join_path(filepath, *filepaths)
|
| 392 |
+
|
| 393 |
+
@contextmanager
|
| 394 |
+
def get_local_path(self, filepath: str | Path) -> Generator[str | Path, None, None]:
|
| 395 |
+
"""Download data from ``filepath`` and write the data to local path.
|
| 396 |
+
|
| 397 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 398 |
+
can be called with ``with`` statement, and when exists from the
|
| 399 |
+
``with`` statement, the temporary path will be released.
|
| 400 |
+
|
| 401 |
+
Note:
|
| 402 |
+
If the ``filepath`` is a local path, just return itself.
|
| 403 |
+
|
| 404 |
+
.. warning::
|
| 405 |
+
``get_local_path`` is an experimental interface that may change in
|
| 406 |
+
the future.
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
filepath (str or Path): Path to be read data.
|
| 410 |
+
|
| 411 |
+
Examples:
|
| 412 |
+
>>> file_client = FileClient(prefix='http')
|
| 413 |
+
>>> with file_client.get_local_path('http://example.com/abc.jpg') as path:
|
| 414 |
+
... # do something here
|
| 415 |
+
|
| 416 |
+
Yields:
|
| 417 |
+
Iterable[str]: Only yield one path.
|
| 418 |
+
"""
|
| 419 |
+
with self.client.get_local_path(str(filepath)) as local_path:
|
| 420 |
+
yield local_path
|
| 421 |
+
|
| 422 |
+
def list_dir_or_file( # pylint: disable=too-many-arguments
|
| 423 |
+
self,
|
| 424 |
+
dir_path: str | Path,
|
| 425 |
+
list_dir: bool = True,
|
| 426 |
+
list_file: bool = True,
|
| 427 |
+
suffix: str | tuple[str] | None = None,
|
| 428 |
+
recursive: bool = False,
|
| 429 |
+
) -> Iterator[str]:
|
| 430 |
+
"""Scan a directory to find the interested directories or files in
|
| 431 |
+
arbitrary order.
|
| 432 |
+
|
| 433 |
+
Note:
|
| 434 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
dir_path (str | Path): Path of the directory.
|
| 438 |
+
list_dir (bool): List the directories. Defaults to True.
|
| 439 |
+
list_file (bool): List the path of files. Defaults to True.
|
| 440 |
+
suffix (str or tuple[str], optional): File suffix
|
| 441 |
+
that we are interested in. Defaults to None.
|
| 442 |
+
recursive (bool): If set to True, recursively scan the
|
| 443 |
+
directory. Defaults to False.
|
| 444 |
+
|
| 445 |
+
Yields:
|
| 446 |
+
Iterable[str]: A relative path to ``dir_path``.
|
| 447 |
+
"""
|
| 448 |
+
yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
imaginaire/utils/easy_io/handlers/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 17 |
+
from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
|
| 18 |
+
from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
|
| 19 |
+
from imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler
|
| 20 |
+
from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"BaseFileHandler",
|
| 24 |
+
"JsonHandler",
|
| 25 |
+
"PickleHandler",
|
| 26 |
+
"YamlHandler",
|
| 27 |
+
"file_handlers",
|
| 28 |
+
"register_handler",
|
| 29 |
+
]
|
imaginaire/utils/easy_io/handlers/base.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from abc import ABCMeta, abstractmethod
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BaseFileHandler(metaclass=ABCMeta):
|
| 20 |
+
# `str_like` is a flag to indicate whether the type of file object is
|
| 21 |
+
# str-like object or bytes-like object. Pickle only processes bytes-like
|
| 22 |
+
# objects but json only processes str-like object. If it is str-like
|
| 23 |
+
# object, `StringIO` will be used to process the buffer.
|
| 24 |
+
str_like = True
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def dump_to_str(self, obj, **kwargs):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def load_from_path(self, filepath, mode="r", **kwargs):
|
| 39 |
+
with open(filepath, mode) as f:
|
| 40 |
+
return self.load_from_fileobj(f, **kwargs)
|
| 41 |
+
|
| 42 |
+
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
|
| 43 |
+
with open(filepath, mode) as f:
|
| 44 |
+
self.dump_to_fileobj(obj, f, **kwargs)
|
imaginaire/utils/easy_io/handlers/byte_handler.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import IO
|
| 17 |
+
|
| 18 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ByteHandler(BaseFileHandler):
|
| 22 |
+
str_like = False
|
| 23 |
+
|
| 24 |
+
def load_from_fileobj(self, file: IO[bytes], **kwargs):
|
| 25 |
+
file.seek(0)
|
| 26 |
+
# extra all bytes and return
|
| 27 |
+
return file.read()
|
| 28 |
+
|
| 29 |
+
def dump_to_fileobj(
|
| 30 |
+
self,
|
| 31 |
+
obj: bytes,
|
| 32 |
+
file: IO[bytes],
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
# write all bytes to file
|
| 36 |
+
file.write(obj)
|
| 37 |
+
|
| 38 |
+
def dump_to_str(self, obj, **kwargs):
|
| 39 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/csv_handler.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import csv
|
| 17 |
+
from io import StringIO
|
| 18 |
+
|
| 19 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CsvHandler(BaseFileHandler):
|
| 23 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 24 |
+
del kwargs
|
| 25 |
+
reader = csv.reader(file)
|
| 26 |
+
return list(reader)
|
| 27 |
+
|
| 28 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 29 |
+
del kwargs
|
| 30 |
+
writer = csv.writer(file)
|
| 31 |
+
if not all(isinstance(row, list) for row in obj):
|
| 32 |
+
raise ValueError("Each row must be a list")
|
| 33 |
+
writer.writerows(obj)
|
| 34 |
+
|
| 35 |
+
def dump_to_str(self, obj, **kwargs):
|
| 36 |
+
del kwargs
|
| 37 |
+
output = StringIO()
|
| 38 |
+
writer = csv.writer(output)
|
| 39 |
+
if not all(isinstance(row, list) for row in obj):
|
| 40 |
+
raise ValueError("Each row must be a list")
|
| 41 |
+
writer.writerows(obj)
|
| 42 |
+
return output.getvalue()
|
imaginaire/utils/easy_io/handlers/gzip_handler.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import gzip
|
| 17 |
+
import pickle
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GzipHandler(PickleHandler):
|
| 25 |
+
str_like = False
|
| 26 |
+
|
| 27 |
+
def load_from_fileobj(self, file: BytesIO, **kwargs):
|
| 28 |
+
with gzip.GzipFile(fileobj=file, mode="rb") as f:
|
| 29 |
+
return pickle.load(f)
|
| 30 |
+
|
| 31 |
+
def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
|
| 32 |
+
with gzip.GzipFile(fileobj=file, mode="wb") as f:
|
| 33 |
+
pickle.dump(obj, f)
|
imaginaire/utils/easy_io/handlers/imageio_video_handler.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import IO, Any
|
| 17 |
+
|
| 18 |
+
import imageio
|
| 19 |
+
import imageio.v3 as iio_v3
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from imaginaire.utils import log
|
| 24 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ImageioVideoHandler(BaseFileHandler):
|
| 28 |
+
str_like = False
|
| 29 |
+
|
| 30 |
+
def load_from_fileobj(
|
| 31 |
+
self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs
|
| 32 |
+
) -> tuple[np.ndarray, dict[str, Any]]:
|
| 33 |
+
"""
|
| 34 |
+
Load video from a file-like object using imageio.v3 with specified format and color mode.
|
| 35 |
+
|
| 36 |
+
Parameters:
|
| 37 |
+
file (IO[bytes]): A file-like object containing video data.
|
| 38 |
+
format (str): Format of the video file (default 'mp4').
|
| 39 |
+
mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb').
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
tuple: A tuple containing an array of video frames and metadata about the video.
|
| 43 |
+
"""
|
| 44 |
+
file.seek(0)
|
| 45 |
+
|
| 46 |
+
# The plugin argument in v3 replaces the format argument in v2
|
| 47 |
+
plugin = kwargs.pop("plugin", "pyav")
|
| 48 |
+
|
| 49 |
+
# Load all frames at once using v3 API
|
| 50 |
+
video_frames = iio_v3.imread(file, plugin=plugin, **kwargs)
|
| 51 |
+
|
| 52 |
+
# Handle grayscale conversion if needed
|
| 53 |
+
if mode == "gray":
|
| 54 |
+
import cv2
|
| 55 |
+
|
| 56 |
+
if len(video_frames.shape) == 4: # (frames, height, width, channels)
|
| 57 |
+
gray_frames = []
|
| 58 |
+
for frame in video_frames:
|
| 59 |
+
gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 60 |
+
gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent
|
| 61 |
+
gray_frames.append(gray_frame)
|
| 62 |
+
video_frames = np.array(gray_frames)
|
| 63 |
+
|
| 64 |
+
# Extract metadata
|
| 65 |
+
# Note: iio_v3.imread doesn't return metadata directly like v2 did
|
| 66 |
+
# We need to extract it separately
|
| 67 |
+
file.seek(0)
|
| 68 |
+
metadata = self._extract_metadata(file, plugin=plugin)
|
| 69 |
+
|
| 70 |
+
return video_frames, metadata
|
| 71 |
+
|
| 72 |
+
def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> dict[str, Any]:
|
| 73 |
+
"""
|
| 74 |
+
Extract metadata from a video file.
|
| 75 |
+
|
| 76 |
+
Parameters:
|
| 77 |
+
file (IO[bytes]): File-like object containing video data.
|
| 78 |
+
plugin (str): Plugin to use for reading.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
dict: Video metadata.
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
# Create a generator to read frames and metadata
|
| 85 |
+
metadata = iio_v3.immeta(file, plugin=plugin)
|
| 86 |
+
|
| 87 |
+
# Add some standard fields similar to v2 metadata format
|
| 88 |
+
if "fps" not in metadata and "duration" in metadata:
|
| 89 |
+
# Read the first frame to get shape information
|
| 90 |
+
file.seek(0)
|
| 91 |
+
first_frame = iio_v3.imread(file, plugin=plugin, index=0)
|
| 92 |
+
metadata["size"] = first_frame.shape[1::-1] # (width, height)
|
| 93 |
+
metadata["source_size"] = metadata["size"]
|
| 94 |
+
|
| 95 |
+
# Create a consistent metadata structure with v2
|
| 96 |
+
metadata["plugin"] = plugin
|
| 97 |
+
if "codec" not in metadata:
|
| 98 |
+
metadata["codec"] = "unknown"
|
| 99 |
+
if "pix_fmt" not in metadata:
|
| 100 |
+
metadata["pix_fmt"] = "unknown"
|
| 101 |
+
|
| 102 |
+
# Calculate nframes if possible
|
| 103 |
+
if "fps" in metadata and "duration" in metadata:
|
| 104 |
+
metadata["nframes"] = int(metadata["fps"] * metadata["duration"])
|
| 105 |
+
else:
|
| 106 |
+
metadata["nframes"] = float("inf")
|
| 107 |
+
|
| 108 |
+
return metadata
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
# Fallback to basic metadata
|
| 112 |
+
return {
|
| 113 |
+
"plugin": plugin,
|
| 114 |
+
"nframes": float("inf"),
|
| 115 |
+
"codec": "unknown",
|
| 116 |
+
"fps": 30.0, # Default values
|
| 117 |
+
"duration": 0,
|
| 118 |
+
"size": (0, 0),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def dump_to_fileobj(
|
| 122 |
+
self,
|
| 123 |
+
obj: np.ndarray | torch.Tensor,
|
| 124 |
+
file: IO[bytes],
|
| 125 |
+
format: str = "mp4", # pylint: disable=redefined-builtin
|
| 126 |
+
fps: int = 17,
|
| 127 |
+
quality: int = 7,
|
| 128 |
+
ffmpeg_params=None,
|
| 129 |
+
**kwargs,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
Save an array of video frames to a file-like object using imageio.
|
| 133 |
+
|
| 134 |
+
Parameters:
|
| 135 |
+
obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video.
|
| 136 |
+
file (IO[bytes]): A file-like object to which the video data will be written.
|
| 137 |
+
format (str): Format of the video file (default 'mp4').
|
| 138 |
+
fps (int): Frames per second of the output video (default 17).
|
| 139 |
+
quality (int): Quality of the video (0-10, default 5).
|
| 140 |
+
ffmpeg_params (list): Additional parameters to pass to ffmpeg.
|
| 141 |
+
|
| 142 |
+
"""
|
| 143 |
+
if isinstance(obj, torch.Tensor):
|
| 144 |
+
assert obj.dtype == torch.uint8, "Tensor must be of type uint8"
|
| 145 |
+
obj = obj.cpu().numpy()
|
| 146 |
+
h, w = obj.shape[1:-1]
|
| 147 |
+
|
| 148 |
+
# Default ffmpeg params that ensure width and height are set
|
| 149 |
+
default_ffmpeg_params = ["-s", f"{w}x{h}"]
|
| 150 |
+
|
| 151 |
+
# Use provided ffmpeg_params if any, otherwise use defaults
|
| 152 |
+
final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params
|
| 153 |
+
|
| 154 |
+
mimsave_kwargs = {
|
| 155 |
+
"fps": fps,
|
| 156 |
+
"quality": quality,
|
| 157 |
+
"macro_block_size": 1,
|
| 158 |
+
"ffmpeg_params": final_ffmpeg_params,
|
| 159 |
+
"output_params": ["-f", "mp4"],
|
| 160 |
+
}
|
| 161 |
+
# Update with any other kwargs
|
| 162 |
+
mimsave_kwargs.update(kwargs)
|
| 163 |
+
log.debug(f"mimsave_kwargs: {mimsave_kwargs}")
|
| 164 |
+
|
| 165 |
+
imageio.mimsave(file, obj, format, **mimsave_kwargs)
|
| 166 |
+
|
| 167 |
+
def dump_to_str(self, obj, **kwargs):
|
| 168 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/json_handler.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def set_default(obj):
|
| 24 |
+
"""Set default json values for non-serializable values.
|
| 25 |
+
|
| 26 |
+
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
| 27 |
+
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
| 28 |
+
etc.) into plain numbers of plain python built-in types.
|
| 29 |
+
"""
|
| 30 |
+
if isinstance(obj, (set, range)):
|
| 31 |
+
return list(obj)
|
| 32 |
+
elif isinstance(obj, np.ndarray):
|
| 33 |
+
return obj.tolist()
|
| 34 |
+
elif isinstance(obj, np.generic):
|
| 35 |
+
return obj.item()
|
| 36 |
+
raise TypeError(f"{type(obj)} is unsupported for json dump")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class JsonHandler(BaseFileHandler):
|
| 40 |
+
def load_from_fileobj(self, file):
|
| 41 |
+
return json.load(file)
|
| 42 |
+
|
| 43 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 44 |
+
kwargs.setdefault("default", set_default)
|
| 45 |
+
json.dump(obj, file, **kwargs)
|
| 46 |
+
|
| 47 |
+
def dump_to_str(self, obj, **kwargs):
|
| 48 |
+
kwargs.setdefault("default", set_default)
|
| 49 |
+
return json.dumps(obj, **kwargs)
|
imaginaire/utils/easy_io/handlers/jsonl_handler.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
from typing import IO
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_default(obj):
|
| 25 |
+
"""Set default json values for non-serializable values.
|
| 26 |
+
|
| 27 |
+
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
| 28 |
+
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
| 29 |
+
etc.) into plain numbers of plain python built-in types.
|
| 30 |
+
"""
|
| 31 |
+
if isinstance(obj, (set, range)):
|
| 32 |
+
return list(obj)
|
| 33 |
+
elif isinstance(obj, np.ndarray):
|
| 34 |
+
return obj.tolist()
|
| 35 |
+
elif isinstance(obj, np.generic):
|
| 36 |
+
return obj.item()
|
| 37 |
+
raise TypeError(f"{type(obj)} is unsupported for json dump")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class JsonlHandler(BaseFileHandler):
|
| 41 |
+
"""Handler for JSON lines (JSONL) files."""
|
| 42 |
+
|
| 43 |
+
def load_from_fileobj(self, file: IO[bytes]):
|
| 44 |
+
"""Load JSON objects from a newline-delimited JSON (JSONL) file object.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
A list of Python objects loaded from each JSON line.
|
| 48 |
+
"""
|
| 49 |
+
data = []
|
| 50 |
+
for line in file:
|
| 51 |
+
line = line.strip()
|
| 52 |
+
if not line:
|
| 53 |
+
continue # skip empty lines if any
|
| 54 |
+
data.append(json.loads(line))
|
| 55 |
+
return data
|
| 56 |
+
|
| 57 |
+
def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs):
|
| 58 |
+
"""Dump a list of objects to a newline-delimited JSON (JSONL) file object.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
obj: A list (or iterable) of objects to dump line by line.
|
| 62 |
+
"""
|
| 63 |
+
kwargs.setdefault("default", set_default)
|
| 64 |
+
for item in obj:
|
| 65 |
+
file.write(json.dumps(item, **kwargs) + "\n")
|
| 66 |
+
|
| 67 |
+
def dump_to_str(self, obj, **kwargs):
|
| 68 |
+
"""Dump a list of objects to a newline-delimited JSON (JSONL) string."""
|
| 69 |
+
kwargs.setdefault("default", set_default)
|
| 70 |
+
lines = [json.dumps(item, **kwargs) for item in obj]
|
| 71 |
+
return "\n".join(lines)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
from imaginaire.utils.easy_io import easy_io
|
| 76 |
+
|
| 77 |
+
easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl")
|
| 78 |
+
print(easy_io.load("test.jsonl"))
|
| 79 |
+
easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl")
|
| 80 |
+
print(easy_io.load("test.jsonl"))
|
imaginaire/utils/easy_io/handlers/np_handler.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from io import BytesIO
|
| 17 |
+
from typing import IO, Any
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NumpyHandler(BaseFileHandler):
|
| 25 |
+
str_like = False
|
| 26 |
+
|
| 27 |
+
def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any:
|
| 28 |
+
"""
|
| 29 |
+
Load a NumPy array from a file-like object.
|
| 30 |
+
|
| 31 |
+
Parameters:
|
| 32 |
+
file (IO[bytes]): The file-like object containing the NumPy array data.
|
| 33 |
+
**kwargs: Additional keyword arguments passed to `np.load`.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
numpy.ndarray: The loaded NumPy array.
|
| 37 |
+
"""
|
| 38 |
+
return np.load(file, **kwargs)
|
| 39 |
+
|
| 40 |
+
def load_from_path(self, filepath: str, **kwargs) -> Any:
|
| 41 |
+
"""
|
| 42 |
+
Load a NumPy array from a file path.
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
filepath (str): The path to the file to load.
|
| 46 |
+
**kwargs: Additional keyword arguments passed to `np.load`.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
numpy.ndarray: The loaded NumPy array.
|
| 50 |
+
"""
|
| 51 |
+
return super().load_from_path(filepath, mode="rb", **kwargs)
|
| 52 |
+
|
| 53 |
+
def dump_to_str(self, obj: np.ndarray, **kwargs) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Serialize a NumPy array to a string in binary format.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
obj (np.ndarray): The NumPy array to serialize.
|
| 59 |
+
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
str: The serialized NumPy array as a string.
|
| 63 |
+
"""
|
| 64 |
+
with BytesIO() as f:
|
| 65 |
+
np.save(f, obj, **kwargs)
|
| 66 |
+
return f.getvalue()
|
| 67 |
+
|
| 68 |
+
def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs):
|
| 69 |
+
"""
|
| 70 |
+
Dump a NumPy array to a file-like object.
|
| 71 |
+
|
| 72 |
+
Parameters:
|
| 73 |
+
obj (np.ndarray): The NumPy array to dump.
|
| 74 |
+
file (IO[bytes]): The file-like object to which the array is dumped.
|
| 75 |
+
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 76 |
+
"""
|
| 77 |
+
np.save(file, obj, **kwargs)
|
| 78 |
+
|
| 79 |
+
def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs):
|
| 80 |
+
"""
|
| 81 |
+
Dump a NumPy array to a file path.
|
| 82 |
+
|
| 83 |
+
Parameters:
|
| 84 |
+
obj (np.ndarray): The NumPy array to dump.
|
| 85 |
+
filepath (str): The file path where the array should be saved.
|
| 86 |
+
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 87 |
+
"""
|
| 88 |
+
with open(filepath, "wb") as f:
|
| 89 |
+
np.save(f, obj, **kwargs)
|
imaginaire/utils/easy_io/handlers/pandas_handler.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PandasHandler(BaseFileHandler):
|
| 22 |
+
str_like = False
|
| 23 |
+
|
| 24 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 25 |
+
return pd.read_csv(file, **kwargs)
|
| 26 |
+
|
| 27 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 28 |
+
obj.to_csv(file, **kwargs)
|
| 29 |
+
|
| 30 |
+
def dump_to_str(self, obj, **kwargs):
|
| 31 |
+
raise NotImplementedError("PandasHandler does not support dumping to str")
|
imaginaire/utils/easy_io/handlers/pickle_handler.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import pickle
|
| 17 |
+
from io import BytesIO
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PickleHandler(BaseFileHandler):
|
| 24 |
+
str_like = False
|
| 25 |
+
|
| 26 |
+
def load_from_fileobj(self, file: BytesIO, **kwargs):
|
| 27 |
+
return pickle.load(file, **kwargs)
|
| 28 |
+
|
| 29 |
+
def load_from_path(self, filepath, **kwargs):
|
| 30 |
+
return super().load_from_path(filepath, mode="rb", **kwargs)
|
| 31 |
+
|
| 32 |
+
def dump_to_str(self, obj, **kwargs):
|
| 33 |
+
kwargs.setdefault("protocol", 2)
|
| 34 |
+
return pickle.dumps(obj, **kwargs)
|
| 35 |
+
|
| 36 |
+
def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
|
| 37 |
+
kwargs.setdefault("protocol", 2)
|
| 38 |
+
pickle.dump(obj, file, **kwargs)
|
| 39 |
+
|
| 40 |
+
def dump_to_path(self, obj, filepath, **kwargs):
|
| 41 |
+
with open(filepath, "wb") as f:
|
| 42 |
+
pickle.dump(obj, f, **kwargs)
|
imaginaire/utils/easy_io/handlers/pil_handler.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import IO
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from PIL import Image
|
| 24 |
+
except ImportError:
|
| 25 |
+
Image = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PILHandler(BaseFileHandler):
|
| 29 |
+
format: str
|
| 30 |
+
str_like = False
|
| 31 |
+
|
| 32 |
+
def load_from_fileobj(
|
| 33 |
+
self,
|
| 34 |
+
file: IO[bytes],
|
| 35 |
+
fmt: str = "pil",
|
| 36 |
+
size: int | tuple[int, int] | None = None,
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Load an image from a file-like object and return it in a specified format.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
file (IO[bytes]): A file-like object containing the image data.
|
| 44 |
+
fmt (str): The format to convert the image into. Options are \
|
| 45 |
+
'numpy', 'np', 'npy', 'type' (all return numpy arrays), \
|
| 46 |
+
'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor).
|
| 47 |
+
size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \
|
| 48 |
+
or a tuple of (width, height). If specified, the image is resized accordingly.
|
| 49 |
+
**kwargs: Additional keyword arguments that can be passed to conversion functions.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Image data in the format specified by `fmt`.
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
IOError: If the image cannot be loaded or processed.
|
| 56 |
+
ValueError: If the specified format is unsupported.
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
img = Image.open(file)
|
| 60 |
+
img.load() # Explicitly load the image data
|
| 61 |
+
if size is not None:
|
| 62 |
+
if isinstance(size, int):
|
| 63 |
+
size = (
|
| 64 |
+
size,
|
| 65 |
+
size,
|
| 66 |
+
) # create a tuple if only one integer is provided
|
| 67 |
+
img = img.resize(size, Image.ANTIALIAS)
|
| 68 |
+
|
| 69 |
+
# Return the image in the requested format
|
| 70 |
+
if fmt in ["numpy", "np", "npy"]:
|
| 71 |
+
return np.array(img, **kwargs)
|
| 72 |
+
if fmt == "pil":
|
| 73 |
+
return img
|
| 74 |
+
if fmt in ["th", "torch"]:
|
| 75 |
+
import torch
|
| 76 |
+
|
| 77 |
+
# Convert to tensor
|
| 78 |
+
img_tensor = torch.from_numpy(np.array(img, **kwargs))
|
| 79 |
+
# Convert image from HxWxC to CxHxW
|
| 80 |
+
if img_tensor.ndim == 3:
|
| 81 |
+
img_tensor = img_tensor.permute(2, 0, 1)
|
| 82 |
+
return img_tensor
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'."
|
| 85 |
+
)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
raise OSError(f"Unable to load image: {e}") from e
|
| 88 |
+
|
| 89 |
+
def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs):
|
| 90 |
+
if "format" not in kwargs:
|
| 91 |
+
kwargs["format"] = self.format
|
| 92 |
+
kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper()
|
| 93 |
+
obj.save(file, **kwargs)
|
| 94 |
+
|
| 95 |
+
def dump_to_str(self, obj, **kwargs):
|
| 96 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/registry_utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 17 |
+
from imaginaire.utils.easy_io.handlers.byte_handler import ByteHandler
|
| 18 |
+
from imaginaire.utils.easy_io.handlers.csv_handler import CsvHandler
|
| 19 |
+
from imaginaire.utils.easy_io.handlers.gzip_handler import GzipHandler
|
| 20 |
+
from imaginaire.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
|
| 22 |
+
from imaginaire.utils.easy_io.handlers.jsonl_handler import JsonlHandler
|
| 23 |
+
from imaginaire.utils.easy_io.handlers.np_handler import NumpyHandler
|
| 24 |
+
from imaginaire.utils.easy_io.handlers.pandas_handler import PandasHandler
|
| 25 |
+
from imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
|
| 26 |
+
from imaginaire.utils.easy_io.handlers.pil_handler import PILHandler
|
| 27 |
+
from imaginaire.utils.easy_io.handlers.tarfile_handler import TarHandler
|
| 28 |
+
from imaginaire.utils.easy_io.handlers.torch_handler import TorchHandler
|
| 29 |
+
from imaginaire.utils.easy_io.handlers.torchjit_handler import TorchJitHandler
|
| 30 |
+
from imaginaire.utils.easy_io.handlers.txt_handler import TxtHandler
|
| 31 |
+
from imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
|
| 32 |
+
|
| 33 |
+
file_handlers = {
|
| 34 |
+
"json": JsonHandler(),
|
| 35 |
+
"yaml": YamlHandler(),
|
| 36 |
+
"yml": YamlHandler(),
|
| 37 |
+
"pickle": PickleHandler(),
|
| 38 |
+
"pkl": PickleHandler(),
|
| 39 |
+
"tar": TarHandler(),
|
| 40 |
+
"jit": TorchJitHandler(),
|
| 41 |
+
"npy": NumpyHandler(),
|
| 42 |
+
"txt": TxtHandler(),
|
| 43 |
+
"csv": CsvHandler(),
|
| 44 |
+
"pandas": PandasHandler(),
|
| 45 |
+
"gz": GzipHandler(),
|
| 46 |
+
"jsonl": JsonlHandler(),
|
| 47 |
+
"byte": ByteHandler(),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
for torch_type in ["pt", "pth", "ckpt"]:
|
| 51 |
+
file_handlers[torch_type] = TorchHandler()
|
| 52 |
+
for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]:
|
| 53 |
+
file_handlers[img_type] = PILHandler()
|
| 54 |
+
file_handlers[img_type].format = img_type
|
| 55 |
+
for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]:
|
| 56 |
+
file_handlers[video_type] = ImageioVideoHandler()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _register_handler(handler, file_formats):
|
| 60 |
+
"""Register a handler for some file extensions.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
handler (:obj:`BaseFileHandler`): Handler to be registered.
|
| 64 |
+
file_formats (str or list[str]): File formats to be handled by this
|
| 65 |
+
handler.
|
| 66 |
+
"""
|
| 67 |
+
if not isinstance(handler, BaseFileHandler):
|
| 68 |
+
raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}")
|
| 69 |
+
if isinstance(file_formats, str):
|
| 70 |
+
file_formats = [file_formats]
|
| 71 |
+
if not all([isinstance(item, str) for item in file_formats]):
|
| 72 |
+
raise TypeError("file_formats must be a str or a list of str")
|
| 73 |
+
for ext in file_formats:
|
| 74 |
+
file_handlers[ext] = handler
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def register_handler(file_formats, **kwargs):
|
| 78 |
+
def wrap(cls):
|
| 79 |
+
_register_handler(cls(**kwargs), file_formats)
|
| 80 |
+
return cls
|
| 81 |
+
|
| 82 |
+
return wrap
|
imaginaire/utils/easy_io/handlers/tarfile_handler.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import tarfile
|
| 17 |
+
|
| 18 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TarHandler(BaseFileHandler):
|
| 22 |
+
str_like = False
|
| 23 |
+
|
| 24 |
+
def load_from_fileobj(self, file, mode="r|*", **kwargs):
|
| 25 |
+
return tarfile.open(fileobj=file, mode=mode, **kwargs)
|
| 26 |
+
|
| 27 |
+
def load_from_path(self, filepath, mode="r|*", **kwargs):
|
| 28 |
+
return tarfile.open(filepath, mode=mode, **kwargs)
|
| 29 |
+
|
| 30 |
+
def dump_to_fileobj(self, obj, file, mode="w", **kwargs):
|
| 31 |
+
with tarfile.open(fileobj=file, mode=mode) as tar:
|
| 32 |
+
tar.add(obj, **kwargs)
|
| 33 |
+
|
| 34 |
+
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
|
| 35 |
+
with tarfile.open(filepath, mode=mode) as tar:
|
| 36 |
+
tar.add(obj, **kwargs)
|
| 37 |
+
|
| 38 |
+
def dump_to_str(self, obj, **kwargs):
|
| 39 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/torch_handler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
except ImportError:
|
| 19 |
+
torch = None
|
| 20 |
+
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TorchHandler(BaseFileHandler):
|
| 25 |
+
str_like = False
|
| 26 |
+
|
| 27 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 28 |
+
return torch.load(file, **kwargs)
|
| 29 |
+
|
| 30 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 31 |
+
torch.save(obj, file, **kwargs)
|
| 32 |
+
|
| 33 |
+
def dump_to_str(self, obj, **kwargs):
|
| 34 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/torchjit_handler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
except ImportError:
|
| 19 |
+
torch = None
|
| 20 |
+
|
| 21 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TorchJitHandler(BaseFileHandler):
|
| 25 |
+
str_like = False
|
| 26 |
+
|
| 27 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 28 |
+
return torch.jit.load(file, **kwargs)
|
| 29 |
+
|
| 30 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 31 |
+
torch.jit.save(obj, file, **kwargs)
|
| 32 |
+
|
| 33 |
+
def dump_to_str(self, obj, **kwargs):
|
| 34 |
+
raise NotImplementedError
|
imaginaire/utils/easy_io/handlers/txt_handler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TxtHandler(BaseFileHandler):
|
| 20 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 21 |
+
del kwargs
|
| 22 |
+
return file.read()
|
| 23 |
+
|
| 24 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 25 |
+
del kwargs
|
| 26 |
+
if not isinstance(obj, str):
|
| 27 |
+
obj = str(obj)
|
| 28 |
+
file.write(obj)
|
| 29 |
+
|
| 30 |
+
def dump_to_str(self, obj, **kwargs):
|
| 31 |
+
del kwargs
|
| 32 |
+
if not isinstance(obj, str):
|
| 33 |
+
obj = str(obj)
|
| 34 |
+
return obj
|
imaginaire/utils/easy_io/handlers/yaml_handler.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from yaml import CDumper as Dumper # type: ignore
|
| 20 |
+
from yaml import CLoader as Loader # type: ignore
|
| 21 |
+
except ImportError:
|
| 22 |
+
from yaml import Dumper, Loader # type: ignore
|
| 23 |
+
|
| 24 |
+
from imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class YamlHandler(BaseFileHandler):
|
| 28 |
+
def load_from_fileobj(self, file, **kwargs):
|
| 29 |
+
kwargs.setdefault("Loader", Loader)
|
| 30 |
+
return yaml.load(file, **kwargs)
|
| 31 |
+
|
| 32 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 33 |
+
kwargs.setdefault("Dumper", Dumper)
|
| 34 |
+
yaml.dump(obj, file, **kwargs)
|
| 35 |
+
|
| 36 |
+
def dump_to_str(self, obj, **kwargs):
|
| 37 |
+
kwargs.setdefault("Dumper", Dumper)
|
| 38 |
+
return yaml.dump(obj, **kwargs)
|
imaginaire/utils/ema.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from megatron.core import parallel_state
|
| 25 |
+
|
| 26 |
+
USE_MEGATRON = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
USE_MEGATRON = False
|
| 29 |
+
|
| 30 |
+
from imaginaire.utils import distributed, log
|
| 31 |
+
|
| 32 |
+
if TYPE_CHECKING:
|
| 33 |
+
from imaginaire.model import ImaginaireModel
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FastEmaModelUpdater:
|
| 37 |
+
"""
|
| 38 |
+
This class is used to update target model~(EMA) given source model~(regular model) and beta.
|
| 39 |
+
The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`.
|
| 40 |
+
Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape.
|
| 41 |
+
The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
# Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite
|
| 46 |
+
self.is_cached = False
|
| 47 |
+
|
| 48 |
+
@torch.no_grad()
|
| 49 |
+
def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None:
|
| 50 |
+
for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters(), strict=False):
|
| 51 |
+
tgt_params.data.copy_(src_params.data)
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None:
|
| 55 |
+
target_list = []
|
| 56 |
+
source_list = []
|
| 57 |
+
for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters(), strict=False):
|
| 58 |
+
assert tgt_params.dtype == torch.float32, (
|
| 59 |
+
f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead."
|
| 60 |
+
)
|
| 61 |
+
target_list.append(tgt_params)
|
| 62 |
+
source_list.append(src_params.data)
|
| 63 |
+
torch._foreach_mul_(target_list, beta)
|
| 64 |
+
torch._foreach_add_(target_list, source_list, alpha=1.0 - beta)
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def cache(self, parameters: Any, is_cpu: bool = False) -> None:
|
| 68 |
+
"""Save the current parameters for restoring later.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored.
|
| 72 |
+
"""
|
| 73 |
+
assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?"
|
| 74 |
+
device = "cpu" if is_cpu else "cuda"
|
| 75 |
+
self.collected_params = [param.clone().to(device) for param in parameters]
|
| 76 |
+
self.is_cached = True
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def restore(self, parameters: Any) -> None:
|
| 80 |
+
"""Restore the parameters in self.collected_params.
|
| 81 |
+
|
| 82 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 83 |
+
original optimization process. Store the parameters before copy_to().
|
| 84 |
+
After validation (or model saving), use this to restore the former parameters.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters.
|
| 88 |
+
"""
|
| 89 |
+
assert self.is_cached, "EMA cache is not taken yet."
|
| 90 |
+
for c_param, param in zip(self.collected_params, parameters, strict=False):
|
| 91 |
+
param.data.copy_(c_param.data.type_as(param.data))
|
| 92 |
+
self.collected_params = []
|
| 93 |
+
# Release the cache after we call restore
|
| 94 |
+
self.is_cached = False
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str:
|
| 98 |
+
"""
|
| 99 |
+
This function creates buffer name used by EMA from parameter's name
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
param_name (str): Model's parameter name
|
| 103 |
+
Returns:
|
| 104 |
+
buffer_name (str): buffer name to be used for given parameter name
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
buffer_name = param_name.replace(".", "-")
|
| 108 |
+
|
| 109 |
+
if torch_compile_buffer_renaming:
|
| 110 |
+
# torch.compile() adds _orig_mod to state dict names, this way we get original name
|
| 111 |
+
buffer_name = buffer_name.replace("_orig_mod-", "")
|
| 112 |
+
|
| 113 |
+
return buffer_name
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class EMAModelTracker(torch.nn.Module):
|
| 117 |
+
"""This is a class to track the EMA model weights.
|
| 118 |
+
|
| 119 |
+
The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the
|
| 120 |
+
regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's
|
| 121 |
+
implementation of EMA. There are no optimizable parameters.
|
| 122 |
+
|
| 123 |
+
Attributes:
|
| 124 |
+
collected_params (list): temporarily stores the regular weights while in EMA mode.
|
| 125 |
+
beta (float): EMA decay rate. (default: 0.9999).
|
| 126 |
+
torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, model: ImaginaireModel, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False):
|
| 130 |
+
"""Constructor of the EMA model weight tracker.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
model (ImaginaireModel): The PyTorch model.
|
| 134 |
+
beta (float): EMA decay rate. (default: 0.9999).
|
| 135 |
+
"""
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming
|
| 138 |
+
if not 0.0 <= beta <= 1.0:
|
| 139 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 140 |
+
self.beta = beta
|
| 141 |
+
for name, param in model.named_parameters():
|
| 142 |
+
if param.requires_grad:
|
| 143 |
+
buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
|
| 144 |
+
self.register_buffer(buffer_name, param.clone().detach().data)
|
| 145 |
+
self.collected_params = []
|
| 146 |
+
# Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite
|
| 147 |
+
self.is_cached = False
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def update_average(self, model: ImaginaireModel, iteration: int | None = None) -> None:
|
| 151 |
+
del iteration
|
| 152 |
+
target_list = []
|
| 153 |
+
source_list = []
|
| 154 |
+
ema_buffers = self.state_dict()
|
| 155 |
+
for name, param in model.named_parameters():
|
| 156 |
+
if param.requires_grad:
|
| 157 |
+
buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
|
| 158 |
+
buffer = ema_buffers[buffer_name]
|
| 159 |
+
assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead."
|
| 160 |
+
target_list.append(buffer)
|
| 161 |
+
source_list.append(param.data)
|
| 162 |
+
torch._foreach_mul_(target_list, self.beta)
|
| 163 |
+
torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta)
|
| 164 |
+
|
| 165 |
+
def copy_to(self, model: ImaginaireModel) -> None:
|
| 166 |
+
ema_buffers = self.state_dict()
|
| 167 |
+
for name, param in model.named_parameters():
|
| 168 |
+
if param.requires_grad:
|
| 169 |
+
buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming)
|
| 170 |
+
buffer = ema_buffers[buffer_name]
|
| 171 |
+
param.data.copy_(buffer.data)
|
| 172 |
+
|
| 173 |
+
def cache(self, parameters: Any, is_cpu: bool = False) -> None:
|
| 174 |
+
"""Save the current parameters for restoring later.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored.
|
| 178 |
+
"""
|
| 179 |
+
assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?"
|
| 180 |
+
device = "cpu" if is_cpu else "cuda"
|
| 181 |
+
self.collected_params = [param.clone().to(device) for param in parameters]
|
| 182 |
+
self.is_cached = True
|
| 183 |
+
|
| 184 |
+
def restore(self, parameters: Any) -> None:
|
| 185 |
+
"""Restore the parameters in self.collected_params.
|
| 186 |
+
|
| 187 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 188 |
+
original optimization process. Store the parameters before copy_to().
|
| 189 |
+
After validation (or model saving), use this to restore the former parameters.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters.
|
| 193 |
+
"""
|
| 194 |
+
assert self.is_cached, "EMA cache is not taken yet."
|
| 195 |
+
for c_param, param in zip(self.collected_params, parameters, strict=False):
|
| 196 |
+
param.data.copy_(c_param.data.type_as(param.data))
|
| 197 |
+
self.collected_params = []
|
| 198 |
+
# Release the cache after we call restore
|
| 199 |
+
self.is_cached = False
|
| 200 |
+
|
| 201 |
+
@classmethod
|
| 202 |
+
def initialize_multi_rank_ema(
|
| 203 |
+
cls, model: torch.nn.Module, rate: float | list[float], num: int = 1, enabled: bool = True
|
| 204 |
+
) -> EMAModelTracker | None:
|
| 205 |
+
"""
|
| 206 |
+
Class method to initialize per rank EMA Model Tracker with different rate.
|
| 207 |
+
Each rank will have a different rate based on the given configuration, resulting in different EMA weights.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
model (torch.nn.Module): The neural network model to be tracked.
|
| 211 |
+
rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided,
|
| 212 |
+
it corresponds to rates for different ranks.
|
| 213 |
+
num (int, optional): The number of leading ranks to consider for different rates.
|
| 214 |
+
Defaults to 1.
|
| 215 |
+
enabled (bool, optional): Flag to enable or disable the creation of the tracker.
|
| 216 |
+
If False, returns None. Defaults to True.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None.
|
| 220 |
+
|
| 221 |
+
Example:
|
| 222 |
+
>>> model = torch.nn.Linear(10, 2)
|
| 223 |
+
>>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2)
|
| 224 |
+
>>> print(tracker)
|
| 225 |
+
|
| 226 |
+
Notes:
|
| 227 |
+
If `rate` is a list and the current rank is less than `num`, the rate for the current rank
|
| 228 |
+
is used. If the current rank exceeds `num`, the first rate in the list is used by default.
|
| 229 |
+
"""
|
| 230 |
+
if not enabled:
|
| 231 |
+
return None
|
| 232 |
+
if USE_MEGATRON and parallel_state.is_initialized():
|
| 233 |
+
cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
|
| 234 |
+
log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
|
| 235 |
+
log.warning("It should not used together with FSDP!")
|
| 236 |
+
else:
|
| 237 |
+
cur_dp_rank = distributed.get_rank()
|
| 238 |
+
log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
|
| 239 |
+
rate = rate if isinstance(rate, list) else [rate]
|
| 240 |
+
num = min(num, len(rate))
|
| 241 |
+
rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0]
|
| 242 |
+
if cur_dp_rank < num:
|
| 243 |
+
print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}")
|
| 244 |
+
return cls(model, rate)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class PowerEMATracker(EMAModelTracker):
|
| 248 |
+
def __init__(self, model: ImaginaireModel, s: float = 0.1, torch_compile_buffer_renaming: bool = False):
|
| 249 |
+
"""Constructor of the EMA model weight tracker.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
model (ImaginaireModel): The PyTorch model.
|
| 253 |
+
s (float): EMA decay rate. See EDM2 paper
|
| 254 |
+
torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used
|
| 255 |
+
"""
|
| 256 |
+
super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming)
|
| 257 |
+
self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max()
|
| 258 |
+
|
| 259 |
+
@torch.no_grad()
|
| 260 |
+
def update_average(self, model: ImaginaireModel, iteration: int | None = None) -> None:
|
| 261 |
+
if iteration == 0:
|
| 262 |
+
beta = 0.0
|
| 263 |
+
else:
|
| 264 |
+
i = iteration + 1
|
| 265 |
+
beta = (1 - 1 / i) ** (self.exp + 1)
|
| 266 |
+
self.beta = beta
|
| 267 |
+
|
| 268 |
+
super().update_average(model, iteration)
|
| 269 |
+
|
| 270 |
+
@classmethod
|
| 271 |
+
def initialize_multi_rank_ema(
|
| 272 |
+
cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True
|
| 273 |
+
) -> PowerEMATracker | None:
|
| 274 |
+
"""
|
| 275 |
+
Class method to initialize per rank EMA Model Tracker with different rate.
|
| 276 |
+
Each rank will have a different rate based on the given configuration, resulting in different EMA weights.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
model (torch.nn.Module): The neural network model for which the EMA tracker is being set up.
|
| 280 |
+
num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged.
|
| 281 |
+
rate (float): The base decay rate for the EMA calculation.
|
| 282 |
+
enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None.
|
| 283 |
+
Defaults to True.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None.
|
| 287 |
+
|
| 288 |
+
Raises:
|
| 289 |
+
None
|
| 290 |
+
|
| 291 |
+
Example:
|
| 292 |
+
>>> model = torch.nn.Linear(10, 2)
|
| 293 |
+
>>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99)
|
| 294 |
+
>>> print(tracker)
|
| 295 |
+
|
| 296 |
+
Notes:
|
| 297 |
+
The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`.
|
| 298 |
+
If the rank is greater than or equal to `num`, the base rate is used without modification. This approach
|
| 299 |
+
allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization
|
| 300 |
+
in a distributed training scenario.
|
| 301 |
+
"""
|
| 302 |
+
if not enabled:
|
| 303 |
+
return None
|
| 304 |
+
if USE_MEGATRON and parallel_state.is_initialized():
|
| 305 |
+
cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
|
| 306 |
+
log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
|
| 307 |
+
log.warning("It should not used together with FSDP!")
|
| 308 |
+
else:
|
| 309 |
+
cur_dp_rank = distributed.get_rank()
|
| 310 |
+
log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False)
|
| 311 |
+
|
| 312 |
+
divider = 2**cur_dp_rank if cur_dp_rank < num else 1
|
| 313 |
+
if cur_dp_rank < num:
|
| 314 |
+
print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}")
|
| 315 |
+
return cls(model, rate / divider)
|
imaginaire/utils/fused_adam.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from apex.multi_tensor_apply import multi_tensor_applier
|
| 18 |
+
|
| 19 |
+
from imaginaire.utils import distributed, log
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FusedAdam(torch.optim.Optimizer):
|
| 23 |
+
"""Implements Adam algorithm.
|
| 24 |
+
|
| 25 |
+
Currently GPU-only. Requires Apex to be installed via
|
| 26 |
+
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
|
| 27 |
+
|
| 28 |
+
This version of fused Adam implements 2 fusions.
|
| 29 |
+
|
| 30 |
+
* Fusion of the Adam update's elementwise operations
|
| 31 |
+
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters
|
| 32 |
+
into one or a few kernel launches.
|
| 33 |
+
|
| 34 |
+
:class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
|
| 35 |
+
or ``torch.optim.Adam`` with ``adam_w_mode=False``::
|
| 36 |
+
|
| 37 |
+
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
|
| 38 |
+
...
|
| 39 |
+
opt.step()
|
| 40 |
+
|
| 41 |
+
:class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp,
|
| 42 |
+
you may choose any ``opt_level``::
|
| 43 |
+
|
| 44 |
+
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....)
|
| 45 |
+
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
|
| 46 |
+
...
|
| 47 |
+
opt.step()
|
| 48 |
+
|
| 49 |
+
In general, ``opt_level="O1"`` is recommended.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
.. warning::
|
| 53 |
+
A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``.
|
| 54 |
+
These additional arguments are now deprecated and unnecessary.
|
| 55 |
+
|
| 56 |
+
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
| 57 |
+
|
| 58 |
+
Arguments:
|
| 59 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 60 |
+
parameter groups.
|
| 61 |
+
lr (float, optional): learning rate. (default: 1e-3)
|
| 62 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 63 |
+
running averages of gradient and its square. (default: (0.9, 0.999))
|
| 64 |
+
eps (float, optional): term added to the denominator to improve
|
| 65 |
+
numerical stability. (default: 1e-8)
|
| 66 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
| 67 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
| 68 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
| 69 |
+
(default: False) NOT SUPPORTED in FusedAdam!
|
| 70 |
+
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
|
| 71 |
+
True for decoupled weight decay(also known as AdamW) (default: True)
|
| 72 |
+
capturable (bool, optional): whether to use the version of the optimizer
|
| 73 |
+
that can be used with CUDA Graphs. (default: False)
|
| 74 |
+
master_weights (bool, optional): whether to maintain FP32 master weights
|
| 75 |
+
in the optimizer with FP16 mixed precision training, currently can
|
| 76 |
+
only be used with capturable set to True. (default: False)
|
| 77 |
+
|
| 78 |
+
.. _Adam - A Method for Stochastic Optimization:
|
| 79 |
+
https://arxiv.org/abs/1412.6980
|
| 80 |
+
.. _On the Convergence of Adam and Beyond:
|
| 81 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
params,
|
| 87 |
+
lr=1e-3,
|
| 88 |
+
bias_correction=True,
|
| 89 |
+
betas=(0.9, 0.999),
|
| 90 |
+
eps=1e-8,
|
| 91 |
+
adam_w_mode=True,
|
| 92 |
+
weight_decay=0.0,
|
| 93 |
+
amsgrad=False,
|
| 94 |
+
capturable=False,
|
| 95 |
+
master_weights=False,
|
| 96 |
+
):
|
| 97 |
+
if amsgrad:
|
| 98 |
+
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
|
| 99 |
+
if master_weights and not capturable:
|
| 100 |
+
raise RuntimeError("Master weights is currently only supported with the capturable version.")
|
| 101 |
+
# If the optimizer is capturable then LR should be a tensor (on GPU)
|
| 102 |
+
log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}")
|
| 103 |
+
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
|
| 104 |
+
defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
|
| 105 |
+
super(FusedAdam, self).__init__(params, defaults) # noqa: UP008
|
| 106 |
+
self.adam_w_mode = 1 if adam_w_mode else 0
|
| 107 |
+
|
| 108 |
+
self.capturable = capturable
|
| 109 |
+
self.master_weights = master_weights
|
| 110 |
+
|
| 111 |
+
self.param_groups_master = None
|
| 112 |
+
|
| 113 |
+
if capturable:
|
| 114 |
+
for idx, group in enumerate(self.param_groups):
|
| 115 |
+
if len(group["params"]) == 0:
|
| 116 |
+
continue
|
| 117 |
+
device = group["params"][0].device
|
| 118 |
+
for item in ["lr"]:
|
| 119 |
+
if isinstance(group[item], float):
|
| 120 |
+
group[item] = torch.tensor(group[item], dtype=torch.float32)
|
| 121 |
+
self.param_groups[idx][item] = group[item].to(device=device)
|
| 122 |
+
|
| 123 |
+
self._step_supports_amp_scaling = True
|
| 124 |
+
|
| 125 |
+
if multi_tensor_applier.available:
|
| 126 |
+
import amp_C
|
| 127 |
+
|
| 128 |
+
# Skip buffer
|
| 129 |
+
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
|
| 130 |
+
self.multi_tensor_adam = amp_C.multi_tensor_adam
|
| 131 |
+
self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable
|
| 132 |
+
self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master
|
| 133 |
+
else:
|
| 134 |
+
raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions")
|
| 135 |
+
|
| 136 |
+
def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None):
|
| 137 |
+
"""Performs a single optimization step.
|
| 138 |
+
|
| 139 |
+
Arguments:
|
| 140 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 141 |
+
and returns the loss.
|
| 142 |
+
|
| 143 |
+
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
|
| 144 |
+
"""
|
| 145 |
+
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
| 146 |
+
raise RuntimeError(
|
| 147 |
+
"FusedAdam has been updated. "
|
| 148 |
+
"Simply initialize it identically to torch.optim.Adam, and call step() with no arguments."
|
| 149 |
+
)
|
| 150 |
+
loss = None
|
| 151 |
+
if closure is not None:
|
| 152 |
+
loss = closure()
|
| 153 |
+
|
| 154 |
+
if self.param_groups_master is None:
|
| 155 |
+
# Create full precision master weights
|
| 156 |
+
self.param_groups_master = []
|
| 157 |
+
for i, pg in enumerate(self.param_groups): # noqa: B007
|
| 158 |
+
param_list = pg["params"]
|
| 159 |
+
self.param_groups_master.append(
|
| 160 |
+
{
|
| 161 |
+
"params": [p.clone().detach().float() if self.master_weights else None for p in param_list],
|
| 162 |
+
}
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
for group, group_master in zip(self.param_groups, self.param_groups_master, strict=False):
|
| 166 |
+
if len(group["params"]) == 0:
|
| 167 |
+
continue
|
| 168 |
+
device = group["params"][0].device
|
| 169 |
+
bias_correction = 1 if group.get("bias_correction") else 0
|
| 170 |
+
beta1, beta2 = group["betas"]
|
| 171 |
+
|
| 172 |
+
# assume same step across group now to simplify things
|
| 173 |
+
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
| 174 |
+
if "step" in group:
|
| 175 |
+
if self.capturable:
|
| 176 |
+
group["step"] = (
|
| 177 |
+
group["step"].to(device=device)
|
| 178 |
+
if isinstance(group["step"], torch.Tensor)
|
| 179 |
+
else torch.tensor(group["step"], dtype=torch.int32, device=device)
|
| 180 |
+
)
|
| 181 |
+
group["step"] += (self._dummy_overflow_buf != 1).to(torch.int)
|
| 182 |
+
else:
|
| 183 |
+
group["step"] += 1
|
| 184 |
+
else:
|
| 185 |
+
group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device)
|
| 186 |
+
|
| 187 |
+
if self.capturable:
|
| 188 |
+
group["lr"] = (
|
| 189 |
+
group["lr"].to(device=device)
|
| 190 |
+
if isinstance(group["lr"], torch.Tensor)
|
| 191 |
+
else torch.tensor(group["lr"], dtype=torch.float32, device=device)
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# create lists for multi-tensor apply
|
| 195 |
+
g_16, p_16, m_16, v_16 = [], [], [], []
|
| 196 |
+
g_bf, p_bf, m_bf, v_bf = [], [], [], []
|
| 197 |
+
g_32, p_32, m_32, v_32 = [], [], [], []
|
| 198 |
+
p_16_master = []
|
| 199 |
+
p_32_master = []
|
| 200 |
+
bf16_master = []
|
| 201 |
+
|
| 202 |
+
for p, p_master in zip(group["params"], group_master["params"], strict=False):
|
| 203 |
+
if p.grad is None:
|
| 204 |
+
continue
|
| 205 |
+
if p.grad.data.is_sparse:
|
| 206 |
+
raise RuntimeError(
|
| 207 |
+
"FusedAdam does not support sparse gradients, please consider SparseAdam instead"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
state = self.state[p]
|
| 211 |
+
# State initialization
|
| 212 |
+
if len(state) == 0:
|
| 213 |
+
# Exponential moving average of gradient values
|
| 214 |
+
state["exp_avg"] = torch.zeros_like(p.data).float()
|
| 215 |
+
# Exponential moving average of squared gradient values
|
| 216 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data).float()
|
| 217 |
+
|
| 218 |
+
if p.dtype == torch.float16:
|
| 219 |
+
if self.master_weights:
|
| 220 |
+
p_16_master.append(p_master.data)
|
| 221 |
+
g_16.append(p.grad.data)
|
| 222 |
+
p_16.append(p.data)
|
| 223 |
+
m_16.append(state["exp_avg"])
|
| 224 |
+
v_16.append(state["exp_avg_sq"])
|
| 225 |
+
elif p.dtype == torch.bfloat16:
|
| 226 |
+
if self.master_weights:
|
| 227 |
+
bf16_master.append(p_master.data)
|
| 228 |
+
g_bf.append(p.grad)
|
| 229 |
+
p_bf.append(p)
|
| 230 |
+
m_bf.append(state["exp_avg"])
|
| 231 |
+
v_bf.append(state["exp_avg_sq"])
|
| 232 |
+
elif p.dtype == torch.float32:
|
| 233 |
+
if self.master_weights:
|
| 234 |
+
p_32_master.append(p_master.data)
|
| 235 |
+
g_32.append(p.grad.data)
|
| 236 |
+
p_32.append(p.data)
|
| 237 |
+
m_32.append(state["exp_avg"])
|
| 238 |
+
v_32.append(state["exp_avg_sq"])
|
| 239 |
+
else:
|
| 240 |
+
raise RuntimeError("FusedAdam only support fp16 and fp32.")
|
| 241 |
+
|
| 242 |
+
# If the optimizer is capturable, then if there's a grad scaler it works
|
| 243 |
+
# on the GPU + a different multi_tensor_applier should be called
|
| 244 |
+
if self.capturable:
|
| 245 |
+
# overflow check of gradients
|
| 246 |
+
found_inf = (
|
| 247 |
+
grad_scaler._check_inf_per_device(self)[device]
|
| 248 |
+
if grad_scaler is not None
|
| 249 |
+
else torch.zeros((1,), device=device)
|
| 250 |
+
)
|
| 251 |
+
self._dummy_overflow_buf.copy_(found_inf)
|
| 252 |
+
|
| 253 |
+
# get unscale scale factor
|
| 254 |
+
scale, inv_scale = None, None
|
| 255 |
+
if grad_scaler:
|
| 256 |
+
scale = grad_scaler._get_scale_async()
|
| 257 |
+
inv_scale = scale.double().reciprocal().float()
|
| 258 |
+
else:
|
| 259 |
+
scale = torch.ones((1,), device=device, dtype=torch.float32)
|
| 260 |
+
inv_scale = torch.ones((1,), device=device, dtype=torch.float32)
|
| 261 |
+
|
| 262 |
+
if len(g_16) > 0:
|
| 263 |
+
multi_tensor_applier(
|
| 264 |
+
(
|
| 265 |
+
self.multi_tensor_adam_capturable_master
|
| 266 |
+
if self.master_weights
|
| 267 |
+
else self.multi_tensor_adam_capturable
|
| 268 |
+
),
|
| 269 |
+
self._dummy_overflow_buf,
|
| 270 |
+
[g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16],
|
| 271 |
+
group["lr"],
|
| 272 |
+
beta1,
|
| 273 |
+
beta2,
|
| 274 |
+
group["eps"],
|
| 275 |
+
group["step"],
|
| 276 |
+
self.adam_w_mode,
|
| 277 |
+
bias_correction,
|
| 278 |
+
group["weight_decay"],
|
| 279 |
+
inv_scale,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if len(g_bf) > 0:
|
| 283 |
+
multi_tensor_applier(
|
| 284 |
+
(
|
| 285 |
+
self.multi_tensor_adam_capturable_master
|
| 286 |
+
if self.master_weights
|
| 287 |
+
else self.multi_tensor_adam_capturable
|
| 288 |
+
),
|
| 289 |
+
self._dummy_overflow_buf,
|
| 290 |
+
[g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf],
|
| 291 |
+
group["lr"],
|
| 292 |
+
beta1,
|
| 293 |
+
beta2,
|
| 294 |
+
group["eps"],
|
| 295 |
+
group["step"],
|
| 296 |
+
self.adam_w_mode,
|
| 297 |
+
bias_correction,
|
| 298 |
+
group["weight_decay"],
|
| 299 |
+
inv_scale,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if len(g_32) > 0:
|
| 303 |
+
multi_tensor_applier(
|
| 304 |
+
(
|
| 305 |
+
self.multi_tensor_adam_capturable_master
|
| 306 |
+
if self.master_weights
|
| 307 |
+
else self.multi_tensor_adam_capturable
|
| 308 |
+
),
|
| 309 |
+
self._dummy_overflow_buf,
|
| 310 |
+
[g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32],
|
| 311 |
+
group["lr"],
|
| 312 |
+
beta1,
|
| 313 |
+
beta2,
|
| 314 |
+
group["eps"],
|
| 315 |
+
group["step"],
|
| 316 |
+
self.adam_w_mode,
|
| 317 |
+
bias_correction,
|
| 318 |
+
group["weight_decay"],
|
| 319 |
+
inv_scale,
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
if len(g_16) > 0:
|
| 323 |
+
multi_tensor_applier(
|
| 324 |
+
self.multi_tensor_adam,
|
| 325 |
+
self._dummy_overflow_buf,
|
| 326 |
+
[g_16, p_16, m_16, v_16],
|
| 327 |
+
group["lr"],
|
| 328 |
+
beta1,
|
| 329 |
+
beta2,
|
| 330 |
+
group["eps"],
|
| 331 |
+
group["step"],
|
| 332 |
+
self.adam_w_mode,
|
| 333 |
+
bias_correction,
|
| 334 |
+
group["weight_decay"],
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if len(g_bf) > 0:
|
| 338 |
+
multi_tensor_applier(
|
| 339 |
+
self.multi_tensor_adam,
|
| 340 |
+
self._dummy_overflow_buf,
|
| 341 |
+
[g_bf, p_bf, m_bf, v_bf],
|
| 342 |
+
group["lr"],
|
| 343 |
+
beta1,
|
| 344 |
+
beta2,
|
| 345 |
+
group["eps"],
|
| 346 |
+
group["step"],
|
| 347 |
+
self.adam_w_mode,
|
| 348 |
+
bias_correction,
|
| 349 |
+
group["weight_decay"],
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if len(g_32) > 0:
|
| 353 |
+
multi_tensor_applier(
|
| 354 |
+
self.multi_tensor_adam,
|
| 355 |
+
self._dummy_overflow_buf,
|
| 356 |
+
[g_32, p_32, m_32, v_32],
|
| 357 |
+
group["lr"],
|
| 358 |
+
beta1,
|
| 359 |
+
beta2,
|
| 360 |
+
group["eps"],
|
| 361 |
+
group["step"],
|
| 362 |
+
self.adam_w_mode,
|
| 363 |
+
bias_correction,
|
| 364 |
+
group["weight_decay"],
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
return loss
|
| 368 |
+
|
| 369 |
+
def load_state_dict(self, state_dict):
|
| 370 |
+
super().load_state_dict(state_dict)
|
| 371 |
+
for group in self.param_groups:
|
| 372 |
+
if self.capturable:
|
| 373 |
+
group["lr"] = (
|
| 374 |
+
group["lr"].cuda()
|
| 375 |
+
if isinstance(group["lr"], torch.Tensor)
|
| 376 |
+
else torch.tensor(group["lr"], dtype=torch.float32).cuda()
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if "step" in group:
|
| 380 |
+
if self.capturable:
|
| 381 |
+
if distributed.get_rank() == 0:
|
| 382 |
+
step = (
|
| 383 |
+
group["step"].cuda()
|
| 384 |
+
if isinstance(group["step"], torch.Tensor)
|
| 385 |
+
else torch.tensor([group["step"]], dtype=torch.int32).cuda()
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
step = torch.zeros(1, dtype=torch.int32).cuda()
|
| 389 |
+
# make it compatible with FSDP optimizer
|
| 390 |
+
distributed.broadcast(step, 0)
|
| 391 |
+
group["step"] = step
|
| 392 |
+
elif isinstance(group["step"], torch.Tensor):
|
| 393 |
+
group["step"] = group["step"].item()
|
| 394 |
+
for p in group["params"]:
|
| 395 |
+
state = self.state[p]
|
| 396 |
+
if "exp_avg" in state:
|
| 397 |
+
state["exp_avg"] = state["exp_avg"].float()
|
| 398 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].float()
|