# Usage Examples ## 1. Basic Example ```python from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity model = SentenceTransformer('ThanhLe0125/ebd-math') # Single query example query = "query: Cách tính đạo hàm của hàm số" chunks = [ "passage: Đạo hàm của hàm số f(x) tại điểm x0 được định nghĩa...", "passage: Các quy tắc tính đạo hàm cơ bản: (x^n)' = nx^(n-1)...", "passage: Phương trình tích phân là phương trình chứa hàm số..." ] query_emb = model.encode([query]) chunk_embs = model.encode(chunks) similarities = cosine_similarity(query_emb, chunk_embs)[0] print("Rankings:") for i, sim in enumerate(similarities): print(f"Chunk {i+1}: {sim:.4f}") ``` ## 2. Batch Processing ```python queries = [ "query: Định nghĩa hàm số đồng biến", "query: Cách giải phương trình bậc hai", "query: Công thức tính thể tích hình cầu" ] # Encode all at once for efficiency query_embs = model.encode(queries) chunk_embs = model.encode(chunks) # Calculate similarities for all queries for i, query in enumerate(queries): similarities = cosine_similarity([query_embs[i]], chunk_embs)[0] best_idx = similarities.argmax() print(f"Best match for '{query}': {chunks[best_idx]} (score: {similarities[best_idx]:.4f})") ``` ## 3. Production Usage ```python class MathRetriever: def __init__(self, model_name='ThanhLe0125/ebd-math'): self.model = SentenceTransformer(model_name) def retrieve(self, query, chunks, top_k=5): # Format inputs formatted_query = f"query: {query}" if not query.startswith("query:") else query formatted_chunks = [f"passage: {chunk}" if not chunk.startswith("passage:") else chunk for chunk in chunks] # Encode and rank query_emb = self.model.encode([formatted_query]) chunk_embs = self.model.encode(formatted_chunks) similarities = cosine_similarity(query_emb, chunk_embs)[0] # Get top K results top_indices = similarities.argsort()[::-1][:top_k] results = [ { 'chunk': chunks[i], 'similarity': float(similarities[i]), 'rank': rank + 1 } for rank, i in enumerate(top_indices) ] return results # Usage retriever = MathRetriever() results = retriever.retrieve( "Định nghĩa hàm số liên tục", mathematical_chunks, top_k=3 ) ```