From e350968b8ef47c84b70ab727f54790442ae8d958 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sun, 9 Feb 2025 23:31:32 -0800 Subject: [PATCH 1/2] Changed vector type to return Vector class instead of NumPy array [skip ci] --- pgvector/django/vector.py | 29 ++++++---------------- pgvector/vector.py | 8 +++--- tests/test_asyncpg.py | 24 +++++++++--------- tests/test_django.py | 21 ++++++++-------- tests/test_peewee.py | 13 +++++----- tests/test_psycopg.py | 38 ++++++++++++++-------------- tests/test_psycopg2.py | 22 ++++++++--------- tests/test_sqlalchemy.py | 52 +++++++++++++++++++-------------------- tests/test_sqlmodel.py | 16 ++++++------ 9 files changed, 99 insertions(+), 124 deletions(-) diff --git a/pgvector/django/vector.py b/pgvector/django/vector.py index 861cfde..09173fa 100644 --- a/pgvector/django/vector.py +++ b/pgvector/django/vector.py @@ -1,6 +1,5 @@ from django import forms from django.db.models import Field -import numpy as np from .. import Vector @@ -28,9 +27,12 @@ def from_db_value(self, value, expression, connection): return Vector._from_db(value) def to_python(self, value): - if isinstance(value, list): - return np.array(value, dtype=np.float32) - return Vector._from_db(value) + if value is None or isinstance(value, Vector): + return value + elif isinstance(value, str): + return Vector._from_db(value) + else: + return Vector(value) def get_prep_value(self, value): return Vector._to_db(value) @@ -38,35 +40,20 @@ def get_prep_value(self, value): def value_to_string(self, obj): return self.get_prep_value(self.value_from_object(obj)) - def validate(self, value, model_instance): - if isinstance(value, np.ndarray): - value = value.tolist() - super().validate(value, model_instance) - - def run_validators(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() - super().run_validators(value) - def formfield(self, **kwargs): return super().formfield(form_class=VectorFormField, **kwargs) class VectorWidget(forms.TextInput): def format_value(self, value): - if isinstance(value, np.ndarray): - value = value.tolist() + if isinstance(value, Vector): + value = value.to_list() return super().format_value(value) class VectorFormField(forms.CharField): widget = VectorWidget - def has_changed(self, initial, data): - if isinstance(initial, np.ndarray): - initial = initial.tolist() - return super().has_changed(initial, data) - def to_python(self, value): if isinstance(value, str) and value == '': return None diff --git a/pgvector/vector.py b/pgvector/vector.py index ebbcafd..7b0304a 100644 --- a/pgvector/vector.py +++ b/pgvector/vector.py @@ -70,14 +70,14 @@ def _to_db_binary(cls, value): @classmethod def _from_db(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_text(value).to_numpy().astype(np.float32) + return cls.from_text(value) @classmethod def _from_db_binary(cls, value): - if value is None or isinstance(value, np.ndarray): + if value is None or isinstance(value, cls): return value - return cls.from_binary(value).to_numpy().astype(np.float32) + return cls.from_binary(value) diff --git a/tests/test_asyncpg.py b/tests/test_asyncpg.py index 3c36048..e9aa836 100644 --- a/tests/test_asyncpg.py +++ b/tests/test_asyncpg.py @@ -1,6 +1,6 @@ import asyncpg import numpy as np -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.asyncpg import register_vector import pytest @@ -15,12 +15,11 @@ async def test_vector(self): await register_vector(conn) - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding) - assert res[0]['embedding'].dtype == np.float32 + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -38,11 +37,11 @@ async def test_halfvec(self): await register_vector(conn) - embedding = [1.5, 2, 3] + embedding = HalfVector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert res[0]['embedding'].to_list() == [1.5, 2, 3] + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -87,7 +86,7 @@ async def test_sparsevec(self): await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert res[0]['embedding'].to_list() == [1.5, 2, 3] + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None # ensures binary format is correct @@ -105,12 +104,12 @@ async def test_vector_array(self): await register_vector(conn) - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings[0], embeddings[1]) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embeddings'][0], embeddings[0]) - assert np.array_equal(res[0]['embeddings'][1], embeddings[1]) + assert res[0]['embeddings'][0] == embeddings[0] + assert res[0]['embeddings'][1] == embeddings[1] await conn.close() @@ -126,10 +125,9 @@ async def init(conn): await conn.execute('DROP TABLE IF EXISTS asyncpg_items') await conn.execute('CREATE TABLE asyncpg_items (id bigserial PRIMARY KEY, embedding vector(3))') - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), (NULL)", embedding) res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id") - assert np.array_equal(res[0]['embedding'], embedding) - assert res[0]['embedding'].dtype == np.float32 + assert res[0]['embedding'] == embedding assert res[1]['embedding'] is None diff --git a/tests/test_django.py b/tests/test_django.py index 65082a3..7182e3a 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -12,7 +12,7 @@ import numpy as np import os import pgvector.django -from pgvector import HalfVector, SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance from unittest import mock @@ -165,12 +165,11 @@ def setup_method(self): def test_vector(self): Item(id=1, embedding=[1, 2, 3]).save() item = Item.objects.get(pk=1) - assert np.array_equal(item.embedding, np.array([1, 2, 3])) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() - distance = L2Distance('embedding', [1, 1, 1]) + distance = L2Distance('embedding', Vector([1, 1, 1])) items = Item.objects.annotate(distance=distance).order_by(distance) assert [v.id for v in items] == [1, 3, 2] assert [v.distance for v in items] == [0, 1, sqrt(3)] @@ -293,7 +292,7 @@ def test_vector_avg(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] - assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] @@ -301,7 +300,7 @@ def test_vector_sum(self): Item(embedding=[1, 2, 3]).save() Item(embedding=[4, 5, 6]).save() sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] - assert np.array_equal(sum, np.array([5, 7, 9])) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] @@ -309,7 +308,7 @@ def test_halfvec_avg(self): Item(half_embedding=[1, 2, 3]).save() Item(half_embedding=[4, 5, 6]).save() avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] - assert avg.to_list() == [2.5, 3.5, 4.5] + assert avg == HalfVector([2.5, 3.5, 4.5]) def test_halfvec_sum(self): sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] @@ -317,7 +316,7 @@ def test_halfvec_sum(self): Item(half_embedding=[1, 2, 3]).save() Item(half_embedding=[4, 5, 6]).save() sum = Item.objects.aggregate(Sum('half_embedding'))['half_embedding__sum'] - assert sum.to_list() == [5, 7, 9] + assert sum == HalfVector([5, 7, 9]) def test_serialization(self): create_items() @@ -347,7 +346,7 @@ def test_vector_form_save(self): assert form.has_changed() assert form.is_valid() assert form.save() - assert [4, 5, 6] == Item.objects.get(pk=1).embedding.tolist() + assert [4, 5, 6] == Item.objects.get(pk=1).embedding.to_list() def test_vector_form_save_missing(self): Item(id=1).save() @@ -465,8 +464,8 @@ def test_vector_array(self): # this fails if the driver does not cast arrays item = Item.objects.get(pk=1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0].to_list() == [1, 2, 3] + assert item.embeddings[1].to_list() == [4, 5, 6] def test_double_array(self): Item(id=1, double_embedding=[1, 1, 1]).save() diff --git a/tests/test_peewee.py b/tests/test_peewee.py index 670d880..42b7787 100644 --- a/tests/test_peewee.py +++ b/tests/test_peewee.py @@ -1,7 +1,7 @@ from math import sqrt import numpy as np from peewee import Model, PostgresqlDatabase, fn -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField db = PostgresqlDatabase('pgvector_python_test') @@ -43,8 +43,7 @@ def setup_method(self): def test_vector(self): Item.create(id=1, embedding=[1, 2, 3]) item = Item.get_by_id(1) - assert np.array_equal(item.embedding, np.array([1, 2, 3])) - assert item.embedding.dtype == np.float32 + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -170,7 +169,7 @@ def test_vector_avg(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar() - assert np.array_equal(avg, np.array([2.5, 3.5, 4.5])) + assert avg == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() @@ -178,7 +177,7 @@ def test_vector_sum(self): Item.create(embedding=[1, 2, 3]) Item.create(embedding=[4, 5, 6]) sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar() - assert np.array_equal(sum, np.array([5, 7, 9])) + assert sum == Vector([5, 7, 9]) def test_halfvec_avg(self): avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() @@ -186,7 +185,7 @@ def test_halfvec_avg(self): Item.create(half_embedding=[1, 2, 3]) Item.create(half_embedding=[4, 5, 6]) avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar() - assert avg.to_list() == [2.5, 3.5, 4.5] + assert avg == HalfVector([2.5, 3.5, 4.5]) def test_halfvec_sum(self): sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() @@ -194,7 +193,7 @@ def test_halfvec_sum(self): Item.create(half_embedding=[1, 2, 3]) Item.create(half_embedding=[4, 5, 6]) sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar() - assert sum.to_list() == [5, 7, 9] + assert sum == HalfVector([5, 7, 9]) def test_get_or_create(self): Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]}) diff --git a/tests/test_psycopg.py b/tests/test_psycopg.py index 6a9d0b7..af658b2 100644 --- a/tests/test_psycopg.py +++ b/tests/test_psycopg.py @@ -23,19 +23,18 @@ def test_vector(self): conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) res = conn.execute('SELECT embedding FROM psycopg_items ORDER BY id').fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_binary_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_text_format(self): embedding = np.array([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, embedding) + assert res == Vector(embedding) def test_vector_binary_format_correct(self): embedding = np.array([1.5, 2, 3]) @@ -46,23 +45,23 @@ def test_vector_text_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([3, 2, 1.5])) + assert res == Vector([3, 2, 1.5]) def test_vector_binary_format_non_contiguous(self): embedding = np.flipud(np.array([1.5, 2, 3])) assert not embedding.data.contiguous res = conn.execute('SELECT %b::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([3, 2, 1.5])) + assert res == Vector([3, 2, 1.5]) def test_vector_class_binary_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] - assert np.array_equal(res, np.array([1.5, 2, 3])) + assert res == embedding def test_vector_class_text_format(self): embedding = Vector([1.5, 2, 3]) res = conn.execute('SELECT %t::vector', (embedding,)).fetchone()[0] - assert np.array_equal(res, np.array([1.5, 2, 3])) + assert res == embedding def test_halfvec(self): embedding = HalfVector([1.5, 2, 3]) @@ -156,7 +155,7 @@ def test_text_copy_to(self): assert row[1] == "[1.5,2,3]" def test_binary_copy_to(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() @@ -166,23 +165,23 @@ def test_binary_copy_to(self): assert HalfVector.from_binary(row[1]).to_list() == [1.5, 2, 3] def test_binary_copy_to_set_types(self): - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) half_embedding = HalfVector([1.5, 2, 3]) conn.execute('INSERT INTO psycopg_items (embedding, half_embedding) VALUES (%s, %s)', (embedding, half_embedding)) cur = conn.cursor() with cur.copy("COPY psycopg_items (embedding, half_embedding) TO STDOUT WITH (FORMAT BINARY)") as copy: copy.set_types(['vector', 'halfvec']) for row in copy.rows(): - assert np.array_equal(row[0], embedding) - assert row[1].to_list() == [1.5, 2, 3] + assert row[0] == embedding + assert row[1] == half_embedding def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] conn.execute('INSERT INTO psycopg_items (embeddings) VALUES (%s)', (embeddings,)) res = conn.execute('SELECT embeddings FROM psycopg_items ORDER BY id').fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_pool(self): def configure(conn): @@ -192,7 +191,7 @@ def configure(conn): with pool.connection() as conn: res = conn.execute("SELECT '[1,2,3]'::vector").fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) pool.close() @@ -206,14 +205,13 @@ async def test_async(self): await register_vector_async(conn) - embedding = np.array([1.5, 2, 3]) + embedding = Vector([1.5, 2, 3]) await conn.execute('INSERT INTO psycopg_items (embedding) VALUES (%s), (NULL)', (embedding,)) async with conn.cursor() as cur: await cur.execute('SELECT * FROM psycopg_items ORDER BY id') res = await cur.fetchall() - assert np.array_equal(res[0][1], embedding) - assert res[0][1].dtype == np.float32 + assert res[0][1] == embedding assert res[1][1] is None @pytest.mark.asyncio @@ -228,6 +226,6 @@ async def configure(conn): async with conn.cursor() as cur: await cur.execute("SELECT '[1,2,3]'::vector") res = await cur.fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) await pool.close() diff --git a/tests/test_psycopg2.py b/tests/test_psycopg2.py index 1994c87..8ef9911 100644 --- a/tests/test_psycopg2.py +++ b/tests/test_psycopg2.py @@ -26,8 +26,7 @@ def test_vector(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding) - assert res[0][0].dtype == np.float32 + assert res[0][0] == Vector(embedding) assert res[1][0] is None def test_vector_class(self): @@ -36,17 +35,16 @@ def test_vector_class(self): cur.execute('SELECT embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert np.array_equal(res[0][0], embedding.to_numpy()) - assert res[0][0].dtype == np.float32 + assert res[0][0] == embedding assert res[1][0] is None def test_halfvec(self): - embedding = [1.5, 2, 3] + embedding = HalfVector([1.5, 2, 3]) cur.execute('INSERT INTO psycopg2_items (half_embedding) VALUES (%s), (NULL)', (embedding,)) cur.execute('SELECT half_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0].to_list() == [1.5, 2, 3] + assert res[0][0] == embedding assert res[1][0] is None def test_bit(self): @@ -55,7 +53,7 @@ def test_bit(self): cur.execute('SELECT binary_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0] == '101' + assert res[0][0] == embedding assert res[1][0] is None def test_sparsevec(self): @@ -64,17 +62,17 @@ def test_sparsevec(self): cur.execute('SELECT sparse_embedding FROM psycopg2_items ORDER BY id') res = cur.fetchall() - assert res[0][0].to_list() == [1.5, 2, 3] + assert res[0][0] == embedding assert res[1][0] is None def test_vector_array(self): - embeddings = [np.array([1.5, 2, 3]), np.array([4.5, 5, 6])] + embeddings = [Vector([1.5, 2, 3]), Vector([4.5, 5, 6])] cur.execute('INSERT INTO psycopg2_items (embeddings) VALUES (%s::vector[])', (embeddings,)) cur.execute('SELECT embeddings FROM psycopg2_items ORDER BY id') res = cur.fetchone() - assert np.array_equal(res[0][0], embeddings[0]) - assert np.array_equal(res[0][1], embeddings[1]) + assert res[0][0] == embeddings[0] + assert res[0][1] == embeddings[1] def test_halfvec_array(self): embeddings = [HalfVector([1.5, 2, 3]), HalfVector([4.5, 5, 6])] @@ -122,7 +120,7 @@ def test_pool(self): cur = conn.cursor() cur.execute("SELECT '[1,2,3]'::vector") res = cur.fetchone() - assert np.array_equal(res[0], np.array([1, 2, 3])) + assert res[0] == Vector([1, 2, 3]) finally: pool.putconn(conn) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index 052edd7..f7203f1 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -1,7 +1,7 @@ import asyncpg import numpy as np import os -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum import pytest from sqlalchemy import create_engine, event, insert, inspect, select, text, MetaData, Table, Column, Index, Integer, ARRAY @@ -190,10 +190,8 @@ def test_orm(self, engine): assert items[0].id % 3 == 1 assert items[1].id % 3 == 2 assert items[2].id % 3 == 0 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self, engine): @@ -201,7 +199,7 @@ def test_vector(self, engine): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self, engine): create_items() @@ -256,7 +254,7 @@ def test_halfvec(self, engine): session.add(Item(id=1, half_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.half_embedding.to_list() == [1, 2, 3] + assert item.half_embedding == HalfVector([1, 2, 3]) def test_halfvec_l2_distance(self, engine): create_items() @@ -348,7 +346,7 @@ def test_sparsevec(self, engine): session.add(Item(id=1, sparse_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.sparse_embedding.to_list() == [1, 2, 3] + assert item.sparse_embedding == SparseVector([1, 2, 3]) def test_sparsevec_l2_distance(self, engine): create_items() @@ -429,7 +427,7 @@ def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(avg(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_avg_orm(self, engine): with Session(engine) as session: @@ -438,7 +436,7 @@ def test_avg_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_sum(self, engine): with Session(engine) as session: @@ -447,7 +445,7 @@ def test_sum(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.query(sum(Item.embedding)).first()[0] - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_sum_orm(self, engine): with Session(engine) as session: @@ -456,7 +454,7 @@ def test_sum_orm(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.scalars(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_bad_dimensions(self, engine): item = Item(embedding=[1, 2]) @@ -509,7 +507,7 @@ def test_automap(self, engine): with Session(engine) as session: session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}]) item = session.query(AutoItem).first() - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_half_precision(self, engine): create_items() @@ -541,8 +539,8 @@ def test_vector_array(self, engine): # this fails if the driver does not cast arrays item = session.get(Item, 1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0] == Vector([1, 2, 3]) + assert item.embeddings[1] == Vector([4, 5, 6]) def test_halfvec_array(self, engine): with Session(engine) as session: @@ -551,8 +549,8 @@ def test_halfvec_array(self, engine): # this fails if the driver does not cast arrays item = session.get(Item, 1) - assert item.half_embeddings[0].to_list() == [1, 2, 3] - assert item.half_embeddings[1].to_list() == [4, 5, 6] + assert item.half_embeddings[0] == HalfVector([1, 2, 3]) + assert item.half_embeddings[1] == HalfVector([4, 5, 6]) @pytest.mark.parametrize('engine', async_engines) @@ -566,10 +564,10 @@ async def test_vector(self, engine): async with async_session() as session: async with session.begin(): - embedding = np.array([1, 2, 3]) + embedding = Vector([1, 2, 3]) session.add(Item(id=1, embedding=embedding)) item = await session.get(Item, 1) - assert np.array_equal(item.embedding, embedding) + assert item.embedding == embedding await engine.dispose() @@ -579,10 +577,10 @@ async def test_halfvec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = HalfVector([1, 2, 3]) session.add(Item(id=1, half_embedding=embedding)) item = await session.get(Item, 1) - assert item.half_embedding.to_list() == embedding + assert item.half_embedding == embedding await engine.dispose() @@ -605,10 +603,10 @@ async def test_sparsevec(self, engine): async with async_session() as session: async with session.begin(): - embedding = [1, 2, 3] + embedding = SparseVector([1, 2, 3]) session.add(Item(id=1, sparse_embedding=embedding)) item = await session.get(Item, 1) - assert item.sparse_embedding.to_list() == embedding + assert item.sparse_embedding == embedding await engine.dispose() @@ -621,7 +619,7 @@ async def test_avg(self, engine): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = await session.scalars(select(avg(Item.embedding))) - assert res.first().tolist() == [2.5, 3.5, 4.5] + assert res.first() == Vector([2.5, 3.5, 4.5]) await engine.dispose() @@ -637,9 +635,9 @@ async def test_vector_array(self, engine): async with async_session() as session: async with session.begin(): - session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])])) + session.add(Item(id=1, embeddings=[Vector([1, 2, 3]), Vector([4, 5, 6])])) item = await session.get(Item, 1) - assert item.embeddings[0].tolist() == [1, 2, 3] - assert item.embeddings[1].tolist() == [4, 5, 6] + assert item.embeddings[0] == Vector([1, 2, 3]) + assert item.embeddings[1] == Vector([4, 5, 6]) await engine.dispose() diff --git a/tests/test_sqlmodel.py b/tests/test_sqlmodel.py index b0e8ccd..8f3e268 100644 --- a/tests/test_sqlmodel.py +++ b/tests/test_sqlmodel.py @@ -1,5 +1,5 @@ import numpy as np -from pgvector import SparseVector +from pgvector import Vector, HalfVector, SparseVector from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, avg, sum import pytest from sqlalchemy.exc import StatementError @@ -65,10 +65,8 @@ def test_orm(self): assert items[0].id == 1 assert items[1].id == 2 assert items[2].id == 3 - assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) - assert items[0].embedding.dtype == np.float32 - assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) - assert items[1].embedding.dtype == np.float32 + assert items[0].embedding == Vector([1.5, 2, 3]) + assert items[1].embedding == Vector([4, 5, 6]) assert items[2].embedding is None def test_vector(self): @@ -76,7 +74,7 @@ def test_vector(self): session.add(Item(id=1, embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.embedding.tolist() == [1, 2, 3] + assert item.embedding == Vector([1, 2, 3]) def test_vector_l2_distance(self): create_items() @@ -107,7 +105,7 @@ def test_halfvec(self): session.add(Item(id=1, half_embedding=[1, 2, 3])) session.commit() item = session.get(Item, 1) - assert item.half_embedding.to_list() == [1, 2, 3] + assert item.half_embedding == HalfVector([1, 2, 3]) def test_halfvec_l2_distance(self): create_items() @@ -202,7 +200,7 @@ def test_vector_avg(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(avg(Item.embedding))).first() - assert np.array_equal(res, np.array([2.5, 3.5, 4.5])) + assert res == Vector([2.5, 3.5, 4.5]) def test_vector_sum(self): with Session(engine) as session: @@ -211,7 +209,7 @@ def test_vector_sum(self): session.add(Item(embedding=[1, 2, 3])) session.add(Item(embedding=[4, 5, 6])) res = session.exec(select(sum(Item.embedding))).first() - assert np.array_equal(res, np.array([5, 7, 9])) + assert res == Vector([5, 7, 9]) def test_halfvec_avg(self): with Session(engine) as session: From 3f44e8aa397de40ba64c54da166fee7462eb439f Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 10 Feb 2025 19:35:44 -0800 Subject: [PATCH 2/2] Use consistent style [skip ci] --- tests/test_django.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_django.py b/tests/test_django.py index ae42d5b..eff5a98 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -12,7 +12,7 @@ import numpy as np import os import pgvector.django -from pgvector import Vector, HalfVector, SparseVector +from pgvector import HalfVector, SparseVector, Vector from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance from unittest import mock pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy