Skip to content

Commit 541a1d5

Browse files
authored
Removed transactions (#765)
1 parent cba5448 commit 541a1d5

File tree

4 files changed

+126
-195
lines changed

4 files changed

+126
-195
lines changed

pgml-sdks/rust/pgml/src/collection.rs

Lines changed: 99 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::collections::HashMap;
1313
use crate::languages::javascript::*;
1414
use crate::models;
1515
use crate::queries;
16-
use crate::{query_builder, transaction_wrapper};
16+
use crate::query_builder;
1717

1818
/// A collection of documents
1919
#[derive(custom_derive, Debug, Clone)]
@@ -314,18 +314,16 @@ impl Collection {
314314
};
315315
let source_uuid = uuid::Uuid::from_slice(&md5_digest.0)?;
316316

317-
transaction_wrapper!(
318-
sqlx::query(&query_builder!(
319-
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
320-
self.documents_table_name
321-
))
322-
.bind(&text)
323-
.bind(source_uuid)
324-
.bind(&document_json)
325-
.bind(&text)
326-
.bind(&document_json),
327-
self.pool.borrow()
328-
);
317+
sqlx::query(&query_builder!(
318+
"INSERT INTO %s (text, source_uuid, metadata) VALUES ($1, $2, $3) ON CONFLICT (source_uuid) DO UPDATE SET text = $4, metadata = $5",
319+
self.documents_table_name
320+
))
321+
.bind(&text)
322+
.bind(source_uuid)
323+
.bind(&document_json)
324+
.bind(&text)
325+
.bind(&document_json)
326+
.execute(self.pool.borrow()).await?;
329327
}
330328
Ok(())
331329
}
@@ -363,18 +361,14 @@ impl Collection {
363361
None => serde_json::json!({}),
364362
};
365363

366-
let current_splitter;
367-
transaction_wrapper!(
368-
current_splitter,
369-
sqlx::query_as::<_, models::Splitter>(&query_builder!(
370-
"SELECT * from %s where name = $1 and parameters = $2;",
371-
self.splitters_table_name
372-
))
373-
.bind(&splitter_name)
374-
.bind(&splitter_params),
375-
self.pool.borrow(),
376-
fetch_optional
377-
);
364+
let current_splitter: Option<models::Splitter> = sqlx::query_as(&query_builder!(
365+
"SELECT * from %s where name = $1 and parameters = $2;",
366+
self.splitters_table_name
367+
))
368+
.bind(&splitter_name)
369+
.bind(&splitter_params)
370+
.fetch_optional(self.pool.borrow())
371+
.await?;
378372

379373
match current_splitter {
380374
Some(_splitter) => {
@@ -384,32 +378,27 @@ impl Collection {
384378
);
385379
}
386380
None => {
387-
transaction_wrapper!(
388-
sqlx::query(&query_builder!(
389-
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
390-
self.splitters_table_name
391-
))
392-
.bind(splitter_name)
393-
.bind(splitter_params),
394-
self.pool.borrow()
395-
);
381+
sqlx::query(&query_builder!(
382+
"INSERT INTO %s (name, parameters) VALUES ($1, $2)",
383+
self.splitters_table_name
384+
))
385+
.bind(splitter_name)
386+
.bind(splitter_params)
387+
.execute(self.pool.borrow())
388+
.await?;
396389
}
397390
}
398391
Ok(())
399392
}
400393

401394
/// Gets all registered text [models::Splitter]s
402395
pub async fn get_text_splitters(&self) -> anyhow::Result<Vec<models::Splitter>> {
403-
let splitters;
404-
transaction_wrapper!(
405-
splitters,
406-
sqlx::query_as::<_, models::Splitter>(&query_builder!(
407-
"SELECT * from %s",
408-
self.splitters_table_name
409-
)),
410-
self.pool.borrow(),
411-
fetch_all
412-
);
396+
let splitters: Vec<models::Splitter> = sqlx::query_as(&query_builder!(
397+
"SELECT * from %s",
398+
self.splitters_table_name
399+
))
400+
.fetch_all(self.pool.borrow())
401+
.await?;
413402
Ok(splitters)
414403
}
415404

@@ -443,17 +432,16 @@ impl Collection {
443432
/// ```
444433
pub async fn generate_chunks(&self, splitter_id: Option<i64>) -> anyhow::Result<()> {
445434
let splitter_id = splitter_id.unwrap_or(1);
446-
transaction_wrapper!(
447-
sqlx::query(&query_builder!(
448-
queries::GENERATE_CHUNKS,
449-
self.splitters_table_name,
450-
self.chunks_table_name,
451-
self.documents_table_name,
452-
self.chunks_table_name
453-
))
454-
.bind(splitter_id),
455-
self.pool.borrow()
456-
);
435+
sqlx::query(&query_builder!(
436+
queries::GENERATE_CHUNKS,
437+
self.splitters_table_name,
438+
self.chunks_table_name,
439+
self.documents_table_name,
440+
self.chunks_table_name
441+
))
442+
.bind(splitter_id)
443+
.execute(self.pool.borrow())
444+
.await?;
457445
Ok(())
458446
}
459447

@@ -492,19 +480,15 @@ impl Collection {
492480
None => serde_json::json!({}),
493481
};
494482

495-
let current_model;
496-
transaction_wrapper!(
497-
current_model,
498-
sqlx::query_as::<_, models::Model>(&query_builder!(
499-
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
500-
self.models_table_name
501-
))
502-
.bind(&task)
503-
.bind(&model_name)
504-
.bind(&model_params),
505-
self.pool.borrow(),
506-
fetch_optional
507-
);
483+
let current_model: Option<models::Model> = sqlx::query_as(&query_builder!(
484+
"SELECT * from %s where task = $1 and name = $2 and parameters = $3;",
485+
self.models_table_name
486+
))
487+
.bind(&task)
488+
.bind(&model_name)
489+
.bind(&model_params)
490+
.fetch_optional(self.pool.borrow())
491+
.await?;
508492

509493
match current_model {
510494
Some(model) => {
@@ -515,37 +499,27 @@ impl Collection {
515499
Ok(model.id)
516500
}
517501
None => {
518-
let id;
519-
transaction_wrapper!(
520-
id,
521-
sqlx::query_as::<_, (i64,)>(&query_builder!(
522-
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
523-
self.models_table_name
524-
))
525-
.bind(task)
526-
.bind(model_name)
527-
.bind(model_params),
528-
self.pool.borrow(),
529-
fetch_one
530-
);
502+
let id: (i64,) = sqlx::query_as(&query_builder!(
503+
"INSERT INTO %s (task, name, parameters) VALUES ($1, $2, $3) RETURNING id",
504+
self.models_table_name
505+
))
506+
.bind(task)
507+
.bind(model_name)
508+
.bind(model_params)
509+
.fetch_one(self.pool.borrow())
510+
.await?;
531511
Ok(id.0)
532512
}
533513
}
534514
}
535515

536516
/// Gets all registered [models::Model]s
537517
pub async fn get_models(&self) -> anyhow::Result<Vec<models::Model>> {
538-
let models;
539-
transaction_wrapper!(
540-
models,
541-
sqlx::query_as::<_, models::Model>(&query_builder!(
542-
"SELECT * from %s",
543-
self.models_table_name
544-
)),
545-
self.pool.borrow(),
546-
fetch_all
547-
);
548-
Ok(models)
518+
Ok(
519+
sqlx::query_as(&query_builder!("SELECT * from %s", self.models_table_name))
520+
.fetch_all(self.pool.borrow())
521+
.await?,
522+
)
549523
}
550524

551525
async fn create_or_get_embeddings_table(
@@ -554,17 +528,13 @@ impl Collection {
554528
splitter_id: i64,
555529
) -> anyhow::Result<String> {
556530
let pool = self.pool.borrow();
557-
let table_name;
558-
transaction_wrapper!(
559-
table_name,
560-
sqlx::query_as::<_, (String,)>(&query_builder!(
531+
let table_name: Option<(String,)> =
532+
sqlx::query_as(&query_builder!(
561533
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
562534
self.transforms_table_name))
563535
.bind(model_id)
564-
.bind(splitter_id),
565-
pool,
566-
fetch_optional
567-
);
536+
.bind(splitter_id)
537+
.fetch_optional(pool).await?;
568538
match table_name {
569539
Some((name,)) => Ok(name),
570540
None => {
@@ -573,12 +543,11 @@ impl Collection {
573543
self.name,
574544
&uuid::Uuid::new_v4().to_string()[0..6]
575545
);
576-
let embedding;
577-
transaction_wrapper!(embedding, sqlx::query_as::<_, (Vec<f32>,)>(&query_builder!(
546+
let embedding: (Vec<f32>,) = sqlx::query_as(&query_builder!(
578547
"WITH model as (SELECT name, parameters from %s where id = $1) SELECT embedding from pgml.embed(transformer => (SELECT name FROM model), text => 'Hello, World!', kwargs => (SELECT parameters FROM model)) as embedding",
579548
self.models_table_name))
580-
.bind(model_id),
581-
pool, fetch_one);
549+
.bind(model_id)
550+
.fetch_one(pool).await?;
582551
let embedding = embedding.0;
583552
let embedding_length = embedding.len() as i64;
584553
pool.execute(
@@ -591,15 +560,13 @@ impl Collection {
591560
.as_str(),
592561
)
593562
.await?;
594-
transaction_wrapper!(
595-
sqlx::query(&query_builder!(
596-
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
597-
self.transforms_table_name))
598-
.bind(&table_name)
599-
.bind(model_id)
600-
.bind(splitter_id),
601-
pool
602-
);
563+
sqlx::query(&query_builder!(
564+
"INSERT INTO %s (table_name, task, model_id, splitter_id) VALUES ($1, 'embedding', $2, $3)",
565+
self.transforms_table_name))
566+
.bind(&table_name)
567+
.bind(model_id)
568+
.bind(splitter_id)
569+
.execute(pool).await?;
603570
pool.execute(
604571
query_builder!(
605572
queries::CREATE_INDEX,
@@ -677,18 +644,17 @@ impl Collection {
677644
.create_or_get_embeddings_table(model_id, splitter_id)
678645
.await?;
679646

680-
transaction_wrapper!(
681-
sqlx::query(&query_builder!(
682-
queries::GENERATE_EMBEDDINGS,
683-
self.models_table_name,
684-
embeddings_table_name,
685-
self.chunks_table_name,
686-
embeddings_table_name
687-
))
688-
.bind(model_id)
689-
.bind(splitter_id),
690-
self.pool.borrow()
691-
);
647+
sqlx::query(&query_builder!(
648+
queries::GENERATE_EMBEDDINGS,
649+
self.models_table_name,
650+
embeddings_table_name,
651+
self.chunks_table_name,
652+
embeddings_table_name
653+
))
654+
.bind(model_id)
655+
.bind(splitter_id)
656+
.execute(self.pool.borrow())
657+
.await?;
692658

693659
Ok(())
694660
}
@@ -751,17 +717,12 @@ impl Collection {
751717
let model_id = model_id.unwrap_or(1);
752718
let splitter_id = splitter_id.unwrap_or(1);
753719

754-
let embeddings_table_name;
755-
transaction_wrapper!(
756-
embeddings_table_name,
757-
sqlx::query_as::<_, (String,)>(&query_builder!(
720+
let embeddings_table_name: Option<(String,)> = sqlx::query_as(&query_builder!(
758721
"SELECT table_name from %s WHERE task = 'embedding' AND model_id = $1 and splitter_id = $2",
759722
self.transforms_table_name))
760723
.bind(model_id)
761-
.bind(splitter_id),
762-
self.pool.borrow(),
763-
fetch_optional
764-
);
724+
.bind(splitter_id)
725+
.fetch_optional(self.pool.borrow()).await?;
765726

766727
let embeddings_table_name = match embeddings_table_name {
767728
Some((table_name,)) => table_name,
@@ -770,10 +731,8 @@ impl Collection {
770731
}
771732
};
772733

773-
let results;
774-
transaction_wrapper!(
775-
results,
776-
sqlx::query_as::<_, (f64, String, Json<HashMap<String, String>>)>(&query_builder!(
734+
let results: Vec<(f64, String, Json<HashMap<String, String>>)> =
735+
sqlx::query_as(&query_builder!(
777736
queries::VECTOR_SEARCH,
778737
self.models_table_name,
779738
embeddings_table_name,
@@ -784,10 +743,9 @@ impl Collection {
784743
.bind(model_id)
785744
.bind(query)
786745
.bind(query_params)
787-
.bind(top_k),
788-
self.pool.borrow(),
789-
fetch_all
790-
);
746+
.bind(top_k)
747+
.fetch_all(self.pool.borrow())
748+
.await?;
791749
let results: Vec<(f64, String, HashMap<String, String>)> =
792750
results.into_iter().map(|r| (r.0, r.1, r.2 .0)).collect();
793751
Ok(results)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy