mahimairaja commited on
Commit
46b73f5
·
1 Parent(s): 92aafdd

feat: initialize the project and added qdrant client to push the vectors

Browse files
Files changed (5) hide show
  1. .python-version +1 -0
  2. README.md +0 -0
  3. main.py +104 -0
  4. pyproject.toml +11 -0
  5. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
README.md ADDED
File without changes
main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+
4
+ import numpy as np
5
+ from datasets import load_dataset
6
+ from dotenv import load_dotenv
7
+ from qdrant_client import QdrantClient, models
8
+
9
+ load_dotenv()
10
+
11
+ repo_id = "mahimairaja/ibm-hls-burn-vectorized"
12
+ ds_from_hub = load_dataset(repo_id)
13
+
14
+
15
+ # 2. Initialize Qdrant client and create collection
16
+ client = QdrantClient(
17
+ url=os.getenv("QDRANT_URL"),
18
+ api_key=os.getenv("QDRANT_API_KEY"),
19
+ )
20
+ collection_name = "hls_burn_scars_vectorized"
21
+
22
+ # Recreate the collection with specified vector configurations and payload indexing
23
+ if client.collection_exists(collection_name=collection_name):
24
+ client.delete_collection(collection_name=collection_name)
25
+
26
+ client.create_collection(
27
+ collection_name=collection_name,
28
+ vectors_config={
29
+ "dense": models.VectorParams(
30
+ size=384,
31
+ distance=models.Distance.COSINE,
32
+ ),
33
+ "colbert": models.VectorParams(
34
+ size=128,
35
+ distance=models.Distance.COSINE,
36
+ multivector_config=models.MultiVectorConfig(
37
+ comparator=models.MultiVectorComparator.MAX_SIM
38
+ ),
39
+ hnsw_config=models.HnswConfigDiff(m=0), # Disable HNSW for reranking
40
+ ),
41
+ },
42
+ # Define payload schema for filtering
43
+ optimizers_config=models.OptimizersConfigDiff(default_segment_number=2),
44
+ )
45
+
46
+ # Create payload indexes for filtering
47
+ client.create_payload_index(
48
+ collection_name=collection_name,
49
+ field_name="centroid_lat",
50
+ field_schema=models.Field(field_type=models.PayloadSchemaType.FLOAT),
51
+ )
52
+ client.create_payload_index(
53
+ collection_name=collection_name,
54
+ field_name="centroid_lon",
55
+ field_schema=models.Field(field_type=models.PayloadSchemaType.FLOAT),
56
+ )
57
+ client.create_payload_index(
58
+ collection_name=collection_name,
59
+ field_name="acquisition_date",
60
+ field_schema=models.Field(field_type=models.PayloadSchemaType.DATETIME),
61
+ )
62
+
63
+
64
+ # 3. Prepare and ingest data into Qdrant
65
+ def generate_qdrant_points(dataset_split):
66
+ points = []
67
+ for i, item in enumerate(dataset_split):
68
+ # Ensure embeddings are numpy arrays for Qdrant, then convert to list
69
+ dense_vec = np.array(item["dense_embedding"], dtype=np.float32).tolist()
70
+ colbert_vec = np.array(item["colbert_embedding"], dtype=np.float32).tolist()
71
+
72
+ point = models.PointStruct(
73
+ id=i,
74
+ vector={
75
+ "dense": dense_vec,
76
+ "colbert": colbert_vec,
77
+ },
78
+ payload={
79
+ "centroid_lat": item["centroid_lat"],
80
+ "centroid_lon": item["centroid_lon"],
81
+ "acquisition_date": datetime.strptime(
82
+ item["acquisition_date"], "%Y-%m-%d"
83
+ ),
84
+ },
85
+ )
86
+ points.append(point)
87
+ return points
88
+
89
+
90
+ # Ingest data for each split
91
+ for split_name, dataset_split in ds_from_hub.items():
92
+ print(f"Ingesting {len(dataset_split)} points from '{split_name}' split...")
93
+ qdrant_points = generate_qdrant_points(dataset_split)
94
+
95
+ client.upsert(collection_name=collection_name, points=qdrant_points, wait=True)
96
+ print(
97
+ f"Finished ingesting {len(qdrant_points)} points into Qdrant for '{split_name}' split."
98
+ )
99
+
100
+ print("Data ingestion complete for all splits.")
101
+
102
+ # Verify ingestion by counting points
103
+ count_result = client.count(collection_name=collection_name, exact=True)
104
+ print(f"Total points in Qdrant collection: {count_result.count}")
pyproject.toml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "geo-spatial-chat-qdrant"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "datasets==2.21.0",
9
+ "python-dotenv>=1.2.1",
10
+ "qdrant-client>=1.16.1",
11
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff