Ababababababbababa zhangj726 commited on
Commit
9cddb79
·
0 Parent(s):

Duplicate from zhangj726/poem_generation

Browse files

Co-authored-by: Jing Zhang <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. .gitignore +22 -0
  3. .idea/.gitignore +3 -0
  4. .idea/.name +1 -0
  5. .idea/ea_lstm.iml +8 -0
  6. .idea/inspectionProfiles/Project_Default.xml +20 -0
  7. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  8. .idea/misc.xml +4 -0
  9. .idea/modules.xml +8 -0
  10. .idea/nlp.iml +12 -0
  11. .idea/vcs.xml +6 -0
  12. .idea/workspace.xml +44 -0
  13. README.md +26 -0
  14. __pycache__/inference.cpython-38.pyc +0 -0
  15. app.py +14 -0
  16. data/org_poetry.txt +0 -0
  17. data/poetry.txt +0 -0
  18. data/poetry_7.txt +0 -0
  19. data/split_poetry.txt +0 -0
  20. data/word_vec.pkl +3 -0
  21. example.jpg +0 -0
  22. inference.py +108 -0
  23. requirements.txt +3 -0
  24. save_models/.keep +0 -0
  25. save_models/GRU_25.pth +3 -0
  26. save_models/GRU_50.pth +3 -0
  27. save_models/lstm_25.pth +3 -0
  28. save_models/lstm_50.pth +3 -0
  29. save_models/transformer_100.pth +3 -0
  30. scripts/lstm_infer.sh +0 -0
  31. scripts/lstm_train.sh +0 -0
  32. src/__init__.py +0 -0
  33. src/__pycache__/__init__.cpython-38.pyc +0 -0
  34. src/__pycache__/__init__.cpython-39.pyc +0 -0
  35. src/apis/__init__.py +0 -0
  36. src/apis/__pycache__/__init__.cpython-39.pyc +0 -0
  37. src/apis/__pycache__/inference.cpython-39.pyc +0 -0
  38. src/apis/__pycache__/train.cpython-39.pyc +0 -0
  39. src/apis/evaluate.py +23 -0
  40. src/apis/train.py +68 -0
  41. src/datasets/__init__.py +0 -0
  42. src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  43. src/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  44. src/datasets/__pycache__/dataloader.cpython-38.pyc +0 -0
  45. src/datasets/__pycache__/dataloader.cpython-39.pyc +0 -0
  46. src/datasets/dataloader.py +115 -0
  47. src/models/LSTM/__init__.py +0 -0
  48. src/models/LSTM/__pycache__/__init__.cpython-38.pyc +0 -0
  49. src/models/LSTM/__pycache__/__init__.cpython-39.pyc +0 -0
  50. src/models/LSTM/__pycache__/algorithm.cpython-39.pyc +0 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch
2
+ /.torch
3
+
4
+ # Data files
5
+ *.csv
6
+ *.json
7
+ *.tsv
8
+
9
+ # Model files
10
+ *.ckpt
11
+ *.pth
12
+ *.pkl
13
+
14
+ # Logs and checkpoints
15
+ logs/
16
+ checkpoints/
17
+
18
+ # Secondary files
19
+ *.pyc
20
+ __pycache__/
21
+ .DS_Store
22
+
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ inference.py
.idea/ea_lstm.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="7">
8
+ <item index="0" class="java.lang.String" itemvalue="easydict" />
9
+ <item index="1" class="java.lang.String" itemvalue="pandas" />
10
+ <item index="2" class="java.lang.String" itemvalue="matplotlib" />
11
+ <item index="3" class="java.lang.String" itemvalue="pillow" />
12
+ <item index="4" class="java.lang.String" itemvalue="mindspore" />
13
+ <item index="5" class="java.lang.String" itemvalue="setuptools" />
14
+ <item index="6" class="java.lang.String" itemvalue="numpy" />
15
+ </list>
16
+ </value>
17
+ </option>
18
+ </inspection_tool>
19
+ </profile>
20
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/ea_lstm.iml" filepath="$PROJECT_DIR$/.idea/ea_lstm.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/nlp.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
5
+ <option name="SHOW_DIALOG" value="false" />
6
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
+ <option name="LAST_RESOLUTION" value="IGNORE" />
9
+ </component>
10
+ <component name="MarkdownSettingsMigration">
11
+ <option name="stateVersion" value="1" />
12
+ </component>
13
+ <component name="ProjectId" id="2OyFWrJQpFYHFKgf87OgmRH5Jtu" />
14
+ <component name="ProjectViewState">
15
+ <option name="hideEmptyMiddlePackages" value="true" />
16
+ <option name="showLibraryContents" value="true" />
17
+ </component>
18
+ <component name="PropertiesComponent"><![CDATA[{
19
+ "keyToString": {
20
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
21
+ "RunOnceActivity.ShowReadmeOnStart": "true",
22
+ "last_opened_file_path": "C:/Users/LENOVO/PycharmProjects/lstm"
23
+ }
24
+ }]]></component>
25
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
26
+ <component name="TaskManager">
27
+ <task active="true" id="Default" summary="Default task">
28
+ <changelist id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
29
+ <created>1682524950142</created>
30
+ <option name="number" value="Default" />
31
+ <option name="presentableId" value="Default" />
32
+ <updated>1682524950142</updated>
33
+ </task>
34
+ <servers />
35
+ </component>
36
+ <component name="XDebuggerManager">
37
+ <watches-manager>
38
+ <configuration name="PythonConfigurationType">
39
+ <watch expression="input_eval" />
40
+ <watch expression="word_2_index" />
41
+ </configuration>
42
+ </watches-manager>
43
+ </component>
44
+ </project>
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ duplicated_from: zhangj726/poem_generation
3
+ ---
4
+ # NLP Final Project
5
+
6
+ ```shell
7
+ ├── configs
8
+ ├── data
9
+ │   └── poetry.txt
10
+ ├── inference.py
11
+ ├── src
12
+ │   ├── apis
13
+ │   │   ├── evaluate.py
14
+ │   │   ├── inference.py
15
+ │   │   └── train.py
16
+ │   ├── datasets
17
+ │   │   └── dataloader.py
18
+ │   ├── models
19
+ │   │   └── EA-LSTM
20
+ │   │   ├── algorithm.py
21
+ │   │   └── model.py
22
+ │   └── utils
23
+ │   └── utils.py
24
+ ├── test.py
25
+ └── train.py
26
+ ```
__pycache__/inference.cpython-38.pyc ADDED
Binary file (2.88 kB). View file
 
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !/user/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import gradio
5
+ from inference import infer
6
+
7
+
8
+
9
+
10
+ INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU"]),"text"], outputs=["text"], title="Poetry Generation",
11
+ description="Choose a model and input the poetic head to generate a acrostic",
12
+ thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
13
+
14
+ INTERFACE.launch(inbrowser=True)
data/org_poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/poetry_7.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/split_poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/word_vec.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1164cfc2e28ef6ecbb1a04734e7268238b4841667f13d6cb4c42e27717dd4575
3
+ size 6339344
example.jpg ADDED
inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import numpy as np
4
+ from src.models.LSTM.model import Poetry_Model_lstm
5
+ from src.datasets.dataloader import train_vec
6
+ from src.utils.utils import make_cuda
7
+
8
+
9
+ def parse_arguments():
10
+ # argument parsing
11
+ parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
12
+ parser.add_argument('--model', type=str, default='lstm',
13
+ help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
14
+ parser.add_argument('--Word2Vec', default=True)
15
+ parser.add_argument('--strict_dataset', default=False, help="strict dataset")
16
+ parser.add_argument('--n_hidden', type=int, default=128)
17
+
18
+ parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth')
19
+
20
+ return parser.parse_args()
21
+
22
+
23
+ def generate_poetry(model, head_string, w1, word_2_index, index_2_word):
24
+ print("藏头诗生成中...., {}".format(head_string))
25
+ poem = ""
26
+ # 以句子的每一个字为开头生成诗句
27
+ for head in head_string:
28
+ if head not in word_2_index:
29
+ print("抱歉,不能生成以{}开头的诗".format(head))
30
+ return
31
+
32
+ sentence = head
33
+ max_sent_len = 20
34
+
35
+ h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
36
+ c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
37
+
38
+ input_eval = word_2_index[head]
39
+ for i in range(max_sent_len):
40
+ if args.Word2Vec:
41
+ word_embedding = torch.tensor(w1[input_eval][None][None])
42
+ else:
43
+ word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0)
44
+ pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
45
+ char_generated = index_2_word[int(torch.argmax(pre))]
46
+
47
+ if char_generated == '。':
48
+ break
49
+ # 以新生成的字为输入继续向下生成
50
+ input_eval = word_2_index[char_generated]
51
+ sentence += char_generated
52
+
53
+ poem += '\n' + sentence
54
+
55
+ return poem
56
+
57
+ def infer(model,string):
58
+ args = parse_arguments()
59
+ all_data, (w1, word_2_index, index_2_word) = train_vec()
60
+ args.word_size, args.embedding_num = w1.shape
61
+ # string = input("诗头:")
62
+ # string = '自然语言'
63
+ args.model=model
64
+ if args.model == 'lstm':
65
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
66
+ args.save_path = 'save_models/lstm_50.pth'
67
+ elif args.model == 'GRU':
68
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
69
+ args.save_path = 'save_models/GRU_50.pth'
70
+ elif args.model == 'Seq2Seq':
71
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
72
+ elif args.model == 'Transformer':
73
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
74
+ elif args.model == 'GPT-2':
75
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
76
+ else:
77
+ print("Please choose a model!\n")
78
+
79
+ model.load_state_dict(torch.load(args.save_path))
80
+ model = make_cuda(model)
81
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
82
+ return poem
83
+
84
+
85
+ if __name__ == '__main__':
86
+ args = parse_arguments()
87
+ all_data, (w1, word_2_index, index_2_word) = train_vec()
88
+ args.word_size, args.embedding_num = w1.shape
89
+ # string = input("诗头:")
90
+ string = '自然语言'
91
+
92
+ if args.model == 'lstm':
93
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
94
+ elif args.model == 'GRU':
95
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
96
+ elif args.model == 'Seq2Seq':
97
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
98
+ elif args.model == 'Transformer':
99
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
100
+ elif args.model == 'GPT-2':
101
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
102
+ else:
103
+ print("Please choose a model!\n")
104
+
105
+ model.load_state_dict(torch.load(args.save_path))
106
+ model = make_cuda(model)
107
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
108
+ print(poem)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.13.0
2
+ gradio==3.34.0
3
+ gensim==4.3.1
save_models/.keep ADDED
File without changes
save_models/GRU_25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bacf9a7ec329c6185098c1309ab28239b4c087b53832b3d18e5323831bfead23
3
+ size 10727391
save_models/GRU_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a8e83a733c023b35c44020e014bb72e2c1d05698eb782669c0e4d5a76d4590d
3
+ size 10727391
save_models/lstm_25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b064666ce02c63541dee4b6146d31ee8f7e784ee9c2811c9b9266aba6cc4193
3
+ size 10727391
save_models/lstm_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa157d970149c32b53b024a23ef8428e7b7e1702ed72d44152b568b085b1bfaa
3
+ size 10727391
save_models/transformer_100.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bbe153237c20ba6ec5e8ac8b55c8b420ec4cdf5bf0f46a8a5b68094a54996c3
3
+ size 26125257
scripts/lstm_infer.sh ADDED
File without changes
scripts/lstm_train.sh ADDED
File without changes
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (166 Bytes). View file
 
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
src/apis/__init__.py ADDED
File without changes
src/apis/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (151 Bytes). View file
 
src/apis/__pycache__/inference.cpython-39.pyc ADDED
Binary file (1.44 kB). View file
 
src/apis/__pycache__/train.cpython-39.pyc ADDED
Binary file (1.68 kB). View file
 
src/apis/evaluate.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from src.models.EA_LSTM.model import weightedLSTM
4
+ from src.datasets.dataloader import MyDataset, create_vocab
5
+
6
+
7
+ def test(args):
8
+ vocab, poetrys = create_vocab(args.data)
9
+ # 词汇表长度
10
+ args.vocab_size = len(vocab)
11
+ int2char = np.array(vocab)
12
+ valid_dataset = MyDataset(vocab, poetrys, args, train=False)
13
+
14
+ model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
15
+ model.load_state_dict(torch.load(args.save_path))
16
+
17
+ input_example_batch, target_example_batch = valid_dataset[0]
18
+ example_batch_predictions = model(input_example_batch)
19
+ predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
20
+ predicted_id = torch.squeeze(predicted_id, -1).numpy()
21
+ print("Input: \n", repr("".join(int2char[input_example_batch])))
22
+ print()
23
+ print("Predictions: \n", repr("".join(int2char[predicted_id])))
src/apis/train.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from src.utils.utils import make_cuda
7
+ from torch.nn import functional as F
8
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
9
+
10
+
11
+ def train(args, model, data_loader, initial=False):
12
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
13
+
14
+ model.train()
15
+ num_epochs = args.initial_epochs if initial else args.num_epochs
16
+
17
+ for epoch in range(num_epochs):
18
+ loss = 0
19
+ for step, (features, targets) in enumerate(data_loader):
20
+ features = make_cuda(features)
21
+ targets = make_cuda(targets)
22
+
23
+ optimizer.zero_grad()
24
+
25
+ pre, _ = model(features)
26
+ crs_loss = model.cross_entropy(pre, targets.reshape(-1))
27
+ loss += crs_loss.item()
28
+ crs_loss.backward()
29
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
30
+ optimizer.step()
31
+
32
+ # print step info
33
+ if (step + 1) % args.log_step == 0:
34
+ print("Epoch [%.3d/%.3d] Step [%.3d/%.3d]: CROSS_loss=%.4f, RCROSS_loss=%.4f"
35
+ % (epoch + 1,
36
+ num_epochs,
37
+ step + 1,
38
+ len(data_loader),
39
+ loss / args.log_step,
40
+ math.sqrt(loss / args.log_step)))
41
+ loss = 0
42
+
43
+ # Loss = []
44
+ # for step, (features, targets) in enumerate(valid_data_loader):
45
+ # features = make_cuda(features)
46
+ # targets = make_cuda(targets)
47
+ # model.eval()
48
+ # preds = model(features)
49
+ # valid_loss = CrossLoss(preds, targets)
50
+ # Loss.append(valid_loss)
51
+ # print("Valid loss: %.3d\n" % (np.mean(Loss)))
52
+
53
+ return model
54
+
55
+
56
+ def evaluate(args, model, data_loader):
57
+ model.eval()
58
+ loss = []
59
+ for step, (features, targets) in enumerate(data_loader):
60
+ features = make_cuda(features)
61
+ targets = make_cuda(targets)
62
+
63
+ pre, _ = model(features)
64
+ crs_loss = model.cross_entropy(pre, targets.reshape(-1))
65
+ loss.append(crs_loss.item())
66
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
67
+
68
+ print("loss=%.4f" % (np.mean(loss)))
src/datasets/__init__.py ADDED
File without changes
src/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
src/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (155 Bytes). View file
 
src/datasets/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (4.09 kB). View file
 
src/datasets/__pycache__/dataloader.cpython-39.pyc ADDED
Binary file (4.12 kB). View file
 
src/datasets/dataloader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from gensim.models.word2vec import Word2Vec
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ def padding(poetries, maxlen, pad):
11
+ batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
12
+ return batch_seq
13
+
14
+
15
+ # 输入向后滑一字符为target,即预测下一个字
16
+ def split_input_target(seq):
17
+ inputs = seq[:-1]
18
+ targets = seq[1:]
19
+ return inputs, targets
20
+
21
+
22
+ # 创建词汇表
23
+ def get_poetry(arg):
24
+ poetrys = []
25
+ if arg.Augmented_dataset:
26
+ path = arg.Augmented_data
27
+ else:
28
+ path = arg.data
29
+ with open(path, "r", encoding='UTF-8') as f:
30
+ for line in f:
31
+ try:
32
+ # line = line.decode('UTF-8')
33
+ line = line.strip(u'\n')
34
+ if arg.Augmented_dataset:
35
+ content = line.strip(u' ')
36
+ else:
37
+ title, content = line.strip(u' ').split(u':')
38
+ content = content.replace(u' ', u'')
39
+ if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
40
+ continue
41
+ if arg.strict_dataset:
42
+ if len(content) < 12 or len(content) > 79:
43
+ continue
44
+ else:
45
+ if len(content) < 5 or len(content) > 79:
46
+ continue
47
+ content = u'[' + content + u']'
48
+ poetrys.append(content)
49
+ except Exception as e:
50
+ pass
51
+
52
+ # 按诗的字数排序
53
+ poetrys = sorted(poetrys, key=lambda line: len(line))
54
+
55
+ with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
56
+ for poetry in poetrys:
57
+ poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
58
+ f.write(poetry)
59
+
60
+ return poetrys
61
+
62
+
63
+ # 切分文档
64
+ def split_text(poetrys):
65
+ with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
66
+ for poetry in poetrys:
67
+ poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
68
+ split_data = " ".join(poetry)
69
+ f.write(split_data)
70
+ return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
71
+
72
+
73
+ # 训练词向量
74
+ def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
75
+ param_file = "data/word_vec.pkl"
76
+ org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
77
+ if os.path.exists(split_file):
78
+ all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
79
+ else:
80
+ all_data_split = split_text().split("\n")
81
+
82
+ if os.path.exists(param_file):
83
+ return org_data, pickle.load(open(param_file, "rb"))
84
+
85
+ models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
86
+ pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
87
+ return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
88
+
89
+
90
+ class Poetry_Dataset(Dataset):
91
+ def __init__(self, w1, word_2_index, all_data, Word2Vec):
92
+ self.Word2Vec = Word2Vec
93
+ self.w1 = w1
94
+ self.word_2_index = word_2_index
95
+ word_size, embedding_num = w1.shape
96
+ self.embedding = nn.Embedding(word_size, embedding_num)
97
+ # 最长句子长度
98
+ maxlen = max([len(seq) for seq in all_data])
99
+ pad = ' '
100
+ self.all_data = padding(all_data[:-1], maxlen, pad)
101
+
102
+ def __getitem__(self, index):
103
+ a_poetry = self.all_data[index]
104
+
105
+ a_poetry_index = [self.word_2_index[i] for i in a_poetry]
106
+ xs, ys = split_input_target(a_poetry_index)
107
+ if self.Word2Vec:
108
+ xs_embedding = self.w1[xs]
109
+ else:
110
+ xs_embedding = np.array(xs)
111
+
112
+ return xs_embedding, np.array(ys).astype(np.int64)
113
+
114
+ def __len__(self):
115
+ return len(self.all_data)
src/models/LSTM/__init__.py ADDED
File without changes
src/models/LSTM/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
src/models/LSTM/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (161 Bytes). View file
 
src/models/LSTM/__pycache__/algorithm.cpython-39.pyc ADDED
Binary file (4.99 kB). View file