Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		jenbenarye
		
	commited on
		
		
					Commit 
							
							·
						
						056b95d
	
1
								Parent(s):
							
							801c17a
								
changed file name
Browse files- ml/{kto_lora.py → trainer.py} +85 -23
    	
        ml/{kto_lora.py → trainer.py}
    RENAMED
    
    | @@ -9,6 +9,7 @@ from datetime import datetime | |
| 9 | 
             
            import wandb
         | 
| 10 | 
             
            from enum import Enum
         | 
| 11 | 
             
            from typing import Optional
         | 
|  | |
| 12 |  | 
| 13 |  | 
| 14 | 
             
            # PEFT library: attach and load adapters
         | 
| @@ -104,6 +105,48 @@ def load_model_and_tokenizer(model_args): | |
| 104 |  | 
| 105 | 
             
                return model, tokenizer
         | 
| 106 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 107 | 
             
            ####################################
         | 
| 108 | 
             
            #  MAIN LOGIC
         | 
| 109 | 
             
            ####################################
         | 
| @@ -112,26 +155,29 @@ def main(): | |
| 112 | 
             
                # Initialize wandb for logging
         | 
| 113 | 
             
                wandb.init(project="kto")
         | 
| 114 |  | 
|  | |
|  | |
|  | |
| 115 | 
             
                print("Loading base model and tokenizer...")
         | 
| 116 | 
             
                model, tokenizer = load_model_and_tokenizer(model_args)
         | 
| 117 | 
             
                ref_model, _ = load_model_and_tokenizer(model_args)
         | 
| 118 | 
             
                print("Models and tokenizer loaded.")
         | 
| 119 |  | 
| 120 | 
            -
                #  | 
| 121 | 
            -
                 | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
                 | 
| 126 | 
            -
             | 
| 127 | 
            -
                if  | 
| 128 | 
            -
                     | 
| 129 | 
            -
                     | 
| 130 | 
            -
                    print(f"Loaded existing adapter for language '{script_args.language}' from {adapter_dir}.")
         | 
| 131 | 
             
                else:
         | 
| 132 | 
            -
                    #  | 
|  | |
| 133 | 
             
                    model = get_peft_model(model, peft_config)
         | 
| 134 | 
            -
                    print( | 
| 135 |  | 
| 136 | 
             
                # -----------------------------
         | 
| 137 | 
             
                # Data Preparation and Training
         | 
| @@ -180,16 +226,32 @@ def main(): | |
| 180 | 
             
                    "step": metrics.get("step")
         | 
| 181 | 
             
                })
         | 
| 182 |  | 
| 183 | 
            -
                #  | 
| 184 | 
            -
                 | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
                 | 
| 189 | 
            -
                 | 
| 190 | 
            -
             | 
| 191 | 
            -
                 | 
| 192 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 193 |  | 
| 194 | 
             
                if script_args.push_to_hub:
         | 
| 195 | 
             
                    # Using a consistent naming pattern that links to the FEEL project
         | 
|  | |
| 9 | 
             
            import wandb
         | 
| 10 | 
             
            from enum import Enum
         | 
| 11 | 
             
            from typing import Optional
         | 
| 12 | 
            +
            from pathlib import Path
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
             
            # PEFT library: attach and load adapters
         | 
|  | |
| 105 |  | 
| 106 | 
             
                return model, tokenizer
         | 
| 107 |  | 
| 108 | 
            +
            def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                Generate standardized adapter path.
         | 
| 111 | 
            +
                If timestamp is None, returns the base language directory.
         | 
| 112 | 
            +
                Otherwise, returns specific adapter version path.
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                Format: adapters/{model_name}/{language}/version_{timestamp}
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                # Clean model name (remove slashes, etc.)
         | 
| 117 | 
            +
                clean_model_name = model_name.replace('/', '_')
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                base_path = Path("adapters") / clean_model_name / language
         | 
| 120 | 
            +
                if timestamp:
         | 
| 121 | 
            +
                    return base_path / f"version_{timestamp}"
         | 
| 122 | 
            +
                return base_path
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
                Load the most recent adapter for given model and language.
         | 
| 127 | 
            +
                Returns: (loaded_model, timestamp of loaded adapter)
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                adapter_base = get_adapter_path(model_name, language)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                if not adapter_base.exists():
         | 
| 132 | 
            +
                    return None, None
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Get all version directories and sort by timestamp
         | 
| 135 | 
            +
                versions = sorted(
         | 
| 136 | 
            +
                    [d for d in adapter_base.glob("version_*")],
         | 
| 137 | 
            +
                    key=lambda x: x.name,
         | 
| 138 | 
            +
                    reverse=True
         | 
| 139 | 
            +
                )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                if not versions:
         | 
| 142 | 
            +
                    return None, None
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                latest_version = versions[0]
         | 
| 145 | 
            +
                timestamp = latest_version.name.replace("version_", "")
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
         | 
| 148 | 
            +
                return model, timestamp
         | 
| 149 | 
            +
             | 
| 150 | 
             
            ####################################
         | 
| 151 | 
             
            #  MAIN LOGIC
         | 
| 152 | 
             
            ####################################
         | 
|  | |
| 155 | 
             
                # Initialize wandb for logging
         | 
| 156 | 
             
                wandb.init(project="kto")
         | 
| 157 |  | 
| 158 | 
            +
                # Get timestamp at start of training
         | 
| 159 | 
            +
                training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
         | 
| 160 | 
            +
             | 
| 161 | 
             
                print("Loading base model and tokenizer...")
         | 
| 162 | 
             
                model, tokenizer = load_model_and_tokenizer(model_args)
         | 
| 163 | 
             
                ref_model, _ = load_model_and_tokenizer(model_args)
         | 
| 164 | 
             
                print("Models and tokenizer loaded.")
         | 
| 165 |  | 
| 166 | 
            +
                # Load existing adapter or create new one
         | 
| 167 | 
            +
                loaded_model, previous_timestamp = load_latest_adapter(
         | 
| 168 | 
            +
                    model,
         | 
| 169 | 
            +
                    model_args.model_name,
         | 
| 170 | 
            +
                    script_args.language
         | 
| 171 | 
            +
                )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                if loaded_model is not None:
         | 
| 174 | 
            +
                    model = loaded_model
         | 
| 175 | 
            +
                    print(f"Loaded existing adapter trained at {previous_timestamp}")
         | 
|  | |
| 176 | 
             
                else:
         | 
| 177 | 
            +
                    # Initialize new LoRA adapter
         | 
| 178 | 
            +
                    peft_config = get_peft_config(model_args)
         | 
| 179 | 
             
                    model = get_peft_model(model, peft_config)
         | 
| 180 | 
            +
                    print("Initialized new adapter")
         | 
| 181 |  | 
| 182 | 
             
                # -----------------------------
         | 
| 183 | 
             
                # Data Preparation and Training
         | 
|  | |
| 226 | 
             
                    "step": metrics.get("step")
         | 
| 227 | 
             
                })
         | 
| 228 |  | 
| 229 | 
            +
                # Save the adapter
         | 
| 230 | 
            +
                adapter_path = get_adapter_path(
         | 
| 231 | 
            +
                    model_args.model_name,
         | 
| 232 | 
            +
                    script_args.language,
         | 
| 233 | 
            +
                    training_timestamp
         | 
| 234 | 
            +
                )
         | 
| 235 | 
            +
                adapter_path.parent.mkdir(parents=True, exist_ok=True)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                print(f"Saving adapter to: {adapter_path}")
         | 
| 238 | 
            +
                model.save_pretrained(adapter_path)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                # Save metadata
         | 
| 241 | 
            +
                metadata = AdapterMetadata(
         | 
| 242 | 
            +
                    training_timestamp=training_timestamp,
         | 
| 243 | 
            +
                    dataset_entries=[entry["id"] for entry in dataset],
         | 
| 244 | 
            +
                    training_params={
         | 
| 245 | 
            +
                        "max_weight": script_args.max_weight,
         | 
| 246 | 
            +
                        "min_weight": script_args.min_weight,
         | 
| 247 | 
            +
                        "decay_factor": script_args.decay_factor,
         | 
| 248 | 
            +
                        "training_mode": script_args.training_mode
         | 
| 249 | 
            +
                    },
         | 
| 250 | 
            +
                    model_name=model_args.model_name,
         | 
| 251 | 
            +
                    language=script_args.language,
         | 
| 252 | 
            +
                    version=training_timestamp
         | 
| 253 | 
            +
                )
         | 
| 254 | 
            +
                metadata.save(adapter_path / "metadata.json")
         | 
| 255 |  | 
| 256 | 
             
                if script_args.push_to_hub:
         | 
| 257 | 
             
                    # Using a consistent naming pattern that links to the FEEL project
         | 
