4
4
import hashlib
5
5
import os
6
6
7
- class TestCollection (unittest .TestCase ):
8
7
8
+ class TestCollection (unittest .TestCase ):
9
9
def setUp (self ) -> None :
10
10
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
11
- conninfo = os .environ .get ("PGML_CONNECTION" ,local_pgml )
11
+ conninfo = os .environ .get ("PGML_CONNECTION" , local_pgml )
12
12
self .db = Database (conninfo )
13
13
self .collection_name = "test_collection_1"
14
14
self .documents = [
15
15
{
16
- "id" : hashlib .md5 (f"abcded-{ i } " .encode (' utf-8' )).hexdigest (),
17
- "text" :f"Lorem ipsum { i } " ,
18
- "metadata" : { "source" : "test_suite" }
16
+ "id" : hashlib .md5 (f"abcded-{ i } " .encode (" utf-8" )).hexdigest (),
17
+ "text" : f"Lorem ipsum { i } " ,
18
+ "source" : "test_suite" ,
19
19
}
20
20
for i in range (4 , 7 )
21
21
]
22
22
self .documents_no_ids = [
23
23
{
24
- "text" :f"Lorem ipsum { i } " ,
25
- "metadata" : { "source" : "test_suite_no_ids" }
24
+ "text" : f"Lorem ipsum { i } " ,
25
+ "source" : "test_suite_no_ids" ,
26
26
}
27
27
for i in range (1 , 4 )
28
28
]
29
-
29
+
30
+ self .documents_with_metadata = [
31
+ {
32
+ "text" : f"Lorem ipsum metadata" ,
33
+ "source" : f"url { i } " ,
34
+ "url" : f"/home { i } " ,
35
+ "user" : f"John Doe-{ i + 1 } " ,
36
+ }
37
+ for i in range (8 , 12 )
38
+ ]
39
+
40
+ self .documents_with_reviews = [
41
+ {
42
+ "text" : f"product is abc { i } " ,
43
+ "reviews" : i * 2 ,
44
+ }
45
+ for i in range (20 , 25 )
46
+ ]
47
+
48
+ self .documents_with_reviews_metadata = [
49
+ {
50
+ "text" : f"product is abc { i } " ,
51
+ "reviews" : i * 2 ,
52
+ "source" : "amazon" ,
53
+ "user" : "John Doe" ,
54
+ }
55
+ for i in range (20 , 25 )
56
+ ]
57
+
58
+ self .documents_with_reviews_metadata += [
59
+ {
60
+ "text" : f"product is abc { i } " ,
61
+ "reviews" : i * 2 ,
62
+ "source" : "ebay" ,
63
+ }
64
+ for i in range (20 , 25 )
65
+ ]
66
+
30
67
self .collection = self .db .create_or_get_collection (self .collection_name )
31
-
68
+
32
69
def test_create_collection (self ):
33
- assert isinstance (self .collection ,Collection )
34
-
70
+ assert isinstance (self .collection , Collection )
71
+
35
72
def test_documents_upsert (self ):
36
73
self .collection .upsert_documents (self .documents )
37
74
conn = self .db .pool .getconn ()
38
- results = run_select_statement (conn ,"SELECT id FROM %s" % self .collection .documents_table )
75
+ results = run_select_statement (
76
+ conn , "SELECT id FROM %s" % self .collection .documents_table
77
+ )
39
78
self .db .pool .putconn (conn )
40
79
assert len (results ) >= len (self .documents )
41
-
80
+
42
81
def test_documents_upsert_no_ids (self ):
43
82
self .collection .upsert_documents (self .documents_no_ids )
44
83
conn = self .db .pool .getconn ()
45
- results = run_select_statement (conn ,"SELECT id FROM %s" % self .collection .documents_table )
84
+ results = run_select_statement (
85
+ conn , "SELECT id FROM %s" % self .collection .documents_table
86
+ )
46
87
self .db .pool .putconn (conn )
47
88
assert len (results ) >= len (self .documents_no_ids )
48
89
@@ -52,23 +93,25 @@ def test_default_text_splitter(self):
52
93
53
94
assert splitter_id == 1
54
95
assert splitters [0 ]["name" ] == "RecursiveCharacterTextSplitter"
55
-
96
+
56
97
def test_default_embeddings_model (self ):
57
98
model_id = self .collection .register_model ()
58
99
models = self .collection .get_models ()
59
-
100
+
60
101
assert model_id == 1
61
102
assert models [0 ]["name" ] == "intfloat/e5-small"
62
-
103
+
63
104
def test_generate_chunks (self ):
64
105
self .collection .upsert_documents (self .documents )
65
106
self .collection .upsert_documents (self .documents_no_ids )
66
107
splitter_id = self .collection .register_text_splitter ()
67
108
self .collection .generate_chunks (splitter_id = splitter_id )
68
- splitter_params = {"chunk_size" : 100 , "chunk_overlap" :20 }
69
- splitter_id = self .collection .register_text_splitter (splitter_params = splitter_params )
109
+ splitter_params = {"chunk_size" : 100 , "chunk_overlap" : 20 }
110
+ splitter_id = self .collection .register_text_splitter (
111
+ splitter_params = splitter_params
112
+ )
70
113
self .collection .generate_chunks (splitter_id = splitter_id )
71
-
114
+
72
115
def test_generate_embeddings (self ):
73
116
self .collection .upsert_documents (self .documents )
74
117
self .collection .upsert_documents (self .documents_no_ids )
@@ -84,10 +127,42 @@ def test_vector_search(self):
84
127
self .collection .generate_embeddings ()
85
128
results = self .collection .vector_search ("Lorem ipsum 1" , top_k = 2 )
86
129
assert results [0 ]["score" ] == 1.0
87
-
88
- # def tearDown(self) -> None:
89
- # self.db.archive_collection(self.collection_name)
90
130
131
+ def test_vector_search_metadata_filter (self ):
132
+ self .collection .upsert_documents (self .documents )
133
+ self .collection .upsert_documents (self .documents_no_ids )
134
+ self .collection .upsert_documents (self .documents_with_metadata )
135
+ self .collection .generate_chunks ()
136
+ self .collection .generate_embeddings ()
137
+ results = self .collection .vector_search (
138
+ "Lorem ipsum metadata" ,
139
+ top_k = 2 ,
140
+ metadata_filter = {"url" : "/home 8" , "source" : "url 8" },
141
+ )
142
+ assert results [0 ]["metadata" ]["user" ] == "John Doe-9"
143
+
144
+ def test_vector_search_generic_filter (self ):
145
+ self .collection .upsert_documents (self .documents_with_reviews )
146
+ self .collection .generate_chunks ()
147
+ self .collection .generate_embeddings ()
148
+ results = self .collection .vector_search (
149
+ "product is abc 21" ,
150
+ top_k = 2 ,
151
+ generic_filter = "(documents.metadata->>'reviews')::int < 45" ,
152
+ )
153
+ assert results [0 ]["metadata" ]["reviews" ] == 42
91
154
92
-
93
-
155
+ def test_vector_search_generic_and_metadata_filter (self ):
156
+ self .collection .upsert_documents (self .documents_with_reviews_metadata )
157
+ self .collection .generate_chunks ()
158
+ self .collection .generate_embeddings ()
159
+ results = self .collection .vector_search (
160
+ "product is abc 21" ,
161
+ top_k = 2 ,
162
+ generic_filter = "(documents.metadata->>'reviews')::int < 45" ,
163
+ metadata_filter = {"source" : "amazon" },
164
+ )
165
+ assert results [0 ]["metadata" ]["user" ] == "John Doe"
166
+
167
+ # def tearDown(self) -> None:
168
+ # self.db.archive_collection(self.collection_name)
0 commit comments