From da45f24db08300db0b97b4a5b7b7045e86e8ed13 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 6 May 2023 09:31:24 -0700 Subject: [PATCH 1/2] 4x speedup --- pgml-extension/src/vectors.rs | 510 ++++++++++++++++------------------ 1 file changed, 234 insertions(+), 276 deletions(-) diff --git a/pgml-extension/src/vectors.rs b/pgml-extension/src/vectors.rs index 5f335a4db..7be94ee33 100644 --- a/pgml-extension/src/vectors.rs +++ b/pgml-extension/src/vectors.rs @@ -1,331 +1,335 @@ use pgrx::*; +use pgrx::array::RawArray; #[pg_extern(immutable, parallel_safe, strict, name = "add")] -fn add_scalar_s(vector: Vec, addend: f32) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() +fn add_scalar_s(vector: Array, addend: f32) -> Vec { + vector.iter_deny_null().map(|a| a + addend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "add")] -fn add_scalar_d(vector: Vec, addend: f64) -> Vec { - vector.as_slice().iter().map(|a| a + addend).collect() +fn add_scalar_d(vector: Array, addend: f64) -> Vec { + vector.iter_deny_null().map(|a| a + addend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] -fn subtract_scalar_s(vector: Vec, subtahend: f32) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() +fn subtract_scalar_s(vector: Array, subtahend: f32) -> Vec { + vector.iter_deny_null().map(|a| a - subtahend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] -fn subtract_scalar_d(vector: Vec, subtahend: f64) -> Vec { - vector.as_slice().iter().map(|a| a - subtahend).collect() +fn subtract_scalar_d(vector: Array, subtahend: f64) -> Vec { + vector.iter_deny_null().map(|a| a - subtahend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] -fn multiply_scalar_s(vector: Vec, multiplicand: f32) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() +fn multiply_scalar_s(vector: Array, multiplicand: f32) -> Vec { + vector.iter_deny_null().map(|a| a * multiplicand).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] -fn multiply_scalar_d(vector: Vec, multiplicand: f64) -> Vec { - vector.as_slice().iter().map(|a| a * multiplicand).collect() +fn multiply_scalar_d(vector: Array, multiplicand: f64) -> Vec { + vector.iter_deny_null().map(|a| a * multiplicand).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "divide")] -fn divide_scalar_s(vector: Vec, dividend: f32) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() +fn divide_scalar_s(vector: Array, dividend: f32) -> Vec { + vector.iter_deny_null().map(|a| a / dividend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "divide")] -fn divide_scalar_d(vector: Vec, dividend: f64) -> Vec { - vector.as_slice().iter().map(|a| a / dividend).collect() +fn divide_scalar_d(vector: Array, dividend: f64) -> Vec { + vector.iter_deny_null().map(|a| a / dividend).collect() } #[pg_extern(immutable, parallel_safe, strict, name = "add")] -fn add_vector_s(vector: Vec, addend: Vec) -> Vec { - vector - .as_slice() - .iter() - .zip(addend.as_slice().iter()) +fn add_vector_s(vector: Array, addend: Array) -> Vec { + vector.iter_deny_null() + .zip(addend.iter_deny_null()) .map(|(a, b)| a + b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "add")] -fn add_vector_d(vector: Vec, addend: Vec) -> Vec { +fn add_vector_d(vector: Array, addend: Array) -> Vec { vector - .as_slice() - .iter() - .zip(addend.as_slice().iter()) + .iter_deny_null() + .zip(addend.iter_deny_null()) .map(|(a, b)| a + b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] -fn subtract_vector_s(vector: Vec, subtahend: Vec) -> Vec { +fn subtract_vector_s(vector: Array, subtahend: Array) -> Vec { vector - .as_slice() - .iter() - .zip(subtahend.as_slice().iter()) + .iter_deny_null() + .zip(subtahend.iter_deny_null()) .map(|(a, b)| a - b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "subtract")] -fn subtract_vector_d(vector: Vec, subtahend: Vec) -> Vec { +fn subtract_vector_d(vector: Array, subtahend: Array) -> Vec { vector - .as_slice() - .iter() - .zip(subtahend.as_slice().iter()) + .iter_deny_null() + .zip(subtahend.iter_deny_null()) .map(|(a, b)| a - b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] -fn multiply_vector_s(vector: Vec, multiplicand: Vec) -> Vec { +fn multiply_vector_s(vector: Array, multiplicand: Array) -> Vec { vector - .as_slice() - .iter() - .zip(multiplicand.as_slice().iter()) + .iter_deny_null() + .zip(multiplicand.iter_deny_null()) .map(|(a, b)| a * b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "multiply")] -fn multiply_vector_d(vector: Vec, multiplicand: Vec) -> Vec { +fn multiply_vector_d(vector: Array, multiplicand: Array) -> Vec { vector - .as_slice() - .iter() - .zip(multiplicand.as_slice().iter()) + .iter_deny_null() + .zip(multiplicand.iter_deny_null()) .map(|(a, b)| a * b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "divide")] -fn divide_vector_s(vector: Vec, dividend: Vec) -> Vec { +fn divide_vector_s(vector: Array, dividend: Array) -> Vec { vector - .as_slice() - .iter() - .zip(dividend.as_slice().iter()) + .iter_deny_null() + .zip(dividend.iter_deny_null()) .map(|(a, b)| a / b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "divide")] -fn divide_vector_d(vector: Vec, dividend: Vec) -> Vec { +fn divide_vector_d(vector: Array, dividend: Array) -> Vec { vector - .as_slice() - .iter() - .zip(dividend.as_slice().iter()) + .iter_deny_null() + .zip(dividend.iter_deny_null()) .map(|(a, b)| a / b) .collect() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] -fn norm_l0_s(vector: Vec) -> f32 { +fn norm_l0_s(vector: Array) -> f32 { vector - .as_slice() - .iter() - .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .iter_deny_null() + .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) .sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] -fn norm_l0_d(vector: Vec) -> f64 { +fn norm_l0_d(vector: Array) -> f64 { vector - .as_slice() - .iter() - .map(|a| if *a == 0.0 { 0.0 } else { 1.0 }) + .iter_deny_null() + .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) .sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] -fn norm_l1_s(vector: Vec) -> f32 { - unsafe { blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +fn norm_l1_s(vector: Array) -> f32 { + unsafe { + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + blas::sasum(vector.len().try_into().unwrap(), vector, 1) + } } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] -fn norm_l1_d(vector: Vec) -> f64 { - unsafe { blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +fn norm_l1_d(vector: Array) -> f64 { + unsafe { + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + blas::dasum(vector.len().try_into().unwrap(), vector, 1) + } } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] -fn norm_l2_s(vector: Vec) -> f32 { - unsafe { blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +fn norm_l2_s(vector: Array) -> f32 { + unsafe { + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + blas::snrm2(vector.len().try_into().unwrap(), vector, 1) + } } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l2")] -fn norm_l2_d(vector: Vec) -> f64 { - unsafe { blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1) } +fn norm_l2_d(vector: Array) -> f64 { + unsafe { + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + blas::dnrm2(vector.len().try_into().unwrap(), vector, 1) + } } #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] -fn norm_max_s(vector: Vec) -> f32 { +fn norm_max_s(vector: Array) -> f32 { unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + let index = blas::isamax(vector.len().try_into().unwrap(), vector, 1); vector[index - 1].abs() } } #[pg_extern(immutable, parallel_safe, strict, name = "norm_max")] -fn norm_max_d(vector: Vec) -> f64 { +fn norm_max_d(vector: Array) -> f64 { unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + let index = blas::idamax(vector.len().try_into().unwrap(), vector, 1); vector[index - 1].abs() } } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] -fn normalize_l1_s(vector: Vec) -> Vec { +fn normalize_l1_s(vector: Array) -> Vec { let norm: f32; unsafe { - norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + norm = blas::sasum(vector.len().try_into().unwrap(), vector, 1); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_s(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l1")] -fn normalize_l1_d(vector: Vec) -> Vec { +fn normalize_l1_d(vector: Array) -> Vec { let norm: f64; unsafe { - norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + norm = blas::dasum(vector.len().try_into().unwrap(), vector, 1); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_d(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] -fn normalize_l2_s(vector: Vec) -> Vec { +fn normalize_l2_s(vector: Array) -> Vec { let norm: f32; unsafe { - norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + norm = blas::snrm2(vector.len().try_into().unwrap(), vector, 1); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_s(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_l2")] -fn normalize_l2_d(vector: Vec) -> Vec { +fn normalize_l2_d(vector: Array) -> Vec { let norm: f64; unsafe { - norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + norm = blas::dnrm2(vector.len().try_into().unwrap(), vector, 1); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_d(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] -fn normalize_max_s(vector: Vec) -> Vec { +fn normalize_max_s(vector: Array) -> Vec { let norm; unsafe { - let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + let index = blas::isamax(vector.len().try_into().unwrap(), vector, 1); norm = vector[index - 1].abs(); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_s(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "normalize_max")] -fn normalize_max_d(vector: Vec) -> Vec { +fn normalize_max_d(vector: Array) -> Vec { let norm; unsafe { - let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(), 1); + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + let index = blas::idamax(vector.len().try_into().unwrap(), vector, 1); norm = vector[index - 1].abs(); + vector.iter().map(|a| a / norm).collect() } - divide_scalar_d(vector, norm) } #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] -fn distance_l1_s(vector: Vec, other: Vec) -> f32 { +fn distance_l1_s(vector: Array, other: Array) -> f32 { vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) + .iter_deny_null() + .zip(other.iter_deny_null()) .map(|(a, b)| (a - b).abs()) .sum() } #[pg_extern(immutable, parallel_safe, strict, name = "distance_l1")] -fn distance_l1_d(vector: Vec, other: Vec) -> f64 { +fn distance_l1_d(vector: Array, other: Array) -> f64 { vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) + .iter_deny_null() + .zip(other.iter_deny_null()) .map(|(a, b)| (a - b).abs()) .sum() } #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] -fn distance_l2_s(vector: Vec, other: Vec) -> f32 { +fn distance_l2_s(vector: Array, other: Array) -> f32 { vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) + .iter_deny_null() + .zip(other.iter_deny_null()) .map(|(a, b)| (a - b).powf(2.0)) .sum::() .sqrt() } #[pg_extern(immutable, parallel_safe, strict, name = "distance_l2")] -fn distance_l2_d(vector: Vec, other: Vec) -> f64 { +fn distance_l2_d(vector: Array, other: Array) -> f64 { vector - .as_slice() - .iter() - .zip(other.as_slice().iter()) + .iter_deny_null() + .zip(other.iter_deny_null()) .map(|(a, b)| (a - b).powf(2.0)) .sum::() .sqrt() } #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] -fn dot_product_s(vector: Vec, other: Vec) -> f32 { +fn dot_product_s(vector: Array, other: Array) -> f32 { unsafe { + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + let other: &[f32] = RawArray::from_array(other).unwrap().data().as_ref(); blas::sdot( vector.len().try_into().unwrap(), - vector.as_slice(), + vector, 1, - other.as_slice(), + other, 1, ) } } #[pg_extern(immutable, parallel_safe, strict, name = "dot_product")] -fn dot_product_d(vector: Vec, other: Vec) -> f64 { +fn dot_product_d(vector: Array, other: Array) -> f64 { unsafe { + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + let other: &[f64] = RawArray::from_array(other).unwrap().data().as_ref(); blas::ddot( vector.len().try_into().unwrap(), - vector.as_slice(), + vector, 1, - other.as_slice(), + other, 1, ) } } #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] -fn cosine_similarity_s(vector: Vec, other: Vec) -> f32 { +fn cosine_similarity_s(vector: Array, other: Array) -> f32 { unsafe { - let dot = blas::sdot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ); - let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); + let other: &[f32] = RawArray::from_array(other).unwrap().data().as_ref(); + let len = vector.len() as i32; + let dot = blas::sdot(len, vector, 1, other, 1); + let a_norm = blas::snrm2(len, vector, 1); + let b_norm = blas::snrm2(len, other, 1); dot / (a_norm * b_norm) } } #[pg_extern(immutable, parallel_safe, strict, name = "cosine_similarity")] -fn cosine_similarity_d(vector: Vec, other: Vec) -> f64 { +fn cosine_similarity_d(vector: Array, other: Array) -> f64 { unsafe { - let dot = blas::ddot( - vector.len().try_into().unwrap(), - vector.as_slice(), - 1, - other.as_slice(), - 1, - ); - let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(), 1); - let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(), 1); + let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); + let other: &[f64] = RawArray::from_array(other).unwrap().data().as_ref(); + let len = vector.len() as i32; + let dot = blas::ddot(len, vector, 1, other, 1); + let a_norm = blas::dnrm2(len, vector, 1); + let b_norm = blas::dnrm2(len, other, 1); dot / (a_norm * b_norm) } } @@ -341,7 +345,7 @@ impl Aggregate for SumS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( + fn state<'a>( mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo, @@ -999,299 +1003,253 @@ mod tests { #[pg_test] fn test_add_scalar_s() { - assert_eq!( - add_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), - [2.0, 3.0, 4.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float4[], 1)"); + assert_eq!(result, Ok(Some([2.0, 3.0, 4.0].to_vec()))); } #[pg_test] fn test_add_scalar_d() { - assert_eq!( - add_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), - [2.0, 3.0, 4.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float8[], 1)"); + assert_eq!(result, Ok(Some([2.0, 3.0, 4.0].to_vec()))); } #[pg_test] fn test_subtract_scalar_s() { - assert_eq!( - subtract_scalar_s([1.0, 2.0, 3.0].to_vec(), 1.0), - [0.0, 1.0, 2.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], 1)"); + assert_eq!(result, Ok(Some([0.0, 1.0, 2.0].to_vec()))); } #[pg_test] fn test_subtract_scalar_d() { - assert_eq!( - subtract_scalar_d([1.0, 2.0, 3.0].to_vec(), 1.0), - [0.0, 1.0, 2.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float8[], 1)"); + assert_eq!(result, Ok(Some([0.0, 1.0, 2.0].to_vec()))); } #[pg_test] fn test_multiply_scalar_s() { - assert_eq!( - multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), - [2.0, 4.0, 6.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float4[], 2)"); + assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_multiply_scalar_d() { - assert_eq!( - multiply_scalar_d([1.0, 2.0, 3.0].to_vec(), 2.0), - [2.0, 4.0, 6.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float8[], 2)"); + assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_divide_scalar_s() { - assert_eq!( - divide_scalar_s([2.0, 4.0, 6.0].to_vec(), 2.0), - [1.0, 2.0, 3.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float4[], 10)"); + assert_eq!(result, Ok(Some([0.1, 0.2, 0.3].to_vec()))); } #[pg_test] fn test_divide_scalar_d() { - assert_eq!( - divide_scalar_d([2.0, 4.0, 6.0].to_vec(), 2.0), - [1.0, 2.0, 3.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float8[], 10)"); + assert_eq!(result, Ok(Some([0.1, 0.2, 0.3].to_vec()))); } #[pg_test] fn test_add_vector_s() { - assert_eq!( - add_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [2.0, 4.0, 6.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_add_vector_d() { - assert_eq!( - add_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [2.0, 4.0, 6.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_subtract_vector_s() { - assert_eq!( - subtract_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [0.0, 0.0, 0.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_subtract_vector_d() { - assert_eq!( - subtract_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [0.0, 0.0, 0.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_s() { - assert_eq!( - multiply_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 4.0, 9.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_d() { - assert_eq!( - multiply_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 4.0, 9.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + assert_eq!(result, Ok(Some([1.0, 4.0, 9.0].to_vec()))); } #[pg_test] fn test_divide_vector_s() { - assert_eq!( - divide_vector_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 1.0, 1.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } #[pg_test] fn test_divide_vector_d() { - assert_eq!( - divide_vector_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - [1.0, 1.0, 1.0].to_vec() - ) + let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } #[pg_test] fn test_norm_l0_s() { - assert_eq!(norm_l0_s([1.0, 2.0, 3.0].to_vec()), 3.0) + let result = Spi::get_one::("SELECT pgml.norm_l0(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(3.0))); } #[pg_test] fn test_norm_l0_d() { - assert_eq!(norm_l0_d([1.0, 2.0, 3.0].to_vec()), 3.0) + let result = Spi::get_one::("SELECT pgml.norm_l0(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(3.0))); } #[pg_test] fn test_norm_l1_s() { - assert_eq!(norm_l1_s([1.0, 2.0, 3.0].to_vec()), 6.0) + let result = Spi::get_one::("SELECT pgml.norm_l1(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(6.0))); } #[pg_test] fn test_norm_l1_d() { - assert_eq!(norm_l1_d([1.0, 2.0, 3.0].to_vec()), 6.0) + let result = Spi::get_one::("SELECT pgml.norm_l1(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(6.0))); } #[pg_test] fn test_norm_l2_s() { - assert_eq!(norm_l2_s([1.0, 2.0, 3.0].to_vec()), 3.7416575); + let result = Spi::get_one::("SELECT pgml.norm_l2(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(3.7416575))); } #[pg_test] fn test_norm_l2_d() { - assert_eq!(norm_l2_d([1.0, 2.0, 3.0].to_vec()), 3.7416573867739413); + let result = Spi::get_one::("SELECT pgml.norm_l2(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(3.7416573867739413))); } #[pg_test] fn test_norm_max_s() { - assert_eq!(norm_max_s([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_s([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); + let result = Spi::get_one::("SELECT pgml.norm_max(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(3.0))); + + let result = Spi::get_one::("SELECT pgml.norm_max(ARRAY[1,2,3,-4]::float4[])"); + assert_eq!(result, Ok(Some(4.0))); } #[pg_test] fn test_norm_max_d() { - assert_eq!(norm_max_d([1.0, 2.0, 3.0].to_vec()), 3.0); - assert_eq!(norm_max_d([1.0, 2.0, 3.0, -4.0].to_vec()), 4.0); + let result = Spi::get_one::("SELECT pgml.norm_max(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(3.0))); + + let result = Spi::get_one::("SELECT pgml.norm_max(ARRAY[1,2,3,-4]::float8[])"); + assert_eq!(result, Ok(Some(4.0))); } #[pg_test] fn test_normalize_l1_s() { - assert_eq!( - normalize_l1_s([1.0, 2.0, 3.0].to_vec()), - [0.16666667, 0.33333334, 0.5].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_l1(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some([0.16666667, 0.33333334, 0.5].to_vec()))); } #[pg_test] fn test_normalize_l1_d() { - assert_eq!( - normalize_l1_d([1.0, 2.0, 3.0].to_vec()), - [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_l1(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some([0.16666666666666666, 0.3333333333333333, 0.5].to_vec()))); } #[pg_test] fn test_normalize_l2_s() { - assert_eq!( - normalize_l2_s([1.0, 2.0, 3.0].to_vec()), - [0.26726124, 0.5345225, 0.8017837].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_l2(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some([0.26726124, 0.5345225, 0.8017837].to_vec()))); } #[pg_test] fn test_normalize_l2_d() { - assert_eq!( - normalize_l2_d([1.0, 2.0, 3.0].to_vec()), - [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_l2(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some([0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec()))); } #[pg_test] fn test_normalize_max_s() { - assert_eq!( - normalize_max_s([1.0, 2.0, 3.0].to_vec()), - [0.33333334, 0.6666667, 1.0].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_max(ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some([0.33333334, 0.6666667, 1.0].to_vec()))); } #[pg_test] fn test_normalize_max_d() { - assert_eq!( - normalize_max_d([1.0, 2.0, 3.0].to_vec()), - [0.3333333333333333, 0.6666666666666666, 1.0].to_vec() - ); + let result = Spi::get_one::>("SELECT pgml.normalize_max(ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec()))); } #[pg_test] fn test_distance_l1_s() { - assert_eq!( - distance_l1_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l1_d() { - assert_eq!( - distance_l1_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_s() { - assert_eq!( - distance_l2_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_d() { - assert_eq!( - distance_l2_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.0 - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_dot_product_s() { - assert_eq!( - dot_product_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 14.0 - ); - assert_eq!( - dot_product_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 20.0 - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + assert_eq!(result, Ok(Some(14.0))); + + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])"); + assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_dot_product_d() { - assert_eq!( - dot_product_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 14.0 - ); - assert_eq!( - dot_product_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 20.0 - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + assert_eq!(result, Ok(Some(14.0))); + + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])"); + assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_cosine_similarity_s() { - assert_eq!( - cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 0.99999994 - ); - assert_eq!( - cosine_similarity_s([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 0.9925833 - ); + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + assert_eq!(result, Ok(Some(0.99999994))); + + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[2.0, 3.0, 4.0]::float4[])"); + assert_eq!(result, Ok(Some(0.9925833))); + + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])"); + assert_eq!(result, Ok(Some(0.4472136))); } #[pg_test] fn test_cosine_similarity_d() { - assert_eq!( - cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [1.0, 2.0, 3.0].to_vec()), - 1.0 - ); - assert_eq!( - cosine_similarity_d([1.0, 2.0, 3.0].to_vec(), [2.0, 3.0, 4.0].to_vec()), - 0.9925833339709303 - ); + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + assert_eq!(result, Ok(Some(1.0))); + + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[2.0, 3.0, 4.0]::float8[])"); + assert_eq!(result, Ok(Some(0.9925833339709303))); + + let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])"); + assert_eq!(result, Ok(Some(0.4472135954999579))); } } From c5f7751bfe0ae70d2b2de2a3e0d3a553e0b480be Mon Sep 17 00:00:00 2001 From: Montana Low Date: Sat, 6 May 2023 15:24:21 -0700 Subject: [PATCH 2/2] fmt --- pgml-extension/src/vectors.rs | 125 +++++++++++++++++++++++----------- 1 file changed, 85 insertions(+), 40 deletions(-) diff --git a/pgml-extension/src/vectors.rs b/pgml-extension/src/vectors.rs index 7be94ee33..b7cafb461 100644 --- a/pgml-extension/src/vectors.rs +++ b/pgml-extension/src/vectors.rs @@ -1,5 +1,5 @@ -use pgrx::*; use pgrx::array::RawArray; +use pgrx::*; #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_scalar_s(vector: Array, addend: f32) -> Vec { @@ -43,7 +43,8 @@ fn divide_scalar_d(vector: Array, dividend: f64) -> Vec { #[pg_extern(immutable, parallel_safe, strict, name = "add")] fn add_vector_s(vector: Array, addend: Array) -> Vec { - vector.iter_deny_null() + vector + .iter_deny_null() .zip(addend.iter_deny_null()) .map(|(a, b)| a + b) .collect() @@ -283,13 +284,7 @@ fn dot_product_s(vector: Array, other: Array) -> f32 { unsafe { let vector: &[f32] = RawArray::from_array(vector).unwrap().data().as_ref(); let other: &[f32] = RawArray::from_array(other).unwrap().data().as_ref(); - blas::sdot( - vector.len().try_into().unwrap(), - vector, - 1, - other, - 1, - ) + blas::sdot(vector.len().try_into().unwrap(), vector, 1, other, 1) } } @@ -298,13 +293,7 @@ fn dot_product_d(vector: Array, other: Array) -> f64 { unsafe { let vector: &[f64] = RawArray::from_array(vector).unwrap().data().as_ref(); let other: &[f64] = RawArray::from_array(other).unwrap().data().as_ref(); - blas::ddot( - vector.len().try_into().unwrap(), - vector, - 1, - other, - 1, - ) + blas::ddot(vector.len().try_into().unwrap(), vector, 1, other, 1) } } @@ -1051,49 +1040,65 @@ mod tests { #[pg_test] fn test_add_vector_s() { - let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + let result = Spi::get_one::>( + "SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", + ); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_add_vector_d() { - let result = Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + let result = Spi::get_one::>( + "SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", + ); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_subtract_vector_s() { - let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + let result = Spi::get_one::>( + "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", + ); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_subtract_vector_d() { - let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + let result = Spi::get_one::>( + "SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", + ); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_s() { - let result = Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + let result = Spi::get_one::>( + "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", + ); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_d() { - let result = Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + let result = Spi::get_one::>( + "SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", + ); assert_eq!(result, Ok(Some([1.0, 4.0, 9.0].to_vec()))); } #[pg_test] fn test_divide_vector_s() { - let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + let result = Spi::get_one::>( + "SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", + ); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } #[pg_test] fn test_divide_vector_d() { - let result = Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + let result = Spi::get_one::>( + "SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", + ); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } @@ -1160,19 +1165,32 @@ mod tests { #[pg_test] fn test_normalize_l1_d() { let result = Spi::get_one::>("SELECT pgml.normalize_l1(ARRAY[1,2,3]::float8[])"); - assert_eq!(result, Ok(Some([0.16666666666666666, 0.3333333333333333, 0.5].to_vec()))); + assert_eq!( + result, + Ok(Some( + [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() + )) + ); } #[pg_test] fn test_normalize_l2_s() { let result = Spi::get_one::>("SELECT pgml.normalize_l2(ARRAY[1,2,3]::float4[])"); - assert_eq!(result, Ok(Some([0.26726124, 0.5345225, 0.8017837].to_vec()))); + assert_eq!( + result, + Ok(Some([0.26726124, 0.5345225, 0.8017837].to_vec())) + ); } #[pg_test] fn test_normalize_l2_d() { let result = Spi::get_one::>("SELECT pgml.normalize_l2(ARRAY[1,2,3]::float8[])"); - assert_eq!(result, Ok(Some([0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec()))); + assert_eq!( + result, + Ok(Some( + [0.2672612419124244, 0.5345224838248488, 0.8017837257372732].to_vec() + )) + ); } #[pg_test] @@ -1184,57 +1202,80 @@ mod tests { #[pg_test] fn test_normalize_max_d() { let result = Spi::get_one::>("SELECT pgml.normalize_max(ARRAY[1,2,3]::float8[])"); - assert_eq!(result, Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec()))); + assert_eq!( + result, + Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec())) + ); } #[pg_test] fn test_distance_l1_s() { - let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", + ); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l1_d() { - let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", + ); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_s() { - let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", + ); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_d() { - let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", + ); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_dot_product_s() { - let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", + ); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])", + ); assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_dot_product_d() { - let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", + ); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])", + ); assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_cosine_similarity_s() { - let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", + ); assert_eq!(result, Ok(Some(0.99999994))); - let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[2.0, 3.0, 4.0]::float4[])"); + let result = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float4[], ARRAY[2.0, 3.0, 4.0]::float4[])", + ); assert_eq!(result, Ok(Some(0.9925833))); let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])"); @@ -1243,10 +1284,14 @@ mod tests { #[pg_test] fn test_cosine_similarity_d() { - let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", + ); assert_eq!(result, Ok(Some(1.0))); - let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[2.0, 3.0, 4.0]::float8[])"); + let result = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,2,3]::float8[], ARRAY[2.0, 3.0, 4.0]::float8[])", + ); assert_eq!(result, Ok(Some(0.9925833339709303))); let result = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])"); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy