Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from vector_db.vector_db_client import VectorDB | |
| from PIL import Image | |
| from transformers import AutoProcessor, CLIPModel | |
| import os | |
| import uuid | |
| from tcvectordb.model.document import SearchParams | |
| import traceback | |
| LOCAL_MODEL_PATH = "download_model.local_model_path" | |
| MODEL_NAME = "download_model.model_name" | |
| LOCAL_GRAPH_PATH = "graph_upload.local_graph_path" | |
| class ChatSearch: | |
| def __init__(self, config, vdb: VectorDB): | |
| self.vdb = vdb | |
| self.model_name = config.get(MODEL_NAME) | |
| self.local_model_path = config.get(LOCAL_MODEL_PATH) | |
| self.local_graph_path = config.get(LOCAL_GRAPH_PATH) | |
| self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) | |
| self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) | |
| def initial_model(self): | |
| model = CLIPModel.from_pretrained(self.model_cache_directory) | |
| processor = AutoProcessor.from_pretrained(self.model_cache_directory) | |
| return model, processor | |
| def search_result(self, image): | |
| if image is None: | |
| return "请先上传图片..." | |
| if not os.path.exists(self.model_cache_directory): | |
| return f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" | |
| model, processor = self.initial_model() | |
| try: | |
| # 生成唯一的文件名 | |
| unique_filename = f"{uuid.uuid4().hex}.png" | |
| image_path = os.path.join(self.graph_cache_directory, unique_filename) | |
| # 保存图片到指定文件夹 | |
| image.save(image_path) | |
| image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
| # 假设你的 VectorDB 支持图片搜索 | |
| collection = self.vdb.get_collection() | |
| res = collection.search( | |
| vectors=[image_vector], | |
| params=SearchParams(ef=200), | |
| limit=10, | |
| output_fields=['local_graph_path'] | |
| ) | |
| results = [] | |
| for i, docs in enumerate(res): | |
| for doc in docs: | |
| image_path = doc['local_graph_path'] | |
| try: | |
| image = Image.open(image_path) | |
| results.append(image) | |
| except Exception as e: | |
| print(f"无法加载图片 {image_path}: {e}") | |
| return results | |
| except Exception as e: | |
| print(f"问题:{e}\n") | |
| error_trace = traceback.format_exc() | |
| print(error_trace) | |
| def _process_image(self, image_path, emb_model, processor): | |
| """ | |
| 处理单个图片文件,将其转换为向量。 | |
| 参数: | |
| image_path (str): 图片文件的路径。 | |
| 返回: | |
| torch.Tensor: 图片的向量表示。 | |
| """ | |
| image = Image.open(image_path) | |
| inputs = processor(images=image, return_tensors="pt") | |
| image_features = emb_model.get_image_features(**inputs) | |
| return image_features | |
| def get_chart(self): | |
| return gr.Interface( | |
| fn=self.search_result, | |
| inputs=gr.Image(type="pil", label="上传图片"), | |
| outputs=gr.Gallery(label="检索结果"), | |
| theme="soft", | |
| description="上传图片进行检索", | |
| allow_flagging="never" | |
| ) | |