Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import transformers | |
| from transformers import Idefics2ForConditionalGeneration | |
| from peft import LoraConfig, get_peft_model | |
| from joint_inference import IdeficsJointInferenceModel | |
| def get_model(): | |
| # Initialize the model | |
| repo = 'lil-lab/cogen' | |
| checkpoint = "HuggingFaceM4/idefics2-8b" | |
| model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) | |
| # Add LoRA adapters | |
| target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)' | |
| lora_config = LoraConfig( | |
| r=16, lora_alpha=8, | |
| lora_dropout=0.1, | |
| target_modules=target_modules, | |
| init_lora_weights="gaussian" | |
| ) | |
| model = get_peft_model(model, lora_config, adapter_name="initial") | |
| model.load_adapter(repo, "initial", revision="r0_full") | |
| # Add other adapter | |
| new_targets = set() | |
| for n, p in model.named_parameters(): | |
| if 'lora' in n: | |
| new_targets.add(n[17:n.find('lora')-1]) | |
| new_targets = list(new_targets) | |
| lora_config = LoraConfig( | |
| r=16, lora_alpha=8, | |
| lora_dropout=0.1, | |
| target_modules=new_targets, | |
| init_lora_weights="gaussian" | |
| ) | |
| model.add_adapter('final', lora_config) | |
| model.load_adapter(repo, "final", revision="r3_full") | |
| model = IdeficsJointInferenceModel(0.5, 0, model=model) | |
| model.eval() | |
| return model | |