Spaces:
Configuration error
Configuration error
Commit
·
9cddb79
0
Parent(s):
Duplicate from zhangj726/poem_generation
Browse filesCo-authored-by: Jing Zhang <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- .gitignore +22 -0
- .idea/.gitignore +3 -0
- .idea/.name +1 -0
- .idea/ea_lstm.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +20 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/nlp.iml +12 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +44 -0
- README.md +26 -0
- __pycache__/inference.cpython-38.pyc +0 -0
- app.py +14 -0
- data/org_poetry.txt +0 -0
- data/poetry.txt +0 -0
- data/poetry_7.txt +0 -0
- data/split_poetry.txt +0 -0
- data/word_vec.pkl +3 -0
- example.jpg +0 -0
- inference.py +108 -0
- requirements.txt +3 -0
- save_models/.keep +0 -0
- save_models/GRU_25.pth +3 -0
- save_models/GRU_50.pth +3 -0
- save_models/lstm_25.pth +3 -0
- save_models/lstm_50.pth +3 -0
- save_models/transformer_100.pth +3 -0
- scripts/lstm_infer.sh +0 -0
- scripts/lstm_train.sh +0 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-38.pyc +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__init__.py +0 -0
- src/apis/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__pycache__/inference.cpython-39.pyc +0 -0
- src/apis/__pycache__/train.cpython-39.pyc +0 -0
- src/apis/evaluate.py +23 -0
- src/apis/train.py +68 -0
- src/datasets/__init__.py +0 -0
- src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- src/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-38.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-39.pyc +0 -0
- src/datasets/dataloader.py +115 -0
- src/models/LSTM/__init__.py +0 -0
- src/models/LSTM/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/LSTM/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/LSTM/__pycache__/algorithm.cpython-39.pyc +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch
|
| 2 |
+
/.torch
|
| 3 |
+
|
| 4 |
+
# Data files
|
| 5 |
+
*.csv
|
| 6 |
+
*.json
|
| 7 |
+
*.tsv
|
| 8 |
+
|
| 9 |
+
# Model files
|
| 10 |
+
*.ckpt
|
| 11 |
+
*.pth
|
| 12 |
+
*.pkl
|
| 13 |
+
|
| 14 |
+
# Logs and checkpoints
|
| 15 |
+
logs/
|
| 16 |
+
checkpoints/
|
| 17 |
+
|
| 18 |
+
# Secondary files
|
| 19 |
+
*.pyc
|
| 20 |
+
__pycache__/
|
| 21 |
+
.DS_Store
|
| 22 |
+
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
.idea/.name
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
inference.py
|
.idea/ea_lstm.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<profile version="1.0">
|
| 3 |
+
<option name="myName" value="Project Default" />
|
| 4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
| 5 |
+
<option name="ignoredPackages">
|
| 6 |
+
<value>
|
| 7 |
+
<list size="7">
|
| 8 |
+
<item index="0" class="java.lang.String" itemvalue="easydict" />
|
| 9 |
+
<item index="1" class="java.lang.String" itemvalue="pandas" />
|
| 10 |
+
<item index="2" class="java.lang.String" itemvalue="matplotlib" />
|
| 11 |
+
<item index="3" class="java.lang.String" itemvalue="pillow" />
|
| 12 |
+
<item index="4" class="java.lang.String" itemvalue="mindspore" />
|
| 13 |
+
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
| 14 |
+
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
| 15 |
+
</list>
|
| 16 |
+
</value>
|
| 17 |
+
</option>
|
| 18 |
+
</inspection_tool>
|
| 19 |
+
</profile>
|
| 20 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
|
| 4 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ea_lstm.iml" filepath="$PROJECT_DIR$/.idea/ea_lstm.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/nlp.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="PLAIN" />
|
| 10 |
+
<option name="myDocStringFormat" value="Plain" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ChangeListManager">
|
| 4 |
+
<list default="true" id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
|
| 5 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 6 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 7 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 8 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 9 |
+
</component>
|
| 10 |
+
<component name="MarkdownSettingsMigration">
|
| 11 |
+
<option name="stateVersion" value="1" />
|
| 12 |
+
</component>
|
| 13 |
+
<component name="ProjectId" id="2OyFWrJQpFYHFKgf87OgmRH5Jtu" />
|
| 14 |
+
<component name="ProjectViewState">
|
| 15 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 16 |
+
<option name="showLibraryContents" value="true" />
|
| 17 |
+
</component>
|
| 18 |
+
<component name="PropertiesComponent"><![CDATA[{
|
| 19 |
+
"keyToString": {
|
| 20 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
| 21 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 22 |
+
"last_opened_file_path": "C:/Users/LENOVO/PycharmProjects/lstm"
|
| 23 |
+
}
|
| 24 |
+
}]]></component>
|
| 25 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
| 26 |
+
<component name="TaskManager">
|
| 27 |
+
<task active="true" id="Default" summary="Default task">
|
| 28 |
+
<changelist id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
|
| 29 |
+
<created>1682524950142</created>
|
| 30 |
+
<option name="number" value="Default" />
|
| 31 |
+
<option name="presentableId" value="Default" />
|
| 32 |
+
<updated>1682524950142</updated>
|
| 33 |
+
</task>
|
| 34 |
+
<servers />
|
| 35 |
+
</component>
|
| 36 |
+
<component name="XDebuggerManager">
|
| 37 |
+
<watches-manager>
|
| 38 |
+
<configuration name="PythonConfigurationType">
|
| 39 |
+
<watch expression="input_eval" />
|
| 40 |
+
<watch expression="word_2_index" />
|
| 41 |
+
</configuration>
|
| 42 |
+
</watches-manager>
|
| 43 |
+
</component>
|
| 44 |
+
</project>
|
README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
duplicated_from: zhangj726/poem_generation
|
| 3 |
+
---
|
| 4 |
+
# NLP Final Project
|
| 5 |
+
|
| 6 |
+
```shell
|
| 7 |
+
├── configs
|
| 8 |
+
├── data
|
| 9 |
+
│ └── poetry.txt
|
| 10 |
+
├── inference.py
|
| 11 |
+
├── src
|
| 12 |
+
│ ├── apis
|
| 13 |
+
│ │ ├── evaluate.py
|
| 14 |
+
│ │ ├── inference.py
|
| 15 |
+
│ │ └── train.py
|
| 16 |
+
│ ├── datasets
|
| 17 |
+
│ │ └── dataloader.py
|
| 18 |
+
│ ├── models
|
| 19 |
+
│ │ └── EA-LSTM
|
| 20 |
+
│ │ ├── algorithm.py
|
| 21 |
+
│ │ └── model.py
|
| 22 |
+
│ └── utils
|
| 23 |
+
│ └── utils.py
|
| 24 |
+
├── test.py
|
| 25 |
+
└── train.py
|
| 26 |
+
```
|
__pycache__/inference.cpython-38.pyc
ADDED
|
Binary file (2.88 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# !/user/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import gradio
|
| 5 |
+
from inference import infer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU"]),"text"], outputs=["text"], title="Poetry Generation",
|
| 11 |
+
description="Choose a model and input the poetic head to generate a acrostic",
|
| 12 |
+
thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
|
| 13 |
+
|
| 14 |
+
INTERFACE.launch(inbrowser=True)
|
data/org_poetry.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/poetry.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/poetry_7.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/split_poetry.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/word_vec.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1164cfc2e28ef6ecbb1a04734e7268238b4841667f13d6cb4c42e27717dd4575
|
| 3 |
+
size 6339344
|
example.jpg
ADDED
|
inference.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
from src.models.LSTM.model import Poetry_Model_lstm
|
| 5 |
+
from src.datasets.dataloader import train_vec
|
| 6 |
+
from src.utils.utils import make_cuda
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_arguments():
|
| 10 |
+
# argument parsing
|
| 11 |
+
parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
|
| 12 |
+
parser.add_argument('--model', type=str, default='lstm',
|
| 13 |
+
help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
|
| 14 |
+
parser.add_argument('--Word2Vec', default=True)
|
| 15 |
+
parser.add_argument('--strict_dataset', default=False, help="strict dataset")
|
| 16 |
+
parser.add_argument('--n_hidden', type=int, default=128)
|
| 17 |
+
|
| 18 |
+
parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth')
|
| 19 |
+
|
| 20 |
+
return parser.parse_args()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def generate_poetry(model, head_string, w1, word_2_index, index_2_word):
|
| 24 |
+
print("藏头诗生成中...., {}".format(head_string))
|
| 25 |
+
poem = ""
|
| 26 |
+
# 以句子的每一个字为开头生成诗句
|
| 27 |
+
for head in head_string:
|
| 28 |
+
if head not in word_2_index:
|
| 29 |
+
print("抱歉,不能生成以{}开头的诗".format(head))
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
sentence = head
|
| 33 |
+
max_sent_len = 20
|
| 34 |
+
|
| 35 |
+
h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
|
| 36 |
+
c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
|
| 37 |
+
|
| 38 |
+
input_eval = word_2_index[head]
|
| 39 |
+
for i in range(max_sent_len):
|
| 40 |
+
if args.Word2Vec:
|
| 41 |
+
word_embedding = torch.tensor(w1[input_eval][None][None])
|
| 42 |
+
else:
|
| 43 |
+
word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0)
|
| 44 |
+
pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
|
| 45 |
+
char_generated = index_2_word[int(torch.argmax(pre))]
|
| 46 |
+
|
| 47 |
+
if char_generated == '。':
|
| 48 |
+
break
|
| 49 |
+
# 以新生成的字为输入继续向下生成
|
| 50 |
+
input_eval = word_2_index[char_generated]
|
| 51 |
+
sentence += char_generated
|
| 52 |
+
|
| 53 |
+
poem += '\n' + sentence
|
| 54 |
+
|
| 55 |
+
return poem
|
| 56 |
+
|
| 57 |
+
def infer(model,string):
|
| 58 |
+
args = parse_arguments()
|
| 59 |
+
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
| 60 |
+
args.word_size, args.embedding_num = w1.shape
|
| 61 |
+
# string = input("诗头:")
|
| 62 |
+
# string = '自然语言'
|
| 63 |
+
args.model=model
|
| 64 |
+
if args.model == 'lstm':
|
| 65 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 66 |
+
args.save_path = 'save_models/lstm_50.pth'
|
| 67 |
+
elif args.model == 'GRU':
|
| 68 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 69 |
+
args.save_path = 'save_models/GRU_50.pth'
|
| 70 |
+
elif args.model == 'Seq2Seq':
|
| 71 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 72 |
+
elif args.model == 'Transformer':
|
| 73 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 74 |
+
elif args.model == 'GPT-2':
|
| 75 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 76 |
+
else:
|
| 77 |
+
print("Please choose a model!\n")
|
| 78 |
+
|
| 79 |
+
model.load_state_dict(torch.load(args.save_path))
|
| 80 |
+
model = make_cuda(model)
|
| 81 |
+
poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
|
| 82 |
+
return poem
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == '__main__':
|
| 86 |
+
args = parse_arguments()
|
| 87 |
+
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
| 88 |
+
args.word_size, args.embedding_num = w1.shape
|
| 89 |
+
# string = input("诗头:")
|
| 90 |
+
string = '自然语言'
|
| 91 |
+
|
| 92 |
+
if args.model == 'lstm':
|
| 93 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 94 |
+
elif args.model == 'GRU':
|
| 95 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 96 |
+
elif args.model == 'Seq2Seq':
|
| 97 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 98 |
+
elif args.model == 'Transformer':
|
| 99 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 100 |
+
elif args.model == 'GPT-2':
|
| 101 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
| 102 |
+
else:
|
| 103 |
+
print("Please choose a model!\n")
|
| 104 |
+
|
| 105 |
+
model.load_state_dict(torch.load(args.save_path))
|
| 106 |
+
model = make_cuda(model)
|
| 107 |
+
poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
|
| 108 |
+
print(poem)
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.13.0
|
| 2 |
+
gradio==3.34.0
|
| 3 |
+
gensim==4.3.1
|
save_models/.keep
ADDED
|
File without changes
|
save_models/GRU_25.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bacf9a7ec329c6185098c1309ab28239b4c087b53832b3d18e5323831bfead23
|
| 3 |
+
size 10727391
|
save_models/GRU_50.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a8e83a733c023b35c44020e014bb72e2c1d05698eb782669c0e4d5a76d4590d
|
| 3 |
+
size 10727391
|
save_models/lstm_25.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b064666ce02c63541dee4b6146d31ee8f7e784ee9c2811c9b9266aba6cc4193
|
| 3 |
+
size 10727391
|
save_models/lstm_50.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa157d970149c32b53b024a23ef8428e7b7e1702ed72d44152b568b085b1bfaa
|
| 3 |
+
size 10727391
|
save_models/transformer_100.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bbe153237c20ba6ec5e8ac8b55c8b420ec4cdf5bf0f46a8a5b68094a54996c3
|
| 3 |
+
size 26125257
|
scripts/lstm_infer.sh
ADDED
|
File without changes
|
scripts/lstm_train.sh
ADDED
|
File without changes
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
src/apis/__init__.py
ADDED
|
File without changes
|
src/apis/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
src/apis/__pycache__/inference.cpython-39.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
src/apis/__pycache__/train.cpython-39.pyc
ADDED
|
Binary file (1.68 kB). View file
|
|
|
src/apis/evaluate.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from src.models.EA_LSTM.model import weightedLSTM
|
| 4 |
+
from src.datasets.dataloader import MyDataset, create_vocab
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test(args):
|
| 8 |
+
vocab, poetrys = create_vocab(args.data)
|
| 9 |
+
# 词汇表长度
|
| 10 |
+
args.vocab_size = len(vocab)
|
| 11 |
+
int2char = np.array(vocab)
|
| 12 |
+
valid_dataset = MyDataset(vocab, poetrys, args, train=False)
|
| 13 |
+
|
| 14 |
+
model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
|
| 15 |
+
model.load_state_dict(torch.load(args.save_path))
|
| 16 |
+
|
| 17 |
+
input_example_batch, target_example_batch = valid_dataset[0]
|
| 18 |
+
example_batch_predictions = model(input_example_batch)
|
| 19 |
+
predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
|
| 20 |
+
predicted_id = torch.squeeze(predicted_id, -1).numpy()
|
| 21 |
+
print("Input: \n", repr("".join(int2char[input_example_batch])))
|
| 22 |
+
print()
|
| 23 |
+
print("Predictions: \n", repr("".join(int2char[predicted_id])))
|
src/apis/train.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from src.utils.utils import make_cuda
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def train(args, model, data_loader, initial=False):
|
| 12 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
| 13 |
+
|
| 14 |
+
model.train()
|
| 15 |
+
num_epochs = args.initial_epochs if initial else args.num_epochs
|
| 16 |
+
|
| 17 |
+
for epoch in range(num_epochs):
|
| 18 |
+
loss = 0
|
| 19 |
+
for step, (features, targets) in enumerate(data_loader):
|
| 20 |
+
features = make_cuda(features)
|
| 21 |
+
targets = make_cuda(targets)
|
| 22 |
+
|
| 23 |
+
optimizer.zero_grad()
|
| 24 |
+
|
| 25 |
+
pre, _ = model(features)
|
| 26 |
+
crs_loss = model.cross_entropy(pre, targets.reshape(-1))
|
| 27 |
+
loss += crs_loss.item()
|
| 28 |
+
crs_loss.backward()
|
| 29 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 30 |
+
optimizer.step()
|
| 31 |
+
|
| 32 |
+
# print step info
|
| 33 |
+
if (step + 1) % args.log_step == 0:
|
| 34 |
+
print("Epoch [%.3d/%.3d] Step [%.3d/%.3d]: CROSS_loss=%.4f, RCROSS_loss=%.4f"
|
| 35 |
+
% (epoch + 1,
|
| 36 |
+
num_epochs,
|
| 37 |
+
step + 1,
|
| 38 |
+
len(data_loader),
|
| 39 |
+
loss / args.log_step,
|
| 40 |
+
math.sqrt(loss / args.log_step)))
|
| 41 |
+
loss = 0
|
| 42 |
+
|
| 43 |
+
# Loss = []
|
| 44 |
+
# for step, (features, targets) in enumerate(valid_data_loader):
|
| 45 |
+
# features = make_cuda(features)
|
| 46 |
+
# targets = make_cuda(targets)
|
| 47 |
+
# model.eval()
|
| 48 |
+
# preds = model(features)
|
| 49 |
+
# valid_loss = CrossLoss(preds, targets)
|
| 50 |
+
# Loss.append(valid_loss)
|
| 51 |
+
# print("Valid loss: %.3d\n" % (np.mean(Loss)))
|
| 52 |
+
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def evaluate(args, model, data_loader):
|
| 57 |
+
model.eval()
|
| 58 |
+
loss = []
|
| 59 |
+
for step, (features, targets) in enumerate(data_loader):
|
| 60 |
+
features = make_cuda(features)
|
| 61 |
+
targets = make_cuda(targets)
|
| 62 |
+
|
| 63 |
+
pre, _ = model(features)
|
| 64 |
+
crs_loss = model.cross_entropy(pre, targets.reshape(-1))
|
| 65 |
+
loss.append(crs_loss.item())
|
| 66 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 67 |
+
|
| 68 |
+
print("loss=%.4f" % (np.mean(loss)))
|
src/datasets/__init__.py
ADDED
|
File without changes
|
src/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
src/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (155 Bytes). View file
|
|
|
src/datasets/__pycache__/dataloader.cpython-38.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
src/datasets/__pycache__/dataloader.cpython-39.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
src/datasets/dataloader.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from gensim.models.word2vec import Word2Vec
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def padding(poetries, maxlen, pad):
|
| 11 |
+
batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
|
| 12 |
+
return batch_seq
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 输入向后滑一字符为target,即预测下一个字
|
| 16 |
+
def split_input_target(seq):
|
| 17 |
+
inputs = seq[:-1]
|
| 18 |
+
targets = seq[1:]
|
| 19 |
+
return inputs, targets
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# 创建词汇表
|
| 23 |
+
def get_poetry(arg):
|
| 24 |
+
poetrys = []
|
| 25 |
+
if arg.Augmented_dataset:
|
| 26 |
+
path = arg.Augmented_data
|
| 27 |
+
else:
|
| 28 |
+
path = arg.data
|
| 29 |
+
with open(path, "r", encoding='UTF-8') as f:
|
| 30 |
+
for line in f:
|
| 31 |
+
try:
|
| 32 |
+
# line = line.decode('UTF-8')
|
| 33 |
+
line = line.strip(u'\n')
|
| 34 |
+
if arg.Augmented_dataset:
|
| 35 |
+
content = line.strip(u' ')
|
| 36 |
+
else:
|
| 37 |
+
title, content = line.strip(u' ').split(u':')
|
| 38 |
+
content = content.replace(u' ', u'')
|
| 39 |
+
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
|
| 40 |
+
continue
|
| 41 |
+
if arg.strict_dataset:
|
| 42 |
+
if len(content) < 12 or len(content) > 79:
|
| 43 |
+
continue
|
| 44 |
+
else:
|
| 45 |
+
if len(content) < 5 or len(content) > 79:
|
| 46 |
+
continue
|
| 47 |
+
content = u'[' + content + u']'
|
| 48 |
+
poetrys.append(content)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
# 按诗的字数排序
|
| 53 |
+
poetrys = sorted(poetrys, key=lambda line: len(line))
|
| 54 |
+
|
| 55 |
+
with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
|
| 56 |
+
for poetry in poetrys:
|
| 57 |
+
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
|
| 58 |
+
f.write(poetry)
|
| 59 |
+
|
| 60 |
+
return poetrys
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# 切分文档
|
| 64 |
+
def split_text(poetrys):
|
| 65 |
+
with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
|
| 66 |
+
for poetry in poetrys:
|
| 67 |
+
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
|
| 68 |
+
split_data = " ".join(poetry)
|
| 69 |
+
f.write(split_data)
|
| 70 |
+
return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# 训练词向量
|
| 74 |
+
def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
|
| 75 |
+
param_file = "data/word_vec.pkl"
|
| 76 |
+
org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
|
| 77 |
+
if os.path.exists(split_file):
|
| 78 |
+
all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
|
| 79 |
+
else:
|
| 80 |
+
all_data_split = split_text().split("\n")
|
| 81 |
+
|
| 82 |
+
if os.path.exists(param_file):
|
| 83 |
+
return org_data, pickle.load(open(param_file, "rb"))
|
| 84 |
+
|
| 85 |
+
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
|
| 86 |
+
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
|
| 87 |
+
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Poetry_Dataset(Dataset):
|
| 91 |
+
def __init__(self, w1, word_2_index, all_data, Word2Vec):
|
| 92 |
+
self.Word2Vec = Word2Vec
|
| 93 |
+
self.w1 = w1
|
| 94 |
+
self.word_2_index = word_2_index
|
| 95 |
+
word_size, embedding_num = w1.shape
|
| 96 |
+
self.embedding = nn.Embedding(word_size, embedding_num)
|
| 97 |
+
# 最长句子长度
|
| 98 |
+
maxlen = max([len(seq) for seq in all_data])
|
| 99 |
+
pad = ' '
|
| 100 |
+
self.all_data = padding(all_data[:-1], maxlen, pad)
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, index):
|
| 103 |
+
a_poetry = self.all_data[index]
|
| 104 |
+
|
| 105 |
+
a_poetry_index = [self.word_2_index[i] for i in a_poetry]
|
| 106 |
+
xs, ys = split_input_target(a_poetry_index)
|
| 107 |
+
if self.Word2Vec:
|
| 108 |
+
xs_embedding = self.w1[xs]
|
| 109 |
+
else:
|
| 110 |
+
xs_embedding = np.array(xs)
|
| 111 |
+
|
| 112 |
+
return xs_embedding, np.array(ys).astype(np.int64)
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return len(self.all_data)
|
src/models/LSTM/__init__.py
ADDED
|
File without changes
|
src/models/LSTM/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
src/models/LSTM/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
src/models/LSTM/__pycache__/algorithm.cpython-39.pyc
ADDED
|
Binary file (4.99 kB). View file
|
|
|