|  | """ | 
					
						
						|  | finetune Phi-4-multimodal-instruct on an image task | 
					
						
						|  |  | 
					
						
						|  | scipy==1.15.1 | 
					
						
						|  | peft==0.13.2 | 
					
						
						|  | backoff==2.2.1 | 
					
						
						|  | transformers==4.47.0 | 
					
						
						|  | accelerate==1.3.0 | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import json | 
					
						
						|  | import os | 
					
						
						|  | import tempfile | 
					
						
						|  | import zipfile | 
					
						
						|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from accelerate import Accelerator | 
					
						
						|  | from accelerate.utils import gather_object | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | from huggingface_hub import hf_hub_download | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from torch.utils.data import Dataset | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | from transformers import ( | 
					
						
						|  | AutoModelForCausalLM, | 
					
						
						|  | AutoProcessor, | 
					
						
						|  | BatchFeature, | 
					
						
						|  | Trainer, | 
					
						
						|  | TrainingArguments, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | DEFAULT_INSTSRUCTION = "Answer with the option's letter from the given choices directly." | 
					
						
						|  | _IGNORE_INDEX = -100 | 
					
						
						|  | _TRAIN_SIZE = 8000 | 
					
						
						|  | _EVAL_SIZE = 500 | 
					
						
						|  | _MAX_TRAINING_LENGTH = 8192 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PmcVqaTrainDataset(Dataset): | 
					
						
						|  | def __init__(self, processor, data_size, instruction=DEFAULT_INSTSRUCTION): | 
					
						
						|  |  | 
					
						
						|  | file_path = hf_hub_download( | 
					
						
						|  | repo_id='xmcmic/PMC-VQA', | 
					
						
						|  | filename='images_2.zip', | 
					
						
						|  | repo_type='dataset', | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f'File downloaded to: {file_path}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.image_folder = Path(tempfile.mkdtemp()) | 
					
						
						|  | with zipfile.ZipFile(file_path, 'r') as zip_ref: | 
					
						
						|  | zip_ref.extractall(self.image_folder) | 
					
						
						|  |  | 
					
						
						|  | data_files = { | 
					
						
						|  | 'train': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/train_2.csv', | 
					
						
						|  | } | 
					
						
						|  | split = 'train' if data_size is None else f'train[:{data_size}]' | 
					
						
						|  | self.annotations = load_dataset('xmcmic/PMC-VQA', data_files=data_files, split=split) | 
					
						
						|  | self.processor = processor | 
					
						
						|  | self.instruction = instruction | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.annotations) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, idx): | 
					
						
						|  | """ | 
					
						
						|  | {'index': 35, | 
					
						
						|  | 'Figure_path': 'PMC8253797_Fig4_11.jpg', | 
					
						
						|  | 'Caption': 'A slightly altered cell . (c-c‴) A highly altered cell as seen from 4 different angles . Note mitochondria/mitochondrial networks (green), Golgi complexes (red), cell nuclei (light blue) and the cell outline (yellow).', | 
					
						
						|  | 'Question': ' What color is used to label the Golgi complexes in the image?', | 
					
						
						|  | 'Choice A': ' A: Green ', | 
					
						
						|  | 'Choice B': ' B: Red ', | 
					
						
						|  | 'Choice C': ' C: Light blue ', | 
					
						
						|  | 'Choice D': ' D: Yellow', | 
					
						
						|  | 'Answer': 'B', | 
					
						
						|  | 'split': 'train'} | 
					
						
						|  | """ | 
					
						
						|  | annotation = self.annotations[idx] | 
					
						
						|  | image = Image.open(self.image_folder / 'figures' / annotation['Figure_path']) | 
					
						
						|  | question = annotation['Question'] | 
					
						
						|  | choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)] | 
					
						
						|  | user_message = { | 
					
						
						|  | 'role': 'user', | 
					
						
						|  | 'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]), | 
					
						
						|  | } | 
					
						
						|  | prompt = self.processor.tokenizer.apply_chat_template( | 
					
						
						|  | [user_message], tokenize=False, add_generation_prompt=True | 
					
						
						|  | ) | 
					
						
						|  | answer = f'{annotation["Answer"]}<|end|><|endoftext|>' | 
					
						
						|  | inputs = self.processor(prompt, images=[image], return_tensors='pt') | 
					
						
						|  |  | 
					
						
						|  | answer_ids = self.processor.tokenizer(answer, return_tensors='pt').input_ids | 
					
						
						|  |  | 
					
						
						|  | input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1) | 
					
						
						|  | labels = torch.full_like(input_ids, _IGNORE_INDEX) | 
					
						
						|  | labels[:, -answer_ids.shape[1] :] = answer_ids | 
					
						
						|  |  | 
					
						
						|  | if input_ids.size(1) > _MAX_TRAINING_LENGTH: | 
					
						
						|  | input_ids = input_ids[:, :_MAX_TRAINING_LENGTH] | 
					
						
						|  | labels = labels[:, :_MAX_TRAINING_LENGTH] | 
					
						
						|  | if torch.all(labels == _IGNORE_INDEX).item(): | 
					
						
						|  |  | 
					
						
						|  | labels[:, -1] = self.processor.tokenizer.eos_token_id | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'input_ids': input_ids, | 
					
						
						|  | 'labels': labels, | 
					
						
						|  | 'input_image_embeds': inputs.input_image_embeds, | 
					
						
						|  | 'image_attention_mask': inputs.image_attention_mask, | 
					
						
						|  | 'image_sizes': inputs.image_sizes, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def __del__(self): | 
					
						
						|  | __import__('shutil').rmtree(self.image_folder) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PmcVqaEvalDataset(Dataset): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, processor, data_size, instruction=DEFAULT_INSTSRUCTION, rank=0, world_size=1 | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | file_path = hf_hub_download( | 
					
						
						|  | repo_id='xmcmic/PMC-VQA', | 
					
						
						|  | filename='images_2.zip', | 
					
						
						|  | repo_type='dataset', | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print(f'File downloaded to: {file_path}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.image_folder = Path(tempfile.mkdtemp()) | 
					
						
						|  | with zipfile.ZipFile(file_path, 'r') as zip_ref: | 
					
						
						|  | zip_ref.extractall(self.image_folder) | 
					
						
						|  |  | 
					
						
						|  | data_files = { | 
					
						
						|  | 'test': 'https://huggingface.co/datasets/xmcmic/PMC-VQA/resolve/main/test_2.csv', | 
					
						
						|  | } | 
					
						
						|  | split = 'test' if data_size is None else f'test[:{data_size}]' | 
					
						
						|  | self.annotations = load_dataset( | 
					
						
						|  | 'xmcmic/PMC-VQA', data_files=data_files, split=split | 
					
						
						|  | ).shard(num_shards=world_size, index=rank) | 
					
						
						|  | self.processor = processor | 
					
						
						|  | self.instruction = instruction | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.annotations) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, idx): | 
					
						
						|  | """ | 
					
						
						|  | {'index': 62, | 
					
						
						|  | 'Figure_path': 'PMC8253867_Fig2_41.jpg', | 
					
						
						|  | 'Caption': 'CT pulmonary angiogram reveals encasement and displacement of the left anterior descending coronary artery ( blue arrows ).', | 
					
						
						|  | 'Question': ' What is the name of the artery encased and displaced in the image? ', | 
					
						
						|  | 'Choice A': ' A: Right Coronary Artery ', | 
					
						
						|  | 'Choice B': ' B: Left Anterior Descending Coronary Artery ', | 
					
						
						|  | 'Choice C': ' C: Circumflex Coronary Artery ', | 
					
						
						|  | 'Choice D': ' D: Superior Mesenteric Artery ', | 
					
						
						|  | 'Answer': 'B', | 
					
						
						|  | 'split': 'test'} | 
					
						
						|  | """ | 
					
						
						|  | annotation = self.annotations[idx] | 
					
						
						|  | image = Image.open(self.image_folder / 'figures' / annotation['Figure_path']) | 
					
						
						|  | question = annotation['Question'] | 
					
						
						|  | choices = [annotation[f'Choice {chr(ord("A") + i)}'] for i in range(4)] | 
					
						
						|  | user_message = { | 
					
						
						|  | 'role': 'user', | 
					
						
						|  | 'content': '<|image_1|>' + '\n'.join([question] + choices + [self.instruction]), | 
					
						
						|  | } | 
					
						
						|  | prompt = self.processor.tokenizer.apply_chat_template( | 
					
						
						|  | [user_message], tokenize=False, add_generation_prompt=True | 
					
						
						|  | ) | 
					
						
						|  | answer = annotation['Answer'] | 
					
						
						|  | inputs = self.processor(prompt, images=[image], return_tensors='pt') | 
					
						
						|  |  | 
					
						
						|  | unique_id = f'{annotation["index"]:010d}' | 
					
						
						|  | return { | 
					
						
						|  | 'id': unique_id, | 
					
						
						|  | 'input_ids': inputs.input_ids, | 
					
						
						|  | 'input_image_embeds': inputs.input_image_embeds, | 
					
						
						|  | 'image_attention_mask': inputs.image_attention_mask, | 
					
						
						|  | 'image_sizes': inputs.image_sizes, | 
					
						
						|  | 'answer': answer, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def __del__(self): | 
					
						
						|  | __import__('shutil').rmtree(self.image_folder) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pad_sequence(sequences, padding_side='right', padding_value=0): | 
					
						
						|  | """ | 
					
						
						|  | Pad a list of sequences to the same length. | 
					
						
						|  | sequences: list of tensors in [seq_len, *] shape | 
					
						
						|  | """ | 
					
						
						|  | assert padding_side in ['right', 'left'] | 
					
						
						|  | max_size = sequences[0].size() | 
					
						
						|  | trailing_dims = max_size[1:] | 
					
						
						|  | max_len = max(len(seq) for seq in sequences) | 
					
						
						|  | batch_size = len(sequences) | 
					
						
						|  | output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) | 
					
						
						|  | for i, seq in enumerate(sequences): | 
					
						
						|  | length = seq.size(0) | 
					
						
						|  | if padding_side == 'right': | 
					
						
						|  | output.data[i, :length] = seq | 
					
						
						|  | else: | 
					
						
						|  | output.data[i, -length:] = seq | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cat_with_pad(tensors, dim, padding_value=0): | 
					
						
						|  | """ | 
					
						
						|  | cat along dim, while pad to max for all other dims | 
					
						
						|  | """ | 
					
						
						|  | ndim = tensors[0].dim() | 
					
						
						|  | assert all( | 
					
						
						|  | t.dim() == ndim for t in tensors[1:] | 
					
						
						|  | ), 'All tensors must have the same number of dimensions' | 
					
						
						|  |  | 
					
						
						|  | out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] | 
					
						
						|  | out_size[dim] = sum(t.shape[dim] for t in tensors) | 
					
						
						|  | output = tensors[0].new_full(out_size, padding_value) | 
					
						
						|  |  | 
					
						
						|  | index = 0 | 
					
						
						|  | for t in tensors: | 
					
						
						|  |  | 
					
						
						|  | slices = [slice(0, t.shape[d]) for d in range(ndim)] | 
					
						
						|  |  | 
					
						
						|  | slices[dim] = slice(index, index + t.shape[dim]) | 
					
						
						|  |  | 
					
						
						|  | output[slices] = t | 
					
						
						|  | index += t.shape[dim] | 
					
						
						|  |  | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pmc_vqa_collate_fn(batch): | 
					
						
						|  | input_ids_list = [] | 
					
						
						|  | labels_list = [] | 
					
						
						|  | input_image_embeds_list = [] | 
					
						
						|  | image_attention_mask_list = [] | 
					
						
						|  | image_sizes_list = [] | 
					
						
						|  | for inputs in batch: | 
					
						
						|  | input_ids_list.append(inputs['input_ids'][0]) | 
					
						
						|  | labels_list.append(inputs['labels'][0]) | 
					
						
						|  | input_image_embeds_list.append(inputs['input_image_embeds']) | 
					
						
						|  | image_attention_mask_list.append(inputs['image_attention_mask']) | 
					
						
						|  | image_sizes_list.append(inputs['image_sizes']) | 
					
						
						|  |  | 
					
						
						|  | input_ids = pad_sequence(input_ids_list, padding_side='right', padding_value=0) | 
					
						
						|  | labels = pad_sequence(labels_list, padding_side='right', padding_value=0) | 
					
						
						|  | attention_mask = (input_ids != 0).long() | 
					
						
						|  | input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0) | 
					
						
						|  | image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0) | 
					
						
						|  | image_sizes = torch.cat(image_sizes_list) | 
					
						
						|  |  | 
					
						
						|  | return BatchFeature( | 
					
						
						|  | { | 
					
						
						|  | 'input_ids': input_ids, | 
					
						
						|  | 'labels': labels, | 
					
						
						|  | 'attention_mask': attention_mask, | 
					
						
						|  | 'input_image_embeds': input_image_embeds, | 
					
						
						|  | 'image_attention_mask': image_attention_mask, | 
					
						
						|  | 'image_sizes': image_sizes, | 
					
						
						|  | 'input_mode': 1, | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pmc_vqa_eval_collate_fn(batch): | 
					
						
						|  | input_ids_list = [] | 
					
						
						|  | input_image_embeds_list = [] | 
					
						
						|  | image_attention_mask_list = [] | 
					
						
						|  | image_sizes_list = [] | 
					
						
						|  | all_unique_ids = [] | 
					
						
						|  | all_answers = [] | 
					
						
						|  | for inputs in batch: | 
					
						
						|  | input_ids_list.append(inputs['input_ids'][0]) | 
					
						
						|  | input_image_embeds_list.append(inputs['input_image_embeds']) | 
					
						
						|  | image_attention_mask_list.append(inputs['image_attention_mask']) | 
					
						
						|  | image_sizes_list.append(inputs['image_sizes']) | 
					
						
						|  | all_unique_ids.append(inputs['id']) | 
					
						
						|  | all_answers.append(inputs['answer']) | 
					
						
						|  |  | 
					
						
						|  | input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) | 
					
						
						|  | attention_mask = (input_ids != 0).long() | 
					
						
						|  | input_image_embeds = cat_with_pad(input_image_embeds_list, dim=0) | 
					
						
						|  | image_attention_mask = cat_with_pad(image_attention_mask_list, dim=0) | 
					
						
						|  | image_sizes = torch.cat(image_sizes_list) | 
					
						
						|  |  | 
					
						
						|  | return ( | 
					
						
						|  | all_unique_ids, | 
					
						
						|  | all_answers, | 
					
						
						|  | BatchFeature( | 
					
						
						|  | { | 
					
						
						|  | 'input_ids': input_ids, | 
					
						
						|  | 'attention_mask': attention_mask, | 
					
						
						|  | 'input_image_embeds': input_image_embeds, | 
					
						
						|  | 'image_attention_mask': image_attention_mask, | 
					
						
						|  | 'image_sizes': image_sizes, | 
					
						
						|  | 'input_mode': 1, | 
					
						
						|  | } | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_model(model_name_or_path, use_flash_attention=False): | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | model_name_or_path, | 
					
						
						|  | torch_dtype=torch.bfloat16 if use_flash_attention else torch.float32, | 
					
						
						|  | _attn_implementation='flash_attention_2' if use_flash_attention else 'sdpa', | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | ).to('cuda') | 
					
						
						|  |  | 
					
						
						|  | del model.model.embed_tokens_extend.audio_embed | 
					
						
						|  | for layer in model.model.layers: | 
					
						
						|  |  | 
					
						
						|  | del layer.mlp.down_proj.lora_A.speech | 
					
						
						|  | del layer.mlp.down_proj.lora_B.speech | 
					
						
						|  | del layer.mlp.gate_up_proj.lora_A.speech | 
					
						
						|  | del layer.mlp.gate_up_proj.lora_B.speech | 
					
						
						|  | del layer.self_attn.o_proj.lora_A.speech | 
					
						
						|  | del layer.self_attn.o_proj.lora_B.speech | 
					
						
						|  | del layer.self_attn.qkv_proj.lora_A.speech | 
					
						
						|  | del layer.self_attn.qkv_proj.lora_B.speech | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def evaluate( | 
					
						
						|  | model, processor, eval_dataset, save_path=None, disable_tqdm=False, eval_batch_size=1 | 
					
						
						|  | ): | 
					
						
						|  | rank = int(os.environ.get('RANK', 0)) | 
					
						
						|  | local_rank = int(os.environ.get('LOCAL_RANK', 0)) | 
					
						
						|  |  | 
					
						
						|  | model.eval() | 
					
						
						|  | all_answers = [] | 
					
						
						|  | all_generated_texts = [] | 
					
						
						|  |  | 
					
						
						|  | eval_dataloader = torch.utils.data.DataLoader( | 
					
						
						|  | eval_dataset, | 
					
						
						|  | batch_size=eval_batch_size, | 
					
						
						|  | collate_fn=pmc_vqa_eval_collate_fn, | 
					
						
						|  | shuffle=False, | 
					
						
						|  | drop_last=False, | 
					
						
						|  | num_workers=4, | 
					
						
						|  | prefetch_factor=2, | 
					
						
						|  | pin_memory=True, | 
					
						
						|  | ) | 
					
						
						|  | for ids, answers, inputs in tqdm( | 
					
						
						|  | eval_dataloader, disable=(rank != 0) or disable_tqdm, desc='running eval' | 
					
						
						|  | ): | 
					
						
						|  | all_answers.extend({'id': i, 'answer': a.strip().lower()} for i, a in zip(ids, answers)) | 
					
						
						|  |  | 
					
						
						|  | inputs = inputs.to(f'cuda:{local_rank}') | 
					
						
						|  | generated_ids = model.generate( | 
					
						
						|  | **inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=64 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | input_len = inputs.input_ids.size(1) | 
					
						
						|  | generated_texts = processor.batch_decode( | 
					
						
						|  | generated_ids[:, input_len:], | 
					
						
						|  | skip_special_tokens=True, | 
					
						
						|  | clean_up_tokenization_spaces=False, | 
					
						
						|  | ) | 
					
						
						|  | all_generated_texts.extend( | 
					
						
						|  | {'id': i, 'generated_text': g.strip().lower()} for i, g in zip(ids, generated_texts) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | all_answers = gather_object(all_answers) | 
					
						
						|  | all_generated_texts = gather_object(all_generated_texts) | 
					
						
						|  |  | 
					
						
						|  | if rank == 0: | 
					
						
						|  | assert len(all_answers) == len(all_generated_texts) | 
					
						
						|  | acc = sum( | 
					
						
						|  | a['answer'] == g['generated_text'] for a, g in zip(all_answers, all_generated_texts) | 
					
						
						|  | ) / len(all_answers) | 
					
						
						|  | if save_path: | 
					
						
						|  | with open(save_path, 'w') as f: | 
					
						
						|  | save_dict = { | 
					
						
						|  | 'answers_unique': all_answers, | 
					
						
						|  | 'generated_texts_unique': all_generated_texts, | 
					
						
						|  | 'accuracy': acc, | 
					
						
						|  | } | 
					
						
						|  | json.dump(save_dict, f) | 
					
						
						|  |  | 
					
						
						|  | return acc | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--model_name_or_path', | 
					
						
						|  | type=str, | 
					
						
						|  | default='microsoft/Phi-4-multimodal-instruct', | 
					
						
						|  | help='Model name or path to load from', | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument('--use_flash_attention', action='store_true', help='Use Flash Attention') | 
					
						
						|  | parser.add_argument('--output_dir', type=str, default='./output/', help='Output directory') | 
					
						
						|  | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--batch_size_per_gpu', | 
					
						
						|  | type=int, | 
					
						
						|  | default=1, | 
					
						
						|  | help='Batch size per GPU (adjust this to fit in GPU memory)', | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--dynamic_hd', | 
					
						
						|  | type=int, | 
					
						
						|  | default=36, | 
					
						
						|  | help='Number of maximum image crops', | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument( | 
					
						
						|  | '--num_train_epochs', type=int, default=1, help='Number of training epochs' | 
					
						
						|  | ) | 
					
						
						|  | parser.add_argument('--learning_rate', type=float, default=4.0e-5, help='Learning rate') | 
					
						
						|  | parser.add_argument('--wd', type=float, default=0.01, help='Weight decay') | 
					
						
						|  | parser.add_argument('--no_tqdm', dest='tqdm', action='store_false', help='Disable tqdm') | 
					
						
						|  | parser.add_argument('--full_run', action='store_true', help='Run the full training and eval') | 
					
						
						|  | args = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | accelerator = Accelerator() | 
					
						
						|  |  | 
					
						
						|  | with accelerator.local_main_process_first(): | 
					
						
						|  | processor = AutoProcessor.from_pretrained( | 
					
						
						|  | args.model_name_or_path, | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | dynamic_hd=args.dynamic_hd, | 
					
						
						|  | ) | 
					
						
						|  | model = create_model( | 
					
						
						|  | args.model_name_or_path, | 
					
						
						|  | use_flash_attention=args.use_flash_attention, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | model.set_lora_adapter('vision') | 
					
						
						|  | for param in model.model.embed_tokens_extend.image_embed.parameters(): | 
					
						
						|  | param.requires_grad = True | 
					
						
						|  |  | 
					
						
						|  | rank = int(os.environ.get('RANK', 0)) | 
					
						
						|  | world_size = int(os.environ.get('WORLD_SIZE', 1)) | 
					
						
						|  |  | 
					
						
						|  | train_dataset = PmcVqaTrainDataset(processor, data_size=None if args.full_run else _TRAIN_SIZE) | 
					
						
						|  | eval_dataset = PmcVqaEvalDataset( | 
					
						
						|  | processor, | 
					
						
						|  | data_size=None if args.full_run else _EVAL_SIZE, | 
					
						
						|  | rank=rank, | 
					
						
						|  | world_size=world_size, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | num_gpus = accelerator.num_processes | 
					
						
						|  | print(f'training on {num_gpus} GPUs') | 
					
						
						|  | assert ( | 
					
						
						|  | args.batch_size % (num_gpus * args.batch_size_per_gpu) == 0 | 
					
						
						|  | ), 'Batch size must be divisible by the number of GPUs' | 
					
						
						|  | gradient_accumulation_steps = args.batch_size // (num_gpus * args.batch_size_per_gpu) | 
					
						
						|  |  | 
					
						
						|  | if args.use_flash_attention: | 
					
						
						|  | fp16 = False | 
					
						
						|  | bf16 = True | 
					
						
						|  | else: | 
					
						
						|  | fp16 = True | 
					
						
						|  | bf16 = False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | training_args = TrainingArguments( | 
					
						
						|  | num_train_epochs=args.num_train_epochs, | 
					
						
						|  | per_device_train_batch_size=args.batch_size_per_gpu, | 
					
						
						|  | gradient_checkpointing=True, | 
					
						
						|  | gradient_checkpointing_kwargs={'use_reentrant': False}, | 
					
						
						|  | gradient_accumulation_steps=gradient_accumulation_steps, | 
					
						
						|  | optim='adamw_torch', | 
					
						
						|  | adam_beta1=0.9, | 
					
						
						|  | adam_beta2=0.95, | 
					
						
						|  | adam_epsilon=1e-7, | 
					
						
						|  | learning_rate=args.learning_rate, | 
					
						
						|  | weight_decay=args.wd, | 
					
						
						|  | max_grad_norm=1.0, | 
					
						
						|  | lr_scheduler_type='linear', | 
					
						
						|  | warmup_steps=50, | 
					
						
						|  | logging_steps=10, | 
					
						
						|  | output_dir=args.output_dir, | 
					
						
						|  | save_strategy='no', | 
					
						
						|  | save_total_limit=10, | 
					
						
						|  | save_only_model=True, | 
					
						
						|  | bf16=bf16, | 
					
						
						|  | fp16=fp16, | 
					
						
						|  | remove_unused_columns=False, | 
					
						
						|  | report_to='none', | 
					
						
						|  | deepspeed=None, | 
					
						
						|  | disable_tqdm=not args.tqdm, | 
					
						
						|  | dataloader_num_workers=4, | 
					
						
						|  | ddp_find_unused_parameters=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | out_path = Path(training_args.output_dir) | 
					
						
						|  | out_path.mkdir(parents=True, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | acc = evaluate( | 
					
						
						|  | model, | 
					
						
						|  | processor, | 
					
						
						|  | eval_dataset, | 
					
						
						|  | save_path=out_path / 'eval_before.json', | 
					
						
						|  | disable_tqdm=not args.tqdm, | 
					
						
						|  | eval_batch_size=args.batch_size_per_gpu, | 
					
						
						|  | ) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | print(f'Accuracy before finetuning: {acc}') | 
					
						
						|  |  | 
					
						
						|  | trainer = Trainer( | 
					
						
						|  | model=model, | 
					
						
						|  | args=training_args, | 
					
						
						|  | data_collator=pmc_vqa_collate_fn, | 
					
						
						|  | train_dataset=train_dataset, | 
					
						
						|  | ) | 
					
						
						|  | trainer.train() | 
					
						
						|  | trainer.save_model() | 
					
						
						|  | accelerator.wait_for_everyone() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | del model | 
					
						
						|  | del trainer | 
					
						
						|  | __import__('gc').collect() | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | training_args.output_dir, | 
					
						
						|  | torch_dtype=torch.bfloat16 if args.use_flash_attention else torch.float32, | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | _attn_implementation='flash_attention_2' if args.use_flash_attention else 'sdpa', | 
					
						
						|  | ).to('cuda') | 
					
						
						|  |  | 
					
						
						|  | acc = evaluate( | 
					
						
						|  | model, | 
					
						
						|  | processor, | 
					
						
						|  | eval_dataset, | 
					
						
						|  | save_path=out_path / 'eval_after.json', | 
					
						
						|  | disable_tqdm=not args.tqdm, | 
					
						
						|  | eval_batch_size=args.batch_size_per_gpu, | 
					
						
						|  | ) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | print(f'Accuracy after finetuning: {acc}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | main() |