Improved Quality, Synchrony, and Preference Alignment for Joint Audio-Video Generation
This codebase is built upon JavisDiT. Many thanks to their contribution.
Installation
For CUDA 12.1, you can install the dependencies with the following commands.
# create a virtual env and activate (conda as an example)
conda create -n javisdit python=3.10
conda activate javisdit
# install torch, torchvision and xformers
pip install -r requirements/requirements-cu121.txt
# install ffpmeg
conda install "ffmpeg<7" -c conda-forge -y
# the default installation is for inference only
pip install -v .
# for development mode, `pip install -v -e .`
# to skip dependencies, `pip install -v -e . --no-deps`
# replace
export PYTHON_SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
cp assets/src/pytorchvideo_augmentations.py ${PYTHON_SITE_PACKAGES}/pytorchvideo/transforms/augmentations.py
cp assets/src/funasr_utils_load_utils.py ${PYTHON_SITE_PACKAGES}/funasr/utils/load_utils.py
# (optional but recommended) install flash attention
# set enable_flash_attn=False in config to disable flash attention
pip install packaging ninja
pip install flash-attn --no-build-isolation
Training
Data Preparation
In this project, we use a .csv file to manage all the training entries and their attributes for efficient training:
| path | id | relpath | num_frames | height | width | aspect_ratio | fps | resolution | audio_path | audio_fps | text |
|---|---|---|---|---|---|---|---|---|---|---|---|
| /path/to/xxx.mp4 | xxx | xxx.mp4 | 240 | 480 | 640 | 0.75 | 24 | 307200 | /path/to/xxx.wav | 16000 | yyy |
The content of columns may vary in different training stages. The detailed instructions for each training stage can be found in here.
Stage1 - Audio Pre-Train
In this stage, we perform audio pretraining to intialize the text-to-audio generation capability:
torchrun --standalone --nproc_per_node 8 \
scripts/train.py \
configs/wan2.1/train/stage1_audio.py \
--data-path data/meta/audio/train_audio.csv
The resulting checkpoints will be saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc/model. You can move the checkpoints to exps/audio_pretrain/ for later use.
mkdir -p exps/audio_pretrain
mv runs/000-Wan2_1_T2V_1_3B/epoch049-global_step53000 exps/audio_pretrain/
Stage2 - Audio-Video SFT
In this stage, we perform finetuning for joint audio-video generation (with LoRA adaptation):
torchrun --standalone --nproc_per_node 8 \
scripts/train_prior.py \
configs/wan2.1/train/stage2_audio_video.py \
--data-path data/meta/video/train_av_sft.csv
The resulting checkpoints will be saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc with the model and lora subfolders. You can move the checkpoints to exps/audio_video_sft/ for later use.
mkdir -p exps/audio_video_sft
mv runs/000-Wan2_1_T2V_1_3B/epoch001-global_step13000 exps/audio_video_sft/
Stage3 - Audio-Video DPO
In this stage, we perform DPO to align joint audio-video generation with human preference (reuse and update the LoRA parameters learned from the previous stage):
torchrun --standalone --nproc_per_node 8 \
scripts/train.py \
configs/wan2.1/train/stage3_audio_video_dpo.py \
--data-path /data/meta/avdpo/train_av_dpo.csv
The resulting checkpoints will be also saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc with the model and lora subfolders. You can move the checkpoints to checkpoints/ for inference and evaluation.
mv runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc checkpoints/your_model
Inference
The basic command line inference is as follows:
resolution=480p # or 240p
num_frames=65 # 4s
aspect_ratio="9:16"
DATASET="JavisBench" # or JavisBench-mini
prompt_path="data/eval/JavisBench/${DATASET}.csv"
save_dir="samples/${DATASET}"
model_path="checkpoints/your_model"
ngpus=1
torchrun --standalone --nproc_per_node ${ngpus} \
scripts/inference.py \
configs/wan2.1/inference/sample.py \
--resolution ${resolution} --num-frames ${num_frames} --aspect-ratio ${aspect_ratio} \
--prompt-path ${prompt_path} --model-path ${model_path} \
--save-dir ${save_dir} --verbose 1
# (Optional, for evaluation) Extract audios from generated videos
python -m tools.datasets.convert video ${save_dir} --output ${save_dir}/meta.csv
python -m tools.datasets.datautil ${save_dir}/meta.csv --extract-audio --audio-sr 16000
rm -f ${save_dir}/meta*.csv
Setting --verbose 2 will display the progress of a single diffusion process. And you can replace the --prompt-path ${prompt_path} with a single prompt to generate a single video, such as --prompt "a beautiful waterfall".
Evaluation
Installation
Install necessary packages:
pip install -r requirements/requirements-eval.txt
Download the meta file and data of JavisBench, and put them into data/eval/:
cd /path/to/JavisDiT
mkdir -p data/eval
huggingface-cli download --repo-type dataset JavisDiT/JavisBench --local-dir data/eval/JavisBench
Evaluation on JavisBench or JavisBench-mini
Run the following code and the results will be saved in ./evaluation_results. For details please refer to the details of JavisBench.
MAX_FRAMES=16
IMAGE_SIZE=224
MAX_AUDIO_LEN_S=4.0
# Params to calculate JavisScore
WINDOW_SIZE_S=2.0
WINDOW_OVERLAP_S=1.5
METRICS="all"
RESULTS_DIR="./evaluation_results"
DATASET="JavisBench" # or JavisBench-mini
INPUT_FILE="data/eval/JavisBench/${DATASET}.csv"
FVD_AVCACHE_PATH="data/eval/JavisBench/cache/fvd_fad/${DATASET}-vanilla-max4s.pt"
INFER_DATA_DIR="samples/${DATASET}"
python -m eval.javisbench.main \
--input_file "${INPUT_FILE}" \
--infer_data_dir "${INFER_DATA_DIR}" \
--output_file "${RESULTS_DIR}/${DATASET}.json" \
--max_frames ${MAX_FRAMES} \
--image_size ${IMAGE_SIZE} \
--max_audio_len_s ${MAX_AUDIO_LEN_S} \
--window_size_s ${WINDOW_SIZE_S} \
--window_overlap_s ${WINDOW_OVERLAP_S} \
--fvd_avcache_path ${FVD_AVCACHE_PATH} \
--metrics ${METRICS}