Spaces:
Sleeping
Sleeping
| # set path | |
| import glob, os, sys; | |
| sys.path.append('../utils') | |
| #import needed libraries | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from st_aggrid import AgGrid | |
| from st_aggrid.shared import ColumnsAutoSizeMode | |
| from utils.vulnerability_classifier import vulnerability_classification | |
| from utils.vulnerability_classifier import runPreprocessingPipeline, load_Classifier | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| from utils.checkconfig import getconfig | |
| # Declare all the necessary variables | |
| config = getconfig('paramconfig.cfg') | |
| model_name = config.get('vulnerability','MODEL') | |
| split_by = config.get('vulnerability','SPLIT_BY') | |
| split_length = int(config.get('vulnerability','SPLIT_LENGTH')) | |
| split_overlap = int(config.get('vulnerability','SPLIT_OVERLAP')) | |
| remove_punc = bool(int(config.get('vulnerability','REMOVE_PUNC'))) | |
| split_respect_sentence_boundary = bool(int(config.get('vulnerability','RESPECT_SENTENCE_BOUNDARY'))) | |
| threshold = float(config.get('vulnerability','THRESHOLD')) | |
| top_n = int(config.get('vulnerability','TOP_KEY')) | |
| def app(): | |
| #### APP INFO ##### | |
| with st.container(): | |
| st.markdown("<h1 style='text-align: center; color: black;'> Vulnerability Classification </h1>", unsafe_allow_html=True) | |
| st.write(' ') | |
| st.write(' ') | |
| with st.expander("ℹ️ - About this app", expanded=False): | |
| st.write( | |
| """ | |
| The *Vulnerability Indicator* app is an easy-to-use interface built \ | |
| in Streamlit for analyzing policy documents with respect to SDG \ | |
| Classification for the paragraphs/texts in the document and \ | |
| extracting the keyphrase per SDG label - developed by GIZ Data \ | |
| and the Sustainable Development Solution Network. \n | |
| """) | |
| st.write("""**Document Processing:** The Uploaded/Selected document is \ | |
| automatically cleaned and split into paragraphs with a maximum \ | |
| length of 120 words using a Haystack preprocessing pipeline. The \ | |
| length of 120 is an empirical value which should reflect the length \ | |
| of a “context” and should limit the paragraph length deviation. \ | |
| However, since we want to respect the sentence boundary the limit \ | |
| can breach and hence this limit of 120 is tentative. \n | |
| """) | |
| st.write("""**Vulnerability cLassification:** The application assigns paragraphs \ | |
| to 18 different vulnerable groups in the climate context.\ | |
| Each paragraph is assigned to one vulnerable group only. Again, the results are \ | |
| displayed in a summary table including the vulnerability label, a \ | |
| relevancy score highlighted through a green color shading, and the \ | |
| respective text of the analyzed paragraph. Additionally, a pie \ | |
| chart with a blue color shading is displayed which illustrates the \ | |
| three most prominent groups mentioned in the document. Training data has been \ | |
| collected manually from different policy documents and been assigned to the groups. \ | |
| The summary table only displays \ | |
| paragraphs with a calculated relevancy score above 85%. \n""") | |
| st.write("") | |
| st.write("") | |
| st.markdown("Some runtime metrics tested with cpu: Intel(R) Xeon(R) CPU @ 2.20GHz, memory: 13GB") | |
| col1,col2,col3,col4 = st.columns([2,2,4,4]) | |
| with col1: | |
| st.caption("Loading Time Classifier") | |
| # st.markdown('<div style="text-align: center;">12 sec</div>', unsafe_allow_html=True) | |
| st.write("12 sec") | |
| with col2: | |
| st.caption("OCR File processing") | |
| # st.markdown('<div style="text-align: center;">50 sec</div>', unsafe_allow_html=True) | |
| st.write("50 sec") | |
| with col3: | |
| st.caption("SDG Classification of 200 paragraphs(~ 35 pages)") | |
| # st.markdown('<div style="text-align: center;">120 sec</div>', unsafe_allow_html=True) | |
| st.write("120 sec") | |
| with col4: | |
| st.caption("Keyword extraction for 200 paragraphs(~ 35 pages)") | |
| # st.markdown('<div style="text-align: center;">3 sec</div>', unsafe_allow_html=True) | |
| st.write("3 sec") | |
| ### Main app code ### | |
| with st.container(): | |
| if st.button("RUN Vulnerability Analysis"): | |
| if 'filepath' in st.session_state: | |
| file_name = st.session_state['filename'] | |
| file_path = st.session_state['filepath'] | |
| st.write(file_name) | |
| st.write(file_path) | |
| classifier = load_Classifier(classifier_name=model_name) | |
| st.session_state['vulnerability_classifier'] = classifier | |
| all_documents = runPreprocessingPipeline(file_name= file_name, | |
| file_path= file_path, split_by= split_by, | |
| split_length= split_length, | |
| split_respect_sentence_boundary= split_respect_sentence_boundary, | |
| split_overlap= split_overlap, remove_punc= remove_punc) | |
| if len(all_documents['documents']) > 100: | |
| warning_msg = ": This might take sometime, please sit back and relax." | |
| else: | |
| warning_msg = "" | |
| with st.spinner("Running Classification{}".format(warning_msg)): | |
| df, x = vulnerability_classification(haystack_doc=all_documents['documents'], | |
| threshold= threshold) | |
| df = df.drop(['Relevancy'], axis = 1) | |
| vulnerability_labels = x.vulnerability.unique() | |
| textrank_keyword_list = [] | |
| for label in sdg_labels: | |
| vulnerability_data = " ".join(df[df.vulnerability == label].text.to_list()) | |
| textranklist_ = textrank(textdata=sdgdata, words= top_n) | |
| if len(textranklist_) > 0: | |
| textrank_keyword_list.append({'Vulnerability':label, 'TextRank Keywords':",".join(textranklist_)}) | |
| textrank_keywords_df = pd.DataFrame(textrank_keyword_list) | |
| plt.rcParams['font.size'] = 25 | |
| colors = plt.get_cmap('Blues')(np.linspace(0.2, 0.7, len(x))) | |
| # plot | |
| fig, ax = plt.subplots() | |
| ax.pie(x['count'], colors=colors, radius=2, center=(4, 4), | |
| wedgeprops={"linewidth": 1, "edgecolor": "white"}, | |
| textprops={'fontsize': 14}, | |
| frame=False,labels =list(x.SDG_Num), | |
| labeldistance=1.2) | |
| # fig.savefig('temp.png', bbox_inches='tight',dpi= 100) | |
| st.markdown("#### Anything related to Vulnerabilities? ####") | |
| c4, c5, c6 = st.columns([1,2,2]) | |
| with c5: | |
| st.pyplot(fig) | |
| with c6: | |
| labeldf = x['SDG_name'].values.tolist() | |
| labeldf = "<br>".join(labeldf) | |
| st.markdown(labeldf, unsafe_allow_html=True) | |
| st.write("") | |
| st.markdown("###### What keywords are present under vulnerability classified text? ######") | |
| AgGrid(textrank_keywords_df, reload_data = False, | |
| update_mode="value_changed", | |
| columns_auto_size_mode = ColumnsAutoSizeMode.FIT_CONTENTS) | |
| st.write("") | |
| st.markdown("###### Top few vulnerability Classified paragraph/text results ######") | |
| AgGrid(df, reload_data = False, update_mode="value_changed", | |
| columns_auto_size_mode = ColumnsAutoSizeMode.FIT_CONTENTS) | |
| else: | |
| st.info("🤔 No document found, please try to upload it at the sidebar!") | |
| logging.warning("Terminated as no document provided") | |