danhtran2mind commited on
Commit
acf8533
·
verified ·
1 Parent(s): da40968

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +103 -103
main.py CHANGED
@@ -1,104 +1,104 @@
1
- import tensorflow as tf
2
- from translator import Translator
3
- from utils import tokenizer_utils
4
- from utils.preprocessing import input_processing, output_processing
5
- from models.transformer import Transformer
6
- from models.encoder import Encoder
7
- from models.decoder import Decoder
8
- from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
9
- from models.utils import masked_loss, masked_accuracy
10
- import argparse
11
-
12
- def main(sentences: list, model: tf.keras.Model, en_tokenizer, vi_tokenizer) -> None:
13
- """
14
- Translates input English sentences to Vietnamese using a pre-trained model.
15
-
16
- Args:
17
- sentences (list): List of English sentences to translate.
18
- model (tf.keras.Model): The pre-trained translation model.
19
- en_tokenizer: English tokenizer.
20
- vi_tokenizer: Vietnamese tokenizer.
21
- """
22
- # Initialize the translator with tokenizers and the model
23
- translator = Translator(en_tokenizer, vi_tokenizer, model)
24
-
25
- # Process and translate each sentence
26
- for sentence in sentences:
27
- processed_sentence = input_processing(sentence)
28
- translated_text = translator(processed_sentence)
29
- translated_text = output_processing(translated_text)
30
-
31
- # Display the input and translated text
32
- print("Input:", processed_sentence)
33
- print("Translated:", translated_text)
34
- print("-" * 50)
35
-
36
- if __name__ == "__main__":
37
- # Set up argument parser
38
- parser = argparse.ArgumentParser(
39
- description="Translate English sentences to Vietnamese using a pre-trained transformer model.",
40
- epilog="Example: python translate.py --sentence 'Hello, world!' --sentence 'The sun is shining.'"
41
- )
42
- parser.add_argument(
43
- "--sentence",
44
- type=str,
45
- nargs="*",
46
- default=[
47
- (
48
- "For at least six centuries, residents along a lake in the mountains of central Japan "
49
- "have marked the depth of winter by celebrating the return of a natural phenomenon "
50
- "once revered as the trail of a wandering god."
51
- )
52
- ],
53
- help="One or more English sentences to translate (default: provided example sentence)"
54
- )
55
- parser.add_argument(
56
- "--model_path",
57
- type=str,
58
- default="saved_models/en_vi_translation.keras",
59
- help="Path to the pre-trained model file (default: saved_models/en_vi_translation.keras)"
60
- )
61
-
62
- # Parse arguments
63
- args = parser.parse_args()
64
-
65
- # Define custom objects required for loading the model
66
- custom_objects = {
67
- "Transformer": Transformer,
68
- "Encoder": Encoder,
69
- "Decoder": Decoder,
70
- "EncoderLayer": EncoderLayer,
71
- "DecoderLayer": DecoderLayer,
72
- "MultiHeadAttention": MultiHeadAttention,
73
- "point_wise_feed_forward_network": point_wise_feed_forward_network,
74
- "masked_loss": masked_loss,
75
- "masked_accuracy": masked_accuracy,
76
- }
77
-
78
- # Load the pre-trained model once
79
- print("Loading model from:", args.model_path)
80
- loaded_model = tf.keras.models.load_model(
81
- args.model_path, custom_objects=custom_objects
82
- )
83
- print("Model loaded successfully.")
84
-
85
- # Load English and Vietnamese tokenizers once
86
- en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers()
87
-
88
- # Run the translation for all provided sentences
89
- main(sentences=args.sentence, model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
90
-
91
- # Interactive loop for additional translations
92
- while True:
93
- choice = input("Would you like to translate another sentence? (Y/n): ").strip().lower()
94
- if choice in ['no', 'n', 'quit', 'q']:
95
- print("Exiting the program.")
96
- break
97
- elif choice in ['yes', 'y']:
98
- new_sentence = input("Enter an English sentence to translate: ").strip()
99
- if new_sentence:
100
- main(sentences=[new_sentence], model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
101
- else:
102
- print("No sentence provided. Please try again.")
103
- else:
104
  print("Invalid input. Please enter 'y' or 'n'.")
 
1
+ import tensorflow as tf
2
+ from translator import Translator
3
+ from utils import tokenizer_utils
4
+ from utils.preprocessing import input_processing, output_processing
5
+ from models.transformer import Transformer
6
+ from models.encoder import Encoder
7
+ from models.decoder import Decoder
8
+ from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network
9
+ from models.utils import masked_loss, masked_accuracy
10
+ import argparse
11
+
12
+ def main(sentences: list, model: tf.keras.Model, en_tokenizer, vi_tokenizer) -> None:
13
+ """
14
+ Translates input English sentences to Vietnamese using a pre-trained model.
15
+
16
+ Args:
17
+ sentences (list): List of English sentences to translate.
18
+ model (tf.keras.Model): The pre-trained translation model.
19
+ en_tokenizer: English tokenizer.
20
+ vi_tokenizer: Vietnamese tokenizer.
21
+ """
22
+ # Initialize the translator with tokenizers and the model
23
+ translator = Translator(en_tokenizer, vi_tokenizer, model)
24
+
25
+ # Process and translate each sentence
26
+ for sentence in sentences:
27
+ processed_sentence = input_processing(sentence)
28
+ translated_text = translator(processed_sentence)
29
+ translated_text = output_processing(translated_text)
30
+
31
+ # Display the input and translated text
32
+ print("Input:", processed_sentence)
33
+ print("Translated:", translated_text)
34
+ print("-" * 50)
35
+
36
+ if __name__ == "__main__":
37
+ # Set up argument parser
38
+ parser = argparse.ArgumentParser(
39
+ description="Translate English sentences to Vietnamese using a pre-trained transformer model.",
40
+ epilog="Example: python translate.py --sentence 'Hello, world!' --sentence 'The sun is shining.'"
41
+ )
42
+ parser.add_argument(
43
+ "--sentence",
44
+ type=str,
45
+ nargs="*",
46
+ default=[
47
+ (
48
+ "For at least six centuries, residents along a lake in the mountains of central Japan "
49
+ "have marked the depth of winter by celebrating the return of a natural phenomenon "
50
+ "once revered as the trail of a wandering god."
51
+ )
52
+ ],
53
+ help="One or more English sentences to translate (default: provided example sentence)"
54
+ )
55
+ parser.add_argument(
56
+ "--model_path",
57
+ type=str,
58
+ default="ckpts/en_vi_translation.keras",
59
+ help="Path to the pre-trained model file (default: ckpts/en_vi_translation.keras)"
60
+ )
61
+
62
+ # Parse arguments
63
+ args = parser.parse_args()
64
+
65
+ # Define custom objects required for loading the model
66
+ custom_objects = {
67
+ "Transformer": Transformer,
68
+ "Encoder": Encoder,
69
+ "Decoder": Decoder,
70
+ "EncoderLayer": EncoderLayer,
71
+ "DecoderLayer": DecoderLayer,
72
+ "MultiHeadAttention": MultiHeadAttention,
73
+ "point_wise_feed_forward_network": point_wise_feed_forward_network,
74
+ "masked_loss": masked_loss,
75
+ "masked_accuracy": masked_accuracy,
76
+ }
77
+
78
+ # Load the pre-trained model once
79
+ print("Loading model from:", args.model_path)
80
+ loaded_model = tf.keras.models.load_model(
81
+ args.model_path, custom_objects=custom_objects
82
+ )
83
+ print("Model loaded successfully.")
84
+
85
+ # Load English and Vietnamese tokenizers once
86
+ en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers()
87
+
88
+ # Run the translation for all provided sentences
89
+ main(sentences=args.sentence, model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
90
+
91
+ # Interactive loop for additional translations
92
+ while True:
93
+ choice = input("Would you like to translate another sentence? (Y/n): ").strip().lower()
94
+ if choice in ['no', 'n', 'quit', 'q']:
95
+ print("Exiting the program.")
96
+ break
97
+ elif choice in ['yes', 'y']:
98
+ new_sentence = input("Enter an English sentence to translate: ").strip()
99
+ if new_sentence:
100
+ main(sentences=[new_sentence], model=loaded_model, en_tokenizer=en_tokenizer, vi_tokenizer=vi_tokenizer)
101
+ else:
102
+ print("No sentence provided. Please try again.")
103
+ else:
104
  print("Invalid input. Please enter 'y' or 'n'.")