YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Improved Quality, Synchrony, and Preference Alignment for Joint Audio-Video Generation

This codebase is built upon JavisDiT. Many thanks to their contribution.

Installation

For CUDA 12.1, you can install the dependencies with the following commands.

# create a virtual env and activate (conda as an example)
conda create -n javisdit python=3.10
conda activate javisdit

# install torch, torchvision and xformers
pip install -r requirements/requirements-cu121.txt

# install ffpmeg
conda install "ffmpeg<7" -c conda-forge -y

# the default installation is for inference only
pip install -v .
# for development mode, `pip install -v -e .`
# to skip dependencies, `pip install -v -e . --no-deps`

# replace
export PYTHON_SITE_PACKAGES=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
cp assets/src/pytorchvideo_augmentations.py ${PYTHON_SITE_PACKAGES}/pytorchvideo/transforms/augmentations.py
cp assets/src/funasr_utils_load_utils.py ${PYTHON_SITE_PACKAGES}/funasr/utils/load_utils.py

# (optional but recommended) install flash attention
# set enable_flash_attn=False in config to disable flash attention
pip install packaging ninja
pip install flash-attn --no-build-isolation

Training

Data Preparation

In this project, we use a .csv file to manage all the training entries and their attributes for efficient training:

path id relpath num_frames height width aspect_ratio fps resolution audio_path audio_fps text
/path/to/xxx.mp4 xxx xxx.mp4 240 480 640 0.75 24 307200 /path/to/xxx.wav 16000 yyy

The content of columns may vary in different training stages. The detailed instructions for each training stage can be found in here.

Stage1 - Audio Pre-Train

In this stage, we perform audio pretraining to intialize the text-to-audio generation capability:

torchrun --standalone --nproc_per_node 8 \
    scripts/train.py \
    configs/wan2.1/train/stage1_audio.py \
    --data-path data/meta/audio/train_audio.csv

The resulting checkpoints will be saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc/model. You can move the checkpoints to exps/audio_pretrain/ for later use.

mkdir -p exps/audio_pretrain
mv runs/000-Wan2_1_T2V_1_3B/epoch049-global_step53000 exps/audio_pretrain/

Stage2 - Audio-Video SFT

In this stage, we perform finetuning for joint audio-video generation (with LoRA adaptation):

torchrun --standalone --nproc_per_node 8 \
    scripts/train_prior.py \
    configs/wan2.1/train/stage2_audio_video.py \
    --data-path data/meta/video/train_av_sft.csv

The resulting checkpoints will be saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc with the model and lora subfolders. You can move the checkpoints to exps/audio_video_sft/ for later use.

mkdir -p exps/audio_video_sft
mv runs/000-Wan2_1_T2V_1_3B/epoch001-global_step13000 exps/audio_video_sft/

Stage3 - Audio-Video DPO

In this stage, we perform DPO to align joint audio-video generation with human preference (reuse and update the LoRA parameters learned from the previous stage):

torchrun --standalone --nproc_per_node 8 \
    scripts/train.py \
    configs/wan2.1/train/stage3_audio_video_dpo.py \
    --data-path /data/meta/avdpo/train_av_dpo.csv

The resulting checkpoints will be also saved at runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc with the model and lora subfolders. You can move the checkpoints to checkpoints/ for inference and evaluation.

mv runs/0aa-Wan2_1_T2V_1_3B/epoch0bb-global_stepccc checkpoints/your_model

Inference

The basic command line inference is as follows:

resolution=480p # or 240p
num_frames=65  # 4s
aspect_ratio="9:16"

DATASET="JavisBench"  # or JavisBench-mini
prompt_path="data/eval/JavisBench/${DATASET}.csv"
save_dir="samples/${DATASET}"

model_path="checkpoints/your_model"
ngpus=1

torchrun --standalone --nproc_per_node ${ngpus} \
    scripts/inference.py \
    configs/wan2.1/inference/sample.py \
    --resolution ${resolution} --num-frames ${num_frames} --aspect-ratio ${aspect_ratio} \
    --prompt-path ${prompt_path} --model-path ${model_path} \
    --save-dir ${save_dir} --verbose 1  

# (Optional, for evaluation) Extract audios from generated videos
python -m tools.datasets.convert video ${save_dir} --output ${save_dir}/meta.csv
python -m tools.datasets.datautil ${save_dir}/meta.csv --extract-audio --audio-sr 16000
rm -f ${save_dir}/meta*.csv

Setting --verbose 2 will display the progress of a single diffusion process. And you can replace the --prompt-path ${prompt_path} with a single prompt to generate a single video, such as --prompt "a beautiful waterfall".

Evaluation

Installation

Install necessary packages:

pip install -r requirements/requirements-eval.txt

Download the meta file and data of JavisBench, and put them into data/eval/:

cd /path/to/JavisDiT
mkdir -p data/eval

huggingface-cli download --repo-type dataset JavisDiT/JavisBench --local-dir data/eval/JavisBench

Evaluation on JavisBench or JavisBench-mini

Run the following code and the results will be saved in ./evaluation_results. For details please refer to the details of JavisBench.

MAX_FRAMES=16
IMAGE_SIZE=224
MAX_AUDIO_LEN_S=4.0

# Params to calculate JavisScore
WINDOW_SIZE_S=2.0
WINDOW_OVERLAP_S=1.5

METRICS="all" 
RESULTS_DIR="./evaluation_results"

DATASET="JavisBench"  # or JavisBench-mini
INPUT_FILE="data/eval/JavisBench/${DATASET}.csv"
FVD_AVCACHE_PATH="data/eval/JavisBench/cache/fvd_fad/${DATASET}-vanilla-max4s.pt"
INFER_DATA_DIR="samples/${DATASET}"

python -m eval.javisbench.main \
  --input_file "${INPUT_FILE}" \
  --infer_data_dir "${INFER_DATA_DIR}" \
  --output_file "${RESULTS_DIR}/${DATASET}.json" \
  --max_frames ${MAX_FRAMES} \
  --image_size ${IMAGE_SIZE} \
  --max_audio_len_s ${MAX_AUDIO_LEN_S} \
  --window_size_s ${WINDOW_SIZE_S} \
  --window_overlap_s ${WINDOW_OVERLAP_S} \
  --fvd_avcache_path ${FVD_AVCACHE_PATH} \
  --metrics ${METRICS}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support