Spaces:
Runtime error
Runtime error
| # https://github.com/langchain-ai/langchain/issues/8623 | |
| from langchain.schema.retriever import BaseRetriever, Document | |
| from langchain.vectorstores import VectorStore | |
| from langchain.vectorstores import Chroma | |
| from typing import List | |
| ## The idea that some documents are summaries so easier to exploit | |
| SUMMARY_TYPES = [] | |
| class QARetriever(BaseRetriever): | |
| vectorstore: VectorStore | |
| domains: list = [] | |
| threshold: float = 22 | |
| k_summary: int = 0 | |
| k_total: int = 10 | |
| namespace: str = "vectors" | |
| def get_relevant_documents(self, query: str) -> List[Document]: | |
| assert isinstance(self.domains, list) | |
| assert self.k_total > self.k_summary, "k_total should be greater than k_summary" | |
| # Prepare base search kwargs | |
| filters = {} | |
| if len(self.domains): | |
| filters["domain"] = {"$in": self.domains} | |
| if self.k_summary > 0: | |
| # Search for k_summary documents in the summaries dataset | |
| filters_summaries = {**filters} | |
| if len(SUMMARY_TYPES): | |
| filters_summaries = { | |
| **filters_summaries, | |
| "report_type": {"$in": SUMMARY_TYPES}, | |
| } | |
| docs_summaries = self.vectorstore.similarity_search_with_score( | |
| query=query, | |
| namespace=self.namespace, | |
| filter=self.format_filter(filters_summaries), | |
| k=self.k_summary, | |
| ) | |
| docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] | |
| else: | |
| docs_summaries = [] | |
| # Search for k_total - k_summary documents in the full reports dataset | |
| filters_full = {**filters} | |
| print("filters", filters) | |
| if len(SUMMARY_TYPES): | |
| filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}} | |
| k_full = self.k_total - len(docs_summaries) | |
| docs_full = self.vectorstore.similarity_search_with_score( | |
| query=query, | |
| namespace=self.namespace, | |
| filter=self.format_filter(filters_full), | |
| k=k_full, | |
| ) | |
| # Concatenate documents | |
| docs = docs_summaries + docs_full | |
| # Filter if scores are below threshold | |
| docs = [x for x in docs if x[1] > self.threshold] | |
| # Add score to metadata | |
| results = [] | |
| for i, (doc, score) in enumerate(docs): | |
| doc.metadata["similarity_score"] = score | |
| doc.metadata["content"] = doc.page_content | |
| doc.metadata["page_number"] = int(doc.metadata["page_number"]) | |
| doc.page_content = ( | |
| f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" | |
| ) | |
| results.append(doc) | |
| return results | |
| def format_filter(self, filters): | |
| # https://docs.trychroma.com/usage-guide#using-logical-operators | |
| if isinstance(self.vectorstore, Chroma): | |
| if len(filters) <= 1: | |
| return filters | |
| and_filters = [] | |
| for field, condition in filters.items(): | |
| and_filters.append({field: condition}) | |
| return {"$and": and_filters} | |
| return filters | |