NightPrince commited on
Commit
1f32233
·
verified ·
1 Parent(s): 50a8a35

FIRST COMMIT

Browse files
Files changed (1) hide show
  1. launch_training.py +45 -0
launch_training.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
3
+
4
+ # Load dataset
5
+ dataset = load_dataset("Abdelkareem/wikihow-arabic-summarization")
6
+
7
+ # Load the model and tokenizer
8
+ model_name = "UBC-NLP/AraT5v2-base-1024"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
+
12
+ # Preprocessing function to tokenize the dataset
13
+ def preprocess_function(examples):
14
+ inputs = examples["article"]
15
+ targets = examples["summarize"]
16
+ model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
17
+ labels = tokenizer(targets, max_length=150, truncation=True)
18
+ model_inputs["labels"] = labels["input_ids"]
19
+ return model_inputs
20
+
21
+ # Apply preprocessing to the dataset
22
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
23
+
24
+ # Define training arguments
25
+ training_args = TrainingArguments(
26
+ output_dir="./results",
27
+ evaluation_strategy="epoch",
28
+ learning_rate=2e-5,
29
+ per_device_train_batch_size=4,
30
+ per_device_eval_batch_size=4,
31
+ num_train_epochs=3,
32
+ weight_decay=0.01,
33
+ logging_dir="./logs"
34
+ )
35
+
36
+ # Initialize the Trainer
37
+ trainer = Trainer(
38
+ model=model,
39
+ args=training_args,
40
+ train_dataset=tokenized_datasets["train"],
41
+ eval_dataset=tokenized_datasets["validation"]
42
+ )
43
+
44
+ # Start the training process
45
+ trainer.train()