Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pickle | |
| import os | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever | |
| from langchain.document_transformers import EmbeddingsRedundantFilter | |
| from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from analysis import calculate_word_overlaps, calculate_duplication_rate, cosine_similarity_score, jaccard_similarity_score, display_similarity_results | |
| with open("docs_data.pkl", "rb") as file: | |
| docs = pickle.load(file) | |
| metadata_list = [] | |
| unique_metadata_list = [] | |
| seen = set() | |
| embeddings = HuggingFaceEmbeddings() | |
| vectorstore = FAISS.load_local("faiss_index", embeddings) | |
| retriever = vectorstore.as_retriever(search_type="similarity") | |
| splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") | |
| redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) | |
| relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5) | |
| pipeline_compressor = DocumentCompressorPipeline( | |
| transformers=[splitter, redundant_filter, relevant_filter] | |
| ) | |
| bm25_retriever = BM25Retriever.from_texts(docs) | |
| st.title("Document Retrieval App") | |
| vecotstore_k = st.number_input("Set k value for Dense Retriever:", value=5, min_value=1, step=1) | |
| bm25_k = st.number_input("Set k value for sparse Retriever:", value=2, min_value=1, step=1) | |
| retriever.search_kwargs["k"] = vecotstore_k | |
| bm25_retriever.k = bm25_k | |
| compressed_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) | |
| bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) | |
| query = st.text_input("Enter a query:", "what is a horizontal conflict") | |
| if st.button("Retrieve Documents"): | |
| compressed_ensemble_retriever = EnsembleRetriever(retrievers=[compressed_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) | |
| ensemble_retriever = EnsembleRetriever(retrievers=[retriever, bm25_retriever], weights=[0.5, 0.5]) | |
| with st.expander("Retrieved Documents"): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.header("Without Compression") | |
| normal_results = ensemble_retriever.get_relevant_documents(query) | |
| for doc in normal_results: | |
| st.write(doc.page_content) | |
| st.write("---") | |
| with col2: | |
| st.header("With Compression") | |
| compressed_results = compressed_ensemble_retriever.get_relevant_documents(query) | |
| for doc in compressed_results: | |
| st.write(doc.page_content) | |
| st.write("---") | |
| if hasattr(doc, 'metadata'): | |
| metadata = doc.metadata | |
| metadata_list.append(metadata) | |
| for metadata in metadata_list: | |
| metadata_tuple = tuple(metadata.items()) | |
| if metadata_tuple not in seen: | |
| unique_metadata_list.append(metadata) | |
| seen.add(metadata_tuple) | |
| st.write(unique_metadata_list) | |
| with st.expander("Analysis"): | |
| st.write("Analysis of Retrieval Results") | |
| total_words_normal = sum(len(doc.page_content.split()) for doc in normal_results) | |
| total_words_compressed = sum(len(doc.page_content.split()) for doc in compressed_results) | |
| reduction_percentage = ((total_words_normal - total_words_compressed) / total_words_normal) * 100 | |
| col1, col2 = st.columns(2) | |
| st.write(f"Total words in documents (Normal): {total_words_normal}") | |
| st.write(f"Total words in documents (Compressed): {total_words_compressed}") | |
| st.write(f"Reduction Percentage: {reduction_percentage:.2f}%") | |
| average_word_overlap_normal = calculate_word_overlaps([doc.page_content for doc in normal_results], query) | |
| average_word_overlap_compressed = calculate_word_overlaps([doc.page_content for doc in compressed_results], query) | |
| duplication_rate_normal = calculate_duplication_rate([doc.page_content for doc in normal_results]) | |
| duplication_rate_compressed = calculate_duplication_rate([doc.page_content for doc in compressed_results]) | |
| cosine_scores_normal = cosine_similarity_score([doc.page_content for doc in normal_results], query) | |
| jaccard_scores_normal = jaccard_similarity_score([doc.page_content for doc in normal_results], query) | |
| cosine_scores_compressed = cosine_similarity_score([doc.page_content for doc in compressed_results], query) | |
| jaccard_scores_compressed = jaccard_similarity_score([doc.page_content for doc in compressed_results], query) | |
| with col1: | |
| st.subheader("Normal") | |
| st.write(f"Average Word Overlap: {average_word_overlap_normal:.2f}") | |
| st.write(f"Duplication Rate: {duplication_rate_normal:.2%}") | |
| st.write("Results without Compression:") | |
| display_similarity_results(cosine_scores_normal, jaccard_scores_normal, "") | |
| with col2: | |
| st.subheader("Compressed") | |
| st.write(f"Average Word Overlap: {average_word_overlap_compressed:.2f}") | |
| st.write(f"Duplication Rate: {duplication_rate_compressed:.2%}") | |
| st.write("Results with Compression:") | |
| display_similarity_results(cosine_scores_compressed, jaccard_scores_compressed, "") |