@@ -48,35 +48,11 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
48
48
) ?;
49
49
}
50
50
51
- let sql = "
52
- WITH semantic_search AS (
53
- SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
54
- FROM documents
55
- ORDER BY embedding <=> $2
56
- LIMIT 20
57
- ),
58
- keyword_search AS (
59
- SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
60
- FROM documents, plainto_tsquery('english', $1) query
61
- WHERE to_tsvector('english', content) @@ query
62
- ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
63
- LIMIT 20
64
- )
65
- SELECT
66
- COALESCE(semantic_search.id, keyword_search.id) AS id,
67
- COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
68
- COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
69
- FROM semantic_search
70
- FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
71
- ORDER BY score DESC
72
- LIMIT 5
73
- " ;
74
-
75
51
let query = "growling bear" ;
76
52
let query_embedding = model. embed ( query) ?;
77
53
let k = 60.0 ;
78
54
79
- for row in client. query ( sql , & [ & query, & Vector :: from ( query_embedding) , & k] ) ? {
55
+ for row in client. query ( HYBRID_SQL , & [ & query, & Vector :: from ( query_embedding) , & k] ) ? {
80
56
let id: i32 = row. get ( 0 ) ;
81
57
let score: f64 = row. get ( 1 ) ;
82
58
println ! ( "document: {}, RRF score: {}" , id, score) ;
@@ -85,6 +61,30 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
85
61
Ok ( ( ) )
86
62
}
87
63
64
+ const HYBRID_SQL : & str = "
65
+ WITH semantic_search AS (
66
+ SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
67
+ FROM documents
68
+ ORDER BY embedding <=> $2
69
+ LIMIT 20
70
+ ),
71
+ keyword_search AS (
72
+ SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
73
+ FROM documents, plainto_tsquery('english', $1) query
74
+ WHERE to_tsvector('english', content) @@ query
75
+ ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
76
+ LIMIT 20
77
+ )
78
+ SELECT
79
+ COALESCE(semantic_search.id, keyword_search.id) AS id,
80
+ COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
81
+ COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
82
+ FROM semantic_search
83
+ FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
84
+ ORDER BY score DESC
85
+ LIMIT 5
86
+ " ;
87
+
88
88
struct EmbeddingModel {
89
89
tokenizer : Tokenizer ,
90
90
model : BertModel ,
0 commit comments