Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import os | |
| import zipfile | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import AutoProcessor, CLIPModel | |
| from vector_db.vector_db_client import VectorDB | |
| from tcvectordb.model.document import Document | |
| import uuid | |
| import traceback | |
| import numpy as np | |
| # 生成随机的 UUID | |
| LOCAL_MODEL_PATH = "download_model.local_model_path" | |
| MODEL_NAME = "download_model.model_name" | |
| LOCAL_GRAPH_PATH="graph_upload.local_graph_path" | |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" | |
| init_css=""" | |
| <style> | |
| .equal-height-row { | |
| display: flex; | |
| } | |
| .equal-height-column { | |
| flex: 1; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .equal-height-column > * { | |
| flex: 1; | |
| } | |
| </style> | |
| """ | |
| class Initial_and_Upload: | |
| def __init__(self, config,vdb: VectorDB): | |
| self.vdb = vdb | |
| self.model_name = config.get(MODEL_NAME) | |
| self.local_model_path = config.get(LOCAL_MODEL_PATH) | |
| self.local_graph_path=config.get(LOCAL_GRAPH_PATH) | |
| self.model_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_model_path, self.model_name) | |
| self.graph_cache_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), self.local_graph_path) | |
| def initial_model(self): | |
| model = CLIPModel.from_pretrained(self.model_cache_directory) | |
| processor = AutoProcessor.from_pretrained(self.model_cache_directory) | |
| return model,processor | |
| def _download_model(self, model_name, progress=gr.Progress()): | |
| """ | |
| 下载指定的Hugging Face模型并保存在指定位置。 | |
| 参数: | |
| model_name (str): 模型在Hugging Face上的名字。 | |
| save_directory (str): 模型文件保存的位置。 | |
| """ | |
| os.environ['TRANSFORMERS_CACHE'] = self.model_cache_directory | |
| # 创建保存目录(如果不存在) | |
| if not os.path.exists(self.model_cache_directory): | |
| os.makedirs(self.model_cache_directory) | |
| text = f"[正在尝试下载] 模型 {model_name},因为涉及到模型相关的多个文件下载,进度仅在后台显示。\n" | |
| progress(0.5, desc=text) | |
| try: | |
| # 下载模型 | |
| snapshot_download( | |
| repo_id=model_name, | |
| local_dir=self.model_cache_directory, | |
| local_dir_use_symlinks=False, | |
| ) | |
| progress(1, f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}") | |
| text += f"模型 {model_name} 已下载并保存在 {self.model_cache_directory}" | |
| time.sleep(0.3) | |
| return text | |
| except Exception as e: | |
| text += f"[下载失败] 失败原因:{e}" | |
| return text | |
| def _process_image(self, image_path,emb_model,processor): | |
| """ | |
| 处理单个图片文件,将其转换为向量。 | |
| 参数: | |
| image_path (str): 图片文件的路径。 | |
| 返回: | |
| torch.Tensor: 图片的向量表示。 | |
| """ | |
| image = Image.open(image_path) | |
| # image.verify() # 验证图片是否有效 | |
| inputs = processor(images=image, return_tensors="pt") | |
| image_features = emb_model.get_image_features(**inputs) | |
| return image_features | |
| def _handle_upload(self, file, progress=gr.Progress()): | |
| """ | |
| 处理上传的文件,识别是图片还是ZIP压缩包,并将图片转换为向量。 | |
| 参数: | |
| file (file): 上传的文件。 | |
| 返回: | |
| str: 文件类型和处理结果。 | |
| """ | |
| output_text = "" | |
| image_vectors = [] | |
| if not os.path.exists(self.model_cache_directory): | |
| output_text += f"缓存目录 {self.model_cache_directory} 不存在,无法初始化模型。" | |
| else: | |
| model, processor = self.initial_model() | |
| collection = self.vdb.get_collection() | |
| if zipfile.is_zipfile(file.name): | |
| with zipfile.ZipFile(file.name, 'r') as zip_ref: | |
| zip_ref.extractall(self.local_graph_path) | |
| image_files = [file_name for file_name in zip_ref.namelist() if file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')) and not file_name.startswith('__MACOSX') and not file_name.startswith('._')] | |
| total_files = len(image_files) | |
| for i, file_name in enumerate(image_files): | |
| image_path = os.path.join(self.local_graph_path, file_name) | |
| try: | |
| image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
| random_uuid = str(uuid.uuid4()) # 转换为字符串 | |
| collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) | |
| output_text += f"处理图片: {file_name}\n" | |
| except UnidentifiedImageError: | |
| output_text += f"无法识别图片文件: {file_name}\n" | |
| # 更新进度 | |
| progress((i + 1) / total_files) | |
| output_text += "上传的是ZIP压缩包,已解压缩并处理所有图片。" | |
| else: | |
| try: | |
| # 保存单张图片到指定文件夹 | |
| image_path = os.path.join(self.graph_cache_directory, os.path.basename(file.name)) | |
| with open(file.name, "rb") as f_src: | |
| with open(image_path, "wb") as f_dst: | |
| f_dst.write(f_src.read()) | |
| image_vector = self._process_image(image_path, model, processor).squeeze().tolist() # 转换为一维列表 | |
| random_uuid = str(uuid.uuid4()) # 转换为字符串 | |
| collection.upsert(documents=[Document(id=random_uuid, vector=image_vector, local_graph_path=image_path)], build_index=True) | |
| output_text += "上传的是图片文件,并已处理。\n" | |
| # 更新进度 | |
| progress(1.0) | |
| except (IOError, SyntaxError) as e: | |
| output_text += f"无法识别文件类型:{e}\n" | |
| # 返回处理结果和图片向量 | |
| return output_text, image_vectors | |
| def _initialize_vector_db(self, progress=gr.Progress()): | |
| """ | |
| 初始化向量数据库。 | |
| 返回: | |
| str: 初始化结果。 | |
| """ | |
| output_text = f"[正在尝试连接] VectorDB {self.vdb.address}\n" | |
| progress(0, desc=output_text) | |
| try: | |
| client = self.vdb.create_client() | |
| client.list_databases() | |
| progress(0.05, f"[连接成功] VectorDB {self.vdb.address}\n") | |
| output_text += f"[连接成功] VectorDB {self.vdb.address}\n" | |
| client.close() | |
| progress(0.1, f"[正在初始化] ai database '{self.vdb.db_name}'\n") | |
| output_text += f"[正在初始化] ai database '{self.vdb.db_name}'\n" | |
| self.vdb.init_database() | |
| progress(0.3, f"[初始化完成] ai database '{self.vdb.db_name}'\n") | |
| output_text += f"[初始化完成] ai database '{self.vdb.db_name}'\n" | |
| progress(0.5, f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n") | |
| output_text += f"[正在初始化] ai collection '{self.vdb.ai_graph_emb_collection}'\n" | |
| self.vdb.init_graph_collection() | |
| progress(0.9, f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n") | |
| output_text += f"[初始化完成] ai collection '{self.vdb.ai_graph_emb_collection}'\n" | |
| progress(1, f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]") | |
| output_text += f"您可以去图片上传栏目上传图片或ZIP压缩包,然后进一步进行[图片搜索]" | |
| time.sleep(0.3) | |
| except Exception as e: | |
| output_text += f"[数据库访问失败] 失败原因:{e}" | |
| error_trace = traceback.format_exc() | |
| print(error_trace) | |
| return output_text | |
| def get_init_panel(self): | |
| with gr.Blocks() as demo: | |
| gr.HTML(init_css) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name_input = gr.Textbox(lines=1, label="模型名称", placeholder="请输入Hugging Face模型名称...", value=self.model_name) | |
| output = gr.Textbox(lines=10, label="下载进度", placeholder="下载进度将在这里显示...") | |
| init_button = gr.Button("开始下载模型") | |
| init_button.click( | |
| fn=self._download_model, | |
| inputs=[model_name_input], | |
| outputs=output | |
| ) | |
| with gr.Column(): | |
| db_init_output = gr.Textbox(lines=14.5, label="数据库初始化结果", placeholder="数据库初始化结果将在这里显示...") | |
| db_init_button = gr.Button("初始化向量数据库") | |
| db_init_button.click( | |
| fn=self._initialize_vector_db, | |
| inputs=[], | |
| outputs=db_init_output | |
| ) | |
| with gr.Row(): | |
| upload_file = gr.File(label="上传图片或ZIP压缩包") | |
| with gr.Row(): | |
| upload_output = gr.Textbox(lines=10, label="上传结果", placeholder="上传结果将在这里显示...") | |
| with gr.Row(): | |
| upload_button = gr.Button("上传文件") | |
| upload_button.click( | |
| fn=self._handle_upload, | |
| inputs=[upload_file], | |
| outputs=[upload_output, gr.State()] | |
| ) | |
| return demo |