Skip to content

Commit 3fb05eb

Browse files
ankaneNgalstyan4dqii
committed
Added casts for arrays to sparsevec - pgvector#604
Co-authored-by: Narek Galstyan <narekg@berkeley.edu> Co-authored-by: Di Qi <di@lantern.dev>
1 parent b738ffe commit 3fb05eb

File tree

6 files changed

+245
-0
lines changed

6 files changed

+245
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 0.8.0 (unreleased)
22

3+
- Added casts for arrays to `sparsevec`
34
- Reduced memory usage for HNSW index scans
45
- Dropped support for Postgres 12
56

sql/vector--0.7.4--0.8.0.sql

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- complain if script is sourced in psql, rather than via CREATE EXTENSION
2+
\echo Use "ALTER EXTENSION vector UPDATE TO '0.8.0'" to load this file. \quit
3+
4+
CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec
5+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
6+
7+
CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec
8+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
9+
10+
CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec
11+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
12+
13+
CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec
14+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
15+
16+
CREATE CAST (integer[] AS sparsevec)
17+
WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT;
18+
19+
CREATE CAST (real[] AS sparsevec)
20+
WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT;
21+
22+
CREATE CAST (double precision[] AS sparsevec)
23+
WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT;
24+
25+
CREATE CAST (numeric[] AS sparsevec)
26+
WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT;

sql/vector.sql

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,18 @@ CREATE FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) RETURNS sparseve
782782
CREATE FUNCTION sparsevec_to_halfvec(sparsevec, integer, boolean) RETURNS halfvec
783783
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
784784

785+
CREATE FUNCTION array_to_sparsevec(integer[], integer, boolean) RETURNS sparsevec
786+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
787+
788+
CREATE FUNCTION array_to_sparsevec(real[], integer, boolean) RETURNS sparsevec
789+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
790+
791+
CREATE FUNCTION array_to_sparsevec(double precision[], integer, boolean) RETURNS sparsevec
792+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
793+
794+
CREATE FUNCTION array_to_sparsevec(numeric[], integer, boolean) RETURNS sparsevec
795+
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
796+
785797
-- sparsevec casts
786798

787799
CREATE CAST (sparsevec AS sparsevec)
@@ -799,6 +811,18 @@ CREATE CAST (sparsevec AS halfvec)
799811
CREATE CAST (halfvec AS sparsevec)
800812
WITH FUNCTION halfvec_to_sparsevec(halfvec, integer, boolean) AS IMPLICIT;
801813

814+
CREATE CAST (integer[] AS sparsevec)
815+
WITH FUNCTION array_to_sparsevec(integer[], integer, boolean) AS ASSIGNMENT;
816+
817+
CREATE CAST (real[] AS sparsevec)
818+
WITH FUNCTION array_to_sparsevec(real[], integer, boolean) AS ASSIGNMENT;
819+
820+
CREATE CAST (double precision[] AS sparsevec)
821+
WITH FUNCTION array_to_sparsevec(double precision[], integer, boolean) AS ASSIGNMENT;
822+
823+
CREATE CAST (numeric[] AS sparsevec)
824+
WITH FUNCTION array_to_sparsevec(numeric[], integer, boolean) AS ASSIGNMENT;
825+
802826
-- sparsevec operators
803827

804828
CREATE OPERATOR <-> (

src/sparsevec.c

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <limits.h>
44
#include <math.h>
55

6+
#include "catalog/pg_type.h"
67
#include "common/string.h"
78
#include "fmgr.h"
89
#include "halfutils.h"
@@ -11,6 +12,7 @@
1112
#include "sparsevec.h"
1213
#include "utils/array.h"
1314
#include "utils/builtins.h"
15+
#include "utils/lsyscache.h"
1416
#include "vector.h"
1517

1618
#if PG_VERSION_NUM >= 120000
@@ -670,6 +672,126 @@ halfvec_to_sparsevec(PG_FUNCTION_ARGS)
670672
PG_RETURN_POINTER(result);
671673
}
672674

675+
/*
676+
* Convert array to sparse vector
677+
*/
678+
FUNCTION_PREFIX PG_FUNCTION_INFO_V1(array_to_sparsevec);
679+
Datum
680+
array_to_sparsevec(PG_FUNCTION_ARGS)
681+
{
682+
ArrayType *array = PG_GETARG_ARRAYTYPE_P(0);
683+
int32 typmod = PG_GETARG_INT32(1);
684+
SparseVector *result;
685+
int16 typlen;
686+
bool typbyval;
687+
char typalign;
688+
Datum *elemsp;
689+
int nelemsp;
690+
int nnz = 0;
691+
float *values;
692+
int j = 0;
693+
694+
if (ARR_NDIM(array) > 1)
695+
ereport(ERROR,
696+
(errcode(ERRCODE_DATA_EXCEPTION),
697+
errmsg("array must be 1-D")));
698+
699+
if (ARR_HASNULL(array) && array_contains_nulls(array))
700+
ereport(ERROR,
701+
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
702+
errmsg("array must not contain nulls")));
703+
704+
get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign);
705+
deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp);
706+
707+
CheckDim(nelemsp);
708+
CheckExpectedDim(typmod, nelemsp);
709+
710+
if (ARR_ELEMTYPE(array) == INT4OID)
711+
{
712+
for (int i = 0; i < nelemsp; i++)
713+
nnz += ((float) DatumGetInt32(elemsp[i])) != 0;
714+
}
715+
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
716+
{
717+
for (int i = 0; i < nelemsp; i++)
718+
nnz += ((float) DatumGetFloat8(elemsp[i])) != 0;
719+
}
720+
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
721+
{
722+
for (int i = 0; i < nelemsp; i++)
723+
nnz += (DatumGetFloat4(elemsp[i]) != 0);
724+
}
725+
else if (ARR_ELEMTYPE(array) == NUMERICOID)
726+
{
727+
for (int i = 0; i < nelemsp; i++)
728+
nnz += (DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])) != 0);
729+
}
730+
else
731+
{
732+
ereport(ERROR,
733+
(errcode(ERRCODE_DATA_EXCEPTION),
734+
errmsg("unsupported array type")));
735+
}
736+
737+
result = InitSparseVector(nelemsp, nnz);
738+
values = SPARSEVEC_VALUES(result);
739+
740+
#define PROCESS_ARRAY_ELEM(elem) \
741+
do { \
742+
float v = (float) (elem); \
743+
if (v != 0) { \
744+
/* Safety check */ \
745+
if (j >= result->nnz) \
746+
elog(ERROR, "safety check failed"); \
747+
result->indices[j] = i; \
748+
values[j] = v; \
749+
j++; \
750+
} \
751+
} while (0)
752+
753+
if (ARR_ELEMTYPE(array) == INT4OID)
754+
{
755+
for (int i = 0; i < nelemsp; i++)
756+
PROCESS_ARRAY_ELEM(DatumGetInt32(elemsp[i]));
757+
}
758+
else if (ARR_ELEMTYPE(array) == FLOAT8OID)
759+
{
760+
for (int i = 0; i < nelemsp; i++)
761+
PROCESS_ARRAY_ELEM(DatumGetFloat8(elemsp[i]));
762+
}
763+
else if (ARR_ELEMTYPE(array) == FLOAT4OID)
764+
{
765+
for (int i = 0; i < nelemsp; i++)
766+
PROCESS_ARRAY_ELEM(DatumGetFloat4(elemsp[i]));
767+
}
768+
else if (ARR_ELEMTYPE(array) == NUMERICOID)
769+
{
770+
for (int i = 0; i < nelemsp; i++)
771+
PROCESS_ARRAY_ELEM(DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])));
772+
}
773+
else
774+
{
775+
ereport(ERROR,
776+
(errcode(ERRCODE_DATA_EXCEPTION),
777+
errmsg("unsupported array type")));
778+
}
779+
780+
#undef PROCESS_ARRAY_ELEM
781+
782+
/*
783+
* Free allocation from deconstruct_array. Do not free individual elements
784+
* when pass-by-reference since they point to original array.
785+
*/
786+
pfree(elemsp);
787+
788+
/* Check elements */
789+
for (int i = 0; i < result->nnz; i++)
790+
CheckElement(values[i]);
791+
792+
PG_RETURN_POINTER(result);
793+
}
794+
673795
/*
674796
* Get the L2 squared distance between sparse vectors
675797
*/

test/expected/cast.out

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,62 @@ SELECT '{1:1e-8}/1'::sparsevec::halfvec;
208208
[0]
209209
(1 row)
210210

211+
SELECT ARRAY[1,0,2,0,3,0]::sparsevec;
212+
array
213+
-----------------
214+
{1:1,3:2,5:3}/6
215+
(1 row)
216+
217+
SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec;
218+
array
219+
-----------------
220+
{1:1,3:2,5:3}/6
221+
(1 row)
222+
223+
SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec;
224+
array
225+
-----------------
226+
{1:1,3:2,5:3}/6
227+
(1 row)
228+
229+
SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec;
230+
array
231+
-----------------
232+
{1:1,3:2,5:3}/6
233+
(1 row)
234+
235+
SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec;
236+
array
237+
-----------------
238+
{1:1,3:2,5:3}/6
239+
(1 row)
240+
241+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec;
242+
sparsevec
243+
-----------------
244+
{1:1,3:2,5:3}/6
245+
(1 row)
246+
247+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6);
248+
sparsevec
249+
-----------------
250+
{1:1,3:2,5:3}/6
251+
(1 row)
252+
253+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5);
254+
ERROR: expected 5 dimensions, not 6
255+
SELECT '{NULL}'::real[]::sparsevec;
256+
ERROR: array must not contain nulls
257+
SELECT '{NaN}'::real[]::sparsevec;
258+
ERROR: NaN not allowed in sparsevec
259+
SELECT '{Infinity}'::real[]::sparsevec;
260+
ERROR: infinite value not allowed in sparsevec
261+
SELECT '{-Infinity}'::real[]::sparsevec;
262+
ERROR: infinite value not allowed in sparsevec
263+
SELECT '{}'::real[]::sparsevec;
264+
ERROR: sparsevec must have at least 1 dimension
265+
SELECT '{{1}}'::real[]::sparsevec;
266+
ERROR: array must be 1-D
211267
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
212268
ERROR: vector cannot have more than 16000 dimensions
213269
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;

test/sql/cast.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@ SELECT '{}/16001'::sparsevec::halfvec;
5858
SELECT '{1:65520}/1'::sparsevec::halfvec;
5959
SELECT '{1:1e-8}/1'::sparsevec::halfvec;
6060

61+
SELECT ARRAY[1,0,2,0,3,0]::sparsevec;
62+
SELECT ARRAY[1.0,0.0,2.0,0.0,3.0,0.0]::sparsevec;
63+
SELECT ARRAY[1,0,2,0,3,0]::float4[]::sparsevec;
64+
SELECT ARRAY[1,0,2,0,3,0]::float8[]::sparsevec;
65+
SELECT ARRAY[1,0,2,0,3,0]::numeric[]::sparsevec;
66+
67+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec;
68+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(6);
69+
SELECT '{1,0,2,0,3,0}'::real[]::sparsevec(5);
70+
SELECT '{NULL}'::real[]::sparsevec;
71+
SELECT '{NaN}'::real[]::sparsevec;
72+
SELECT '{Infinity}'::real[]::sparsevec;
73+
SELECT '{-Infinity}'::real[]::sparsevec;
74+
SELECT '{}'::real[]::sparsevec;
75+
SELECT '{{1}}'::real[]::sparsevec;
76+
6177
SELECT array_agg(n)::vector FROM generate_series(1, 16001) n;
6278
SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n;
6379

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy