import os, json from pathlib import Path from tqdm import tqdm import faiss import numpy as np from sentence_transformers import SentenceTransformer DUMP_PATH = "/home/ubuntu/output" FAISS_OUT = "wiki_faiss.index" STATE_FILE = "progress.json" PAUSE_FLAG = "PAUSE" CHUNK_SIZE = 200 BATCH_SIZE = 1000 CHECKPOINT_BATCHES = 5 # Load model and FAISS index embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") dim = embedder.get_sentence_embedding_dimension() if Path(FAISS_OUT).exists(): index = faiss.read_index(FAISS_OUT) else: index = faiss.IndexFlatIP(dim) # Gather all files files = [os.path.join(r,f) for r,_,fs in os.walk(DUMP_PATH) for f in fs if f.startswith("wiki_")] total_files = len(files) # Load progress if Path(STATE_FILE).exists(): with open(STATE_FILE) as f: state = json.load(f) file_idx = state.get("file_idx", 0) batch_idx = state.get("batch_idx", 0) print(f"▶ Resuming from file {file_idx}, batch {batch_idx}") else: file_idx = 0 batch_idx = 0 # Helper: split text into chunks def chunk_text(text, size=CHUNK_SIZE): words = text.split() for i in range(0, len(words), size): yield " ".join(words[i:i+size]) # --- Precompute total chunks and already processed chunks for overall progress bar --- file_chunk_counts = [] total_chunks = 0 for f in files: cnt = 0 try: with open(f, "r", encoding="utf-8") as file: for line in file: data = json.loads(line) text = data.get("text", "").strip() if text: cnt += len(list(chunk_text(text))) except: pass file_chunk_counts.append(cnt) total_chunks += cnt # Already processed chunks processed_chunks = sum(file_chunk_counts[:file_idx]) + batch_idx # Overall progress bar pbar = tqdm(total=total_chunks, initial=processed_chunks, desc="Embedding chunks", unit="chunk") # --- Main loop --- for f_idx in range(file_idx, total_files): file_path = files[f_idx] # Pause check if Path(PAUSE_FLAG).exists(): print("\n⏸ Pause requested. Saving state...") faiss.write_index(index, FAISS_OUT) with open(STATE_FILE, "w") as f: json.dump({"file_idx": f_idx, "batch_idx": batch_idx}, f) exit(0) # Read file chunks = [] try: with open(file_path, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) text = data.get("text", "").strip() if text: chunks.extend(list(chunk_text(text))) except Exception as e: print(f"Error reading {file_path}: {e}") continue start = batch_idx if f_idx == file_idx else 0 total_chunks_in_file = len(chunks) # Process chunks in batches for b_idx in range(start, total_chunks_in_file, BATCH_SIZE): if Path(PAUSE_FLAG).exists(): print("\n⏸ Pause requested. Saving state...") faiss.write_index(index, FAISS_OUT) with open(STATE_FILE, "w") as f: json.dump({"file_idx": f_idx, "batch_idx": b_idx}, f) exit(0) batch_texts = chunks[b_idx:b_idx+BATCH_SIZE] embeddings = embedder.encode(batch_texts, convert_to_numpy=True, dtype=np.float32) faiss.normalize_L2(embeddings) index.add(embeddings) # Update overall progress bar pbar.update(len(batch_texts)) # Checkpoint if (b_idx // BATCH_SIZE + 1) % CHECKPOINT_BATCHES == 0: faiss.write_index(index, FAISS_OUT) with open(STATE_FILE, "w") as f: json.dump({"file_idx": f_idx, "batch_idx": b_idx + BATCH_SIZE}, f) # Finished file batch_idx = 0 faiss.write_index(index, FAISS_OUT) with open(STATE_FILE, "w") as f: json.dump({"file_idx": f_idx+1, "batch_idx": 0}, f) pbar.close() print("✅ All files processed.") if Path(PAUSE_FLAG).exists(): os.remove(PAUSE_FLAG)