NLPGenius commited on
Commit
a7270f3
·
1 Parent(s): a97117b

Hybrid retrieval: semantic + BM25-style keyword fusion, lazy index, dedupe, robust fallbacks

Browse files
cve_factchecker/retriever.py CHANGED
@@ -1,6 +1,9 @@
1
  from __future__ import annotations
2
  import os
3
- from typing import List, Dict, Any
 
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.schema import Document
6
  try:
@@ -44,6 +47,18 @@ class VectorNewsRetriever:
44
  self.persist_directory = env_dir or persist_directory
45
  self.embeddings = build_embeddings()
46
  self.vector_store = self._initialize_vector_store()
 
 
 
 
 
 
 
 
 
 
 
 
47
  def _initialize_vector_store(self) -> Chroma:
48
  """Initialize vector store with proper error handling for permission issues."""
49
  # If no persist directory (failed all write tests), use in-memory
@@ -97,6 +112,10 @@ class VectorNewsRetriever:
97
  print(f"⚠️ Could not clear vector store: {e}")
98
  # Fallback: create new in-memory store
99
  self.vector_store = Chroma(embedding_function=self.embeddings, collection_name="news_articles_fresh")
 
 
 
 
100
 
101
  def store_articles_in_vector_db(self, articles: List[NewsArticle], clear_first: bool = False) -> None:
102
  if not articles:
@@ -204,14 +223,216 @@ class VectorNewsRetriever:
204
  except Exception as e:
205
  print(f"⚠️ Could not persist vector store: {e}")
206
  print(f"✅ Stored {len(docs)} chunks from {len(articles)} articles")
207
- def semantic_search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  try:
209
- # Guardrails on k to avoid heavy loads
210
- k = max(1, min(int(k or 5), 10))
211
- docs = self.vector_store.similarity_search(query, k=k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
  print(f"❌ Vector search failed: {e}")
214
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  results: List[Dict[str, Any]] = []
216
  seen_urls = set()
217
  for d in docs:
@@ -221,9 +442,17 @@ class VectorNewsRetriever:
221
  if content.startswith("Title: "):
222
  line = content.splitlines()[0]
223
  title = line.replace("Title: ", "").strip() or title
224
- url = meta.get("url", "")
225
  if url and url in seen_urls:
226
  continue
227
  seen_urls.add(url)
228
- results.append({"title": title, "content": content, "url": url, "source": meta.get("source", "Unknown"), "metadata": meta})
 
 
 
 
 
 
 
 
229
  return results
 
1
  from __future__ import annotations
2
  import os
3
+ import math
4
+ import re
5
+ import time
6
+ from typing import List, Dict, Any, Tuple, Optional
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.schema import Document
9
  try:
 
47
  self.persist_directory = env_dir or persist_directory
48
  self.embeddings = build_embeddings()
49
  self.vector_store = self._initialize_vector_store()
50
+ # Lightweight in-memory keyword index (lazy-built)
51
+ self._index_ready: bool = False
52
+ self._index_built_at: float = 0.0
53
+ self._N: int = 0 # number of docs
54
+ self._avgdl: float = 0.0
55
+ self._df: Dict[str, int] = {}
56
+ self._postings: Dict[str, Dict[str, int]] = {} # term -> {doc_id: tf}
57
+ self._doc_len: Dict[str, int] = {}
58
+ self._doc_meta: Dict[str, Dict[str, Any]] = {} # id -> {content, metadata}
59
+ self._stopwords = set(
60
+ "the a an and or of to in on for from by with without at as is are was were be been being this that those these it its their his her you your we our not no over under into about across more most least few many much may might should would could will can https http www com pk net org www.".split()
61
+ )
62
  def _initialize_vector_store(self) -> Chroma:
63
  """Initialize vector store with proper error handling for permission issues."""
64
  # If no persist directory (failed all write tests), use in-memory
 
112
  print(f"⚠️ Could not clear vector store: {e}")
113
  # Fallback: create new in-memory store
114
  self.vector_store = Chroma(embedding_function=self.embeddings, collection_name="news_articles_fresh")
115
+ # Invalidate keyword index after clear
116
+ self._index_ready = False
117
+ self._df.clear(); self._postings.clear(); self._doc_len.clear(); self._doc_meta.clear()
118
+ self._N = 0; self._avgdl = 0.0
119
 
120
  def store_articles_in_vector_db(self, articles: List[NewsArticle], clear_first: bool = False) -> None:
121
  if not articles:
 
223
  except Exception as e:
224
  print(f"⚠️ Could not persist vector store: {e}")
225
  print(f"✅ Stored {len(docs)} chunks from {len(articles)} articles")
226
+ # Invalidate index so it is rebuilt on next query
227
+ self._index_ready = False
228
+ self._df.clear(); self._postings.clear(); self._doc_len.clear(); self._doc_meta.clear()
229
+ self._N = 0; self._avgdl = 0.0
230
+ # -----------------------------
231
+ # Hybrid Retrieval Implementation
232
+ # -----------------------------
233
+ def _tokenize(self, text: str) -> List[str]:
234
+ text = text.lower()
235
+ # Keep alphanumerics as tokens
236
+ tokens = re.split(r"[^a-z0-9]+", text)
237
+ return [t for t in tokens if t and t not in self._stopwords and not t.isdigit()]
238
+
239
+ def _ensure_index(self) -> None:
240
+ if self._index_ready:
241
+ return
242
  try:
243
+ # Prefer direct collection access for efficiency
244
+ docs_data: Optional[Dict[str, Any]] = None
245
+ if hasattr(self.vector_store, "_collection") and self.vector_store._collection is not None: # type: ignore[attr-defined]
246
+ try:
247
+ docs_data = self.vector_store._collection.get(include=["ids", "documents", "metadatas"]) # type: ignore[attr-defined]
248
+ except Exception as e:
249
+ print(f"⚠️ Could not read collection directly: {e}")
250
+ if docs_data is None:
251
+ try:
252
+ docs_data = self.vector_store.get()
253
+ except Exception as e:
254
+ print(f"⚠️ Could not fetch documents for index: {e}")
255
+ self._index_ready = False
256
+ return
257
+ ids = docs_data.get("ids", []) or []
258
+ documents = docs_data.get("documents", []) or []
259
+ metadatas = docs_data.get("metadatas", []) or []
260
+ N = len(ids)
261
+ if N == 0:
262
+ self._index_ready = True
263
+ self._N = 0
264
+ self._avgdl = 0.0
265
+ return
266
+ df: Dict[str, int] = {}
267
+ postings: Dict[str, Dict[str, int]] = {}
268
+ doc_len: Dict[str, int] = {}
269
+ doc_meta: Dict[str, Dict[str, Any]] = {}
270
+ total_len = 0
271
+ for doc_id, content, meta in zip(ids, documents, metadatas):
272
+ content = content or ""
273
+ tokens = self._tokenize(content)
274
+ total_len += len(tokens)
275
+ doc_len[doc_id] = len(tokens)
276
+ # compute term frequencies
277
+ tf: Dict[str, int] = {}
278
+ for tok in tokens:
279
+ tf[tok] = tf.get(tok, 0) + 1
280
+ # update postings and df
281
+ for tok, freq in tf.items():
282
+ if tok not in postings:
283
+ postings[tok] = {doc_id: freq}
284
+ df[tok] = 1
285
+ else:
286
+ postings[tok][doc_id] = freq
287
+ df[tok] = df.get(tok, 0) + 1
288
+ # store meta for reconstruction
289
+ doc_meta[doc_id] = {
290
+ "content": content,
291
+ "metadata": meta or {},
292
+ }
293
+ self._N = N
294
+ self._avgdl = (total_len / N) if N else 0.0
295
+ self._df = df
296
+ self._postings = postings
297
+ self._doc_len = doc_len
298
+ self._doc_meta = doc_meta
299
+ self._index_ready = True
300
+ self._index_built_at = time.time()
301
+ # print(f"🔎 Keyword index built for {N} docs (avgdl={self._avgdl:.1f})")
302
+ except Exception as e:
303
+ print(f"⚠️ Failed building keyword index: {e}")
304
+ self._index_ready = False
305
+
306
+ def _bm25_scores(self, query: str) -> Dict[str, float]:
307
+ self._ensure_index()
308
+ if not self._index_ready or self._N == 0:
309
+ return {}
310
+ q_tokens = self._tokenize(query)
311
+ if not q_tokens:
312
+ return {}
313
+ # collect candidate docs (union of postings for query tokens)
314
+ candidate_docs: Dict[str, float] = {}
315
+ k1, b = 1.5, 0.75
316
+ for tok in q_tokens:
317
+ df = self._df.get(tok, 0)
318
+ postings = self._postings.get(tok)
319
+ if not postings or df == 0:
320
+ continue
321
+ # IDF with +1 stabilizer
322
+ idf = math.log((self._N - df + 0.5) / (df + 0.5) + 1.0)
323
+ for doc_id, tf in postings.items():
324
+ dl = self._doc_len.get(doc_id, 0) or 1
325
+ denom = tf + k1 * (1 - b + b * (dl / (self._avgdl or 1.0)))
326
+ score = idf * (tf * (k1 + 1)) / denom
327
+ candidate_docs[doc_id] = candidate_docs.get(doc_id, 0.0) + score
328
+ return candidate_docs
329
+
330
+ def _semantic_candidates(self, query: str, n: int) -> List[Tuple[Any, float]]:
331
+ """Return list of (doc, score) for semantic candidates; fallback if scores not available."""
332
+ try:
333
+ if hasattr(self.vector_store, "similarity_search_with_score"):
334
+ docs_scores = self.vector_store.similarity_search_with_score(query, k=n)
335
+ # docs_scores -> List[Tuple[Document, float]] where lower score is closer for some stores; normalize later
336
+ return docs_scores
337
+ # fallback: without scores, get docs and synthesize decreasing scores
338
+ docs = self.vector_store.similarity_search(query, k=n)
339
+ return list(zip(docs, [1.0 - (i / max(1, n)) for i in range(len(docs))]))
340
  except Exception as e:
341
  print(f"❌ Vector search failed: {e}")
342
  return []
343
+
344
+ def _normalize_scores(self, scores: Dict[str, float]) -> Dict[str, float]:
345
+ if not scores:
346
+ return {}
347
+ vals = list(scores.values())
348
+ mx = max(vals)
349
+ mn = min(vals)
350
+ if mx == mn:
351
+ return {k: 1.0 for k in scores}
352
+ return {k: (v - mn) / (mx - mn) for k, v in scores.items()}
353
+
354
+ def semantic_search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
355
+ """Hybrid retrieval: fuse semantic and keyword (BM25-like) signals and return top-k results.
356
+ Maintains original signature and return shape for compatibility.
357
+ """
358
+ # Guardrails
359
+ k = max(1, min(int(k or 5), 10))
360
+ # Collect candidates
361
+ n_sem = max(k * 2, 10)
362
+ n_kw = max(k * 3, 20)
363
+
364
+ sem_pairs = self._semantic_candidates(query, n_sem)
365
+ # Build semantic score map keyed by (url or id)
366
+ sem_scores: Dict[str, float] = {}
367
+ sem_docs_map: Dict[str, Any] = {}
368
+ for d, score in sem_pairs:
369
+ meta = getattr(d, "metadata", {}) or {}
370
+ url = (meta.get("url") or "").strip()
371
+ key = url or getattr(d, "id", None) or id(d)
372
+ sem_scores[key] = float(score if score is not None else 0.0)
373
+ sem_docs_map[key] = d
374
+ # Normalize semantic scores to ascending relevance (higher better)
375
+ # For some stores, lower distance is better; invert appropriately
376
+ if sem_scores:
377
+ # Try to detect if lower is better (distance) and invert
378
+ vals = list(sem_scores.values())
379
+ lower_is_better = True if len(vals) > 1 and vals[0] > vals[-1] else False
380
+ if lower_is_better:
381
+ maxv = max(vals)
382
+ sem_scores = {k: (maxv - v) for k, v in sem_scores.items()}
383
+ sem_scores = self._normalize_scores(sem_scores)
384
+
385
+ # Keyword BM25 candidates
386
+ kw_raw_scores = self._bm25_scores(query)
387
+ # Keep top n_kw keyword docs
388
+ if kw_raw_scores:
389
+ kw_items = sorted(kw_raw_scores.items(), key=lambda x: x[1], reverse=True)[:n_kw]
390
+ kw_raw_scores = dict(kw_items)
391
+ kw_scores = self._normalize_scores(kw_raw_scores)
392
+
393
+ # Fusion: weighted sum
394
+ alpha = 0.6 # semantic weight
395
+ beta = 0.4 # keyword weight
396
+ fused: Dict[str, float] = {}
397
+
398
+ # Include all keys from both sets
399
+ keys = set(sem_scores.keys()) | set(kw_scores.keys())
400
+ for key in keys:
401
+ s = sem_scores.get(key, 0.0)
402
+ w = kw_scores.get(key, 0.0)
403
+ fused[key] = alpha * s + beta * w
404
+
405
+ if not fused and sem_docs_map:
406
+ # If keyword index not ready, fallback to semantic docs order
407
+ ordered = sorted(sem_docs_map.items(), key=lambda kv: sem_scores.get(kv[0], 0.0), reverse=True)
408
+ docs = [d for _, d in ordered[:k]]
409
+ elif not fused and kw_scores:
410
+ # If semantic failed, reconstruct docs from index metadata
411
+ ordered = sorted(kw_scores.items(), key=lambda kv: kv[1], reverse=True)[:k]
412
+ docs = []
413
+ for doc_id, _ in ordered:
414
+ meta_entry = self._doc_meta.get(doc_id) or {}
415
+ content = meta_entry.get("content", "")
416
+ meta = meta_entry.get("metadata", {})
417
+ docs.append(Document(page_content=content, metadata=meta))
418
+ else:
419
+ ordered = sorted(fused.items(), key=lambda kv: kv[1], reverse=True)[:max(k*2, 20)]
420
+ docs = []
421
+ seen_keys = set()
422
+ for key, _ in ordered:
423
+ if key in seen_keys:
424
+ continue
425
+ seen_keys.add(key)
426
+ if key in sem_docs_map:
427
+ docs.append(sem_docs_map[key])
428
+ else:
429
+ # reconstruct from keyword index
430
+ meta_entry = self._doc_meta.get(key) or {}
431
+ content = meta_entry.get("content", "")
432
+ meta = meta_entry.get("metadata", {})
433
+ docs.append(Document(page_content=content, metadata=meta))
434
+
435
+ # Convert to results shape and dedupe by URL
436
  results: List[Dict[str, Any]] = []
437
  seen_urls = set()
438
  for d in docs:
 
442
  if content.startswith("Title: "):
443
  line = content.splitlines()[0]
444
  title = line.replace("Title: ", "").strip() or title
445
+ url = (meta.get("url", "") or "").strip()
446
  if url and url in seen_urls:
447
  continue
448
  seen_urls.add(url)
449
+ results.append({
450
+ "title": title,
451
+ "content": content,
452
+ "url": url,
453
+ "source": meta.get("source", "Unknown"),
454
+ "metadata": meta,
455
+ })
456
+ if len(results) >= k:
457
+ break
458
  return results
test_hybrid_retriever.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from cve_factchecker.retriever import VectorNewsRetriever
4
+ from cve_factchecker.models import NewsArticle
5
+
6
+ # Keep this test lightweight and isolated
7
+ os.environ["USE_DUMMY_EMBEDDINGS"] = "true"
8
+ os.environ["VECTOR_PERSIST_DIR"] = os.path.abspath("./vector_db_hybrid_test")
9
+
10
+ articles = [
11
+ NewsArticle(
12
+ title="Militants storm FC lines in Bannu",
13
+ content=(
14
+ "At least five militants attacked the Frontier Corps (FC) Lines in Bannu, Khyber-Pakhtunkhwa. "
15
+ "Security forces responded swiftly, and the situation is under control."
16
+ ),
17
+ url="https://tribune.com.pk/story/2564614/militants-storm-fc-lines-in-bannu",
18
+ source="The Express Tribune",
19
+ published_date="2025-09-15",
20
+ scraped_date=str(int(time.time())),
21
+ article_id="a1",
22
+ language="English",
23
+ ),
24
+ NewsArticle(
25
+ title="Six soldiers martyred; five terrorists killed in Bannu FC compound attack",
26
+ content=(
27
+ "An attack on the FC compound in Bannu resulted in the martyrdom of six soldiers."
28
+ "Reports indicate five terrorists were killed in the exchange."
29
+ ),
30
+ url="https://dailytimes.com.pk/1363459/six-soldiers-martyred-five-terrorists-killed-in-attack-on-bannu-fc-compound/",
31
+ source="Daily Times",
32
+ published_date="2025-09-15",
33
+ scraped_date=str(int(time.time())),
34
+ article_id="a2",
35
+ language="English",
36
+ ),
37
+ NewsArticle(
38
+ title="KP operations update: militants neutralized",
39
+ content=(
40
+ "Security operations in Khyber-Pakhtunkhwa neutralized multiple militants. The Frontier Corps participated "
41
+ "in the operations across the province."
42
+ ),
43
+ url="https://dailytimes.com.pk/1368975/31-indian-backed-militants-killed-in-kp-operations/",
44
+ source="Daily Times",
45
+ published_date="2025-09-16",
46
+ scraped_date=str(int(time.time())),
47
+ article_id="a3",
48
+ language="English",
49
+ ),
50
+ NewsArticle(
51
+ title="Sports: Cricket series announced",
52
+ content="Pakistan Cricket Board announced a new bilateral series in Lahore next month.",
53
+ url="https://example.com/sports/cricket-series",
54
+ source="Example Sports",
55
+ published_date="2025-09-10",
56
+ scraped_date=str(int(time.time())),
57
+ article_id="a4",
58
+ language="English",
59
+ ),
60
+ ]
61
+
62
+ if __name__ == "__main__":
63
+ retriever = VectorNewsRetriever(persist_directory=os.environ["VECTOR_PERSIST_DIR"])
64
+ retriever.store_articles_in_vector_db(articles, clear_first=True)
65
+
66
+ query = (
67
+ "At least five militants attacked the Frontier Corps (FC) Lines in Bannu, Khyber-Pakhtunkhwa"
68
+ )
69
+
70
+ print("\n=== Hybrid Retrieval Results (k=5) ===")
71
+ results = retriever.semantic_search(query, k=5)
72
+ for i, r in enumerate(results, 1):
73
+ print(f"{i}. {r.get('title')} | {r.get('url')} | source={r.get('source')}")
74
+ snippet = (r.get('content','') or '')[:120].replace('\n', ' ')
75
+ print(f" Snippet: {snippet}...")
76
+
77
+ # Basic sanity checks
78
+ print("\nCounts:")
79
+ print("vector_count:", retriever.get_vector_count())
80
+ print("results_count:", len(results))