From be7052edac29bfc1f5b1f44b3188df50908ed533 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 08:45:58 -0500 Subject: [PATCH 01/14] working postgres --- pgml-extension/Dockerfile.local | 85 ++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 16 deletions(-) diff --git a/pgml-extension/Dockerfile.local b/pgml-extension/Dockerfile.local index 3df89c787..4b9859af6 100644 --- a/pgml-extension/Dockerfile.local +++ b/pgml-extension/Dockerfile.local @@ -1,5 +1,9 @@ FROM ubuntu:jammy -MAINTAINER team@postgresml.com +LABEL maintainer="team@postgresml.com" + +ARG PG_MAJOR_VER +ENV PG_MAJOR_VER=${PG_MAJOR_VER} + RUN apt-get update ARG DEBIAN_FRONTEND=noninteractive ENV TZ=Etc/UTC @@ -23,27 +27,76 @@ RUN apt-get update && apt-fast install -y \ libpq-dev \ libclang-dev \ wget \ - postgresql-plpython3-14 \ - postgresql-14 \ - postgresql-server-dev-14 + postgresql-plpython3-$PG_MAJOR_VER \ + postgresql-$PG_MAJOR_VER \ + postgresql-server-dev-$PG_MAJOR_VER + + RUN add-apt-repository ppa:deadsnakes/ppa --yes RUN apt update && apt-fast install -y \ python3.10 \ python3-pip \ libpython3.10-dev \ python3.10-dev -RUN useradd postgresml -m -s /bin/bash -G sudo -RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers -USER postgresml -RUN curl https://sh.rustup.rs -sSf | sh -s -- -y -RUN $HOME/.cargo/bin/cargo install cargo-pgrx --version "0.8.2" --locked -RUN $HOME/.cargo/bin/cargo pgrx init -RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg >/dev/null -RUN sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' -RUN sudo apt update -RUN sudo apt-get install -y postgresql-15 postgresql-13 postgresql-12 postgresql-11 -RUN sudo apt install -y postgresql-server-dev-15 postgresql-server-dev-15 postgresql-server-dev-12 postgresql-server-dev-11 + + +RUN echo 'postgres ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers + +# COPY --chown=postgres:postgres ./pgml-extension /app +# WORKDIR /app + +COPY requirements.txt /app/requirements.txt + + +RUN pip3 install -r /app/requirements.txt + + +# Running pgrx and tests require a non-root user WORKDIR /app -RUN pip3 install -r requirements.txt +# Running pgrx and tests require a non-root user +RUN useradd --create-home --shell /bin/bash rust + +USER rust +ENV USER=rust + +# Install cargo and Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +ENV PATH="/home/rust/.cargo/bin:${PATH}" + +# RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg >/dev/null +# RUN sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + + +# Install pgrx +RUN cargo install cargo-pgrx --version "0.8.2" --locked +# RUN $HOME/.cargo/bin/cargo install pgrx --version "0.8.2" --locked + + + +COPY --chown=postgres:postgres ./ /app + +USER postgres + +RUN sudo cp /app/docker/postgresql.conf /etc/postgresql/$PG_MAJOR_VER/main/postgresql.conf +RUN sudo cp /app/docker/pg_hba.conf /etc/postgresql/$PG_MAJOR_VER/main/pg_hba.conf + + +# RUN sudo chown -R rust:rust /usr/share/postgresql/$PG_MAJOR_VER/extension +RUN sudo chown -R postgres:postgres /usr/share/postgresql/$PG_MAJOR_VER/extension + +# commenting this three make things work +# USER rust + +RUN cargo pgrx init --pg$PG_MAJOR_VER=$(which pg_config) +RUN cargo pgrx install --pg-config $(which pg_config) + +EXPOSE 5432 + +USER postgres + + +# ENTRYPOINT ["/bin/bash", "/app/docker/entrypoint.sh"] +CMD ["/usr/lib/postgresql/14/bin/postgres", "-c", "config_file=/etc/postgresql/14/main/postgresql.conf"] + From 2d949fbc6a41a882bb6bac4ebc3e193c524b84f2 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 11:54:09 -0500 Subject: [PATCH 02/14] Fixing Dockerfile.local for local deployment --- .../Dockerfile.local => Dockerfile.local | 45 +++++++++---------- docker-compose.local.yml | 18 ++++++++ 2 files changed, 40 insertions(+), 23 deletions(-) rename pgml-extension/Dockerfile.local => Dockerfile.local (55%) create mode 100644 docker-compose.local.yml diff --git a/pgml-extension/Dockerfile.local b/Dockerfile.local similarity index 55% rename from pgml-extension/Dockerfile.local rename to Dockerfile.local index 4b9859af6..60e461612 100644 --- a/pgml-extension/Dockerfile.local +++ b/Dockerfile.local @@ -1,3 +1,4 @@ + FROM ubuntu:jammy LABEL maintainer="team@postgresml.com" @@ -29,8 +30,8 @@ RUN apt-get update && apt-fast install -y \ wget \ postgresql-plpython3-$PG_MAJOR_VER \ postgresql-$PG_MAJOR_VER \ - postgresql-server-dev-$PG_MAJOR_VER - + postgresql-server-dev-$PG_MAJOR_VER \ + git RUN add-apt-repository ppa:deadsnakes/ppa --yes RUN apt update && apt-fast install -y \ @@ -42,10 +43,8 @@ RUN apt update && apt-fast install -y \ RUN echo 'postgres ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers -# COPY --chown=postgres:postgres ./pgml-extension /app -# WORKDIR /app -COPY requirements.txt /app/requirements.txt +COPY ./pgml-extension/requirements.txt /app/requirements.txt RUN pip3 install -r /app/requirements.txt @@ -54,49 +53,49 @@ RUN pip3 install -r /app/requirements.txt # Running pgrx and tests require a non-root user WORKDIR /app +RUN chmod a+rwx `$(which pg_config) --pkglibdir` \ + `$(which pg_config) --sharedir`/extension \ + /var/run/postgresql/ -# Running pgrx and tests require a non-root user -RUN useradd --create-home --shell /bin/bash rust +RUN useradd postgresml -m -s /bin/bash -G sudo +RUN echo 'postgresml ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers + +USER postgresml -USER rust -ENV USER=rust # Install cargo and Rust RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -ENV PATH="/home/rust/.cargo/bin:${PATH}" - -# RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor | sudo tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg >/dev/null -# RUN sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' +ENV PATH="/home/postgresml/.cargo/bin:${PATH}" # Install pgrx RUN cargo install cargo-pgrx --version "0.8.2" --locked -# RUN $HOME/.cargo/bin/cargo install pgrx --version "0.8.2" --locked +COPY --chown=postgresml:postgresml ./ /app +RUN sudo chown -R postgresml:postgresml /app +RUN git submodule update --init --recursive -COPY --chown=postgres:postgres ./ /app -USER postgres -RUN sudo cp /app/docker/postgresql.conf /etc/postgresql/$PG_MAJOR_VER/main/postgresql.conf -RUN sudo cp /app/docker/pg_hba.conf /etc/postgresql/$PG_MAJOR_VER/main/pg_hba.conf + +RUN sudo cp /app/pgml-extension/docker/postgresql.conf /etc/postgresql/$PG_MAJOR_VER/main/postgresql.conf +RUN sudo cp /app/pgml-extension/docker/pg_hba.conf /etc/postgresql/$PG_MAJOR_VER/main/pg_hba.conf # RUN sudo chown -R rust:rust /usr/share/postgresql/$PG_MAJOR_VER/extension -RUN sudo chown -R postgres:postgres /usr/share/postgresql/$PG_MAJOR_VER/extension +RUN sudo chown -R postgresml:postgresml /usr/share/postgresql/$PG_MAJOR_VER/ +RUN sudo chown -R postgresml:postgresml /usr/share/postgresql/$PG_MAJOR_VER/extension # commenting this three make things work # USER rust -RUN cargo pgrx init --pg$PG_MAJOR_VER=$(which pg_config) -RUN cargo pgrx install --pg-config $(which pg_config) +RUN cd /app/pgml-extension && cargo pgrx init --pg$PG_MAJOR_VER=$(which pg_config) +RUN cd /app/pgml-extension && cargo pgrx install --pg-config $(which pg_config) EXPOSE 5432 USER postgres - -# ENTRYPOINT ["/bin/bash", "/app/docker/entrypoint.sh"] CMD ["/usr/lib/postgresql/14/bin/postgres", "-c", "config_file=/etc/postgresql/14/main/postgresql.conf"] diff --git a/docker-compose.local.yml b/docker-compose.local.yml new file mode 100644 index 000000000..dca2caebd --- /dev/null +++ b/docker-compose.local.yml @@ -0,0 +1,18 @@ +version: "3" +# Run by doing docker compose -f docker-compose.local.yml build --build-arg PG_MAJOR_VER=14 +# docker compose -f docker-compose.local.yml up +services: + postgres: + healthcheck: + test: [ "CMD-SHELL", "pg_isready" ] + interval: 1s + timeout: 5s + retries: 100 + build: + context: . + dockerfile: Dockerfile.local + ports: + - "5433:5432" + command: + - sleep + - infinity From 1ebfd0160d7d3228e11a0d4a9fbb6c4e92b44521 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:03:06 -0500 Subject: [PATCH 03/14] removing command from compose --- docker-compose.local.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/docker-compose.local.yml b/docker-compose.local.yml index dca2caebd..67bd9f631 100644 --- a/docker-compose.local.yml +++ b/docker-compose.local.yml @@ -13,6 +13,3 @@ services: dockerfile: Dockerfile.local ports: - "5433:5432" - command: - - sleep - - infinity From b20ed8d33b0848d8893a40b53b7348fd9b9b2e54 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:11:51 -0500 Subject: [PATCH 04/14] adding inputs as an array for embeddings --- pgml-extension/src/api.rs | 5 +++ pgml-extension/src/bindings/transformers.py | 18 +++++++--- pgml-extension/src/bindings/transformers.rs | 38 +++++++++++++++++---- 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 914952e91..847c63f89 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -568,6 +568,11 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> crate::bindings::transformers::embed(transformer, text, &kwargs.0) } +#[pg_extern(immutable, parallel_safe)] +pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec { + crate::bindings::transformers::embed(transformer, &inputs, &kwargs.0) +} + #[pg_extern(immutable, parallel_safe)] pub fn chunk( splitter: &str, diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index f6d367f84..8c61791dc 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -110,13 +110,23 @@ def transform(task, args, inputs): return json.dumps(pipe(inputs, **args), cls=NumpyJSONEncoder) -def embed(transformer, text, kwargs): +def embed(transformer, inputs, kwargs): kwargs = json.loads(kwargs) ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") + if instructor: klass = INSTRUCTOR - text = [[kwargs.pop("instruction"), text]] + if isinstance(inputs, str): + inputs = [[kwargs.pop("instruction"), inputs]] + + else: + texts_with_instructions = [] + instruction = kwargs.pop("instruction") + for text in inputs: + texts_with_instructions.append([instruction, text]) + + inputs = texts_with_instructions else: klass = SentenceTransformer @@ -124,8 +134,8 @@ def embed(transformer, text, kwargs): __cache_sentence_transformer_by_name[transformer] = klass(transformer) model = __cache_sentence_transformer_by_name[transformer] - result = model.encode(text, **kwargs) - if instructor: + result = model.encode(inputs, **kwargs) + if instructor and len(result) == 1: result = result[0] return result diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 65c24bcd6..d14bee0bf 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -35,14 +35,13 @@ pub fn transform( let results = Python::with_gil(|py| -> String { let transform: Py = PY_MODULE.getattr(py, "transform").unwrap().into(); - let result = transform - .call1( + let result = transform.call1( + py, + PyTuple::new( py, - PyTuple::new( - py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), - ); + &[task.into_py(py), args.into_py(py), inputs.into_py(py)], + ), + ); let result = match result { Err(e) => { @@ -81,6 +80,31 @@ pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec, kwargs: &serde_json::Value) -> Vec> { + crate::bindings::venv::activate(); + + let kwargs = serde_json::to_string(kwargs).unwrap(); + let inputs = serde_json::to_string(inputs).unwrap(); + Python::with_gil(|py| -> Vec { + let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); + embed + .call1( + py, + PyTuple::new( + py, + &[ + transformer.to_string().into_py(py), + inputs.into_py(py), + kwargs.into_py(py), + ], + ), + ) + .unwrap() + .extract(py) + .unwrap() + }) +} + pub fn tune( task: &Task, dataset: TextDataset, From e0738dcfb4482b373bbeaf08372596596e31d14d Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:18:41 -0500 Subject: [PATCH 05/14] fixing function overloading issue --- pgml-extension/src/api.rs | 8 ++++++-- pgml-extension/src/bindings/transformers.rs | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 847c63f89..9b5b954a3 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -569,8 +569,12 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> } #[pg_extern(immutable, parallel_safe)] -pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec { - crate::bindings::transformers::embed(transformer, &inputs, &kwargs.0) +pub fn embed_array( + transformer: &str, + inputs: Vec<&str>, + kwargs: default!(JsonB, "'{}'"), +) -> Vec { + crate::bindings::transformers::embed_array(transformer, &inputs, &kwargs.0) } #[pg_extern(immutable, parallel_safe)] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index d14bee0bf..603e3b489 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -80,7 +80,11 @@ pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec, kwargs: &serde_json::Value) -> Vec> { +pub fn embed_array( + transformer: &str, + inputs: &Vec<&str>, + kwargs: &serde_json::Value, +) -> Vec> { crate::bindings::venv::activate(); let kwargs = serde_json::to_string(kwargs).unwrap(); From f614404426039d52810563fa78892bd26018f7ab Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:22:32 -0500 Subject: [PATCH 06/14] adding vec --- pgml-extension/src/api.rs | 2 +- pgml-extension/src/bindings/transformers.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 9b5b954a3..57571d60d 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -573,7 +573,7 @@ pub fn embed_array( transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'"), -) -> Vec { +) -> Vec> { crate::bindings::transformers::embed_array(transformer, &inputs, &kwargs.0) } diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 603e3b489..48e760035 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -89,7 +89,7 @@ pub fn embed_array( let kwargs = serde_json::to_string(kwargs).unwrap(); let inputs = serde_json::to_string(inputs).unwrap(); - Python::with_gil(|py| -> Vec { + Python::with_gil(|py| -> Vec> { let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); embed .call1( From 9425b63e6df9e54182d793027b12dc5d3fb7e9df Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:31:13 -0500 Subject: [PATCH 07/14] handling inputs --- pgml-extension/src/api.rs | 2 +- pgml-extension/src/bindings/transformers.py | 4 ++++ pgml-extension/src/bindings/transformers.rs | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 57571d60d..8fb9c2f0a 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -571,7 +571,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> #[pg_extern(immutable, parallel_safe)] pub fn embed_array( transformer: &str, - inputs: Vec<&str>, + inputs: default!(Vec, "ARRAY[]::TEXT[]"), kwargs: default!(JsonB, "'{}'"), ) -> Vec> { crate::bindings::transformers::embed_array(transformer, &inputs, &kwargs.0) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 8c61791dc..e10f6387b 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -111,6 +111,10 @@ def transform(task, args, inputs): def embed(transformer, inputs, kwargs): + + if not isinstance(inputs, str): + inputs = json.dumps(inputs) + kwargs = json.loads(kwargs) ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 48e760035..8b94119fe 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -82,7 +82,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec, + inputs: &Vec, kwargs: &serde_json::Value, ) -> Vec> { crate::bindings::venv::activate(); From 039da67bacb9da61e5b89ea1e281f51cf4aa8142 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 12:40:08 -0500 Subject: [PATCH 08/14] handling inputs loads instead of dumps --- pgml-extension/src/bindings/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index e10f6387b..723dde5cf 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -113,7 +113,7 @@ def transform(task, args, inputs): def embed(transformer, inputs, kwargs): if not isinstance(inputs, str): - inputs = json.dumps(inputs) + inputs = json.loads(inputs) kwargs = json.loads(kwargs) ensure_device(kwargs) From 03abdcd8001c77173f58020cd5dcc37b84fed299 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 14:53:59 -0500 Subject: [PATCH 09/14] fixing instance of json --- pgml-extension/src/bindings/transformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 723dde5cf..9c5d25856 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -112,8 +112,10 @@ def transform(task, args, inputs): def embed(transformer, inputs, kwargs): - if not isinstance(inputs, str): + try: inputs = json.loads(inputs) + except json.decoder.JSONDecodeError: + pass kwargs = json.loads(kwargs) ensure_device(kwargs) From 6ffabc97f77e213feded28551957953a515026fe Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 15:16:56 -0500 Subject: [PATCH 10/14] adding name so same funciton name --- pgml-extension/src/api.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 8fb9c2f0a..b5ab033b7 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -568,7 +568,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> crate::bindings::transformers::embed(transformer, text, &kwargs.0) } -#[pg_extern(immutable, parallel_safe)] +#[pg_extern(immutable, parallel_safe, name = "embed")] pub fn embed_array( transformer: &str, inputs: default!(Vec, "ARRAY[]::TEXT[]"), From 0b0e666159a90cc7022bb25c764ec3658a2174e4 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 15:24:55 -0500 Subject: [PATCH 11/14] Changing inner func name to embed_batch --- pgml-extension/src/api.rs | 4 ++-- pgml-extension/src/bindings/transformers.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index b5ab033b7..49383d22b 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -569,12 +569,12 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> } #[pg_extern(immutable, parallel_safe, name = "embed")] -pub fn embed_array( +pub fn embed_batch( transformer: &str, inputs: default!(Vec, "ARRAY[]::TEXT[]"), kwargs: default!(JsonB, "'{}'"), ) -> Vec> { - crate::bindings::transformers::embed_array(transformer, &inputs, &kwargs.0) + crate::bindings::transformers::embed_batch(transformer, &inputs, &kwargs.0) } #[pg_extern(immutable, parallel_safe)] diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 8b94119fe..60ec88171 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -80,7 +80,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec, kwargs: &serde_json::Value, From 4b2e14d615d7608a9d39207418598cfd1d0a37c7 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 16:01:48 -0500 Subject: [PATCH 12/14] adding fixes from comments --- pgml-extension/src/api.rs | 9 ++++-- pgml-extension/src/bindings/transformers.py | 27 ++++++----------- pgml-extension/src/bindings/transformers.rs | 32 ++------------------- 3 files changed, 16 insertions(+), 52 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 49383d22b..4fa328dd8 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -563,9 +563,12 @@ fn load_dataset( TableIterator::new(vec![(name, rows)].into_iter()) } -#[pg_extern(immutable, parallel_safe)] +#[pg_extern(immutable, parallel_safe, name = "embed")] pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { - crate::bindings::transformers::embed(transformer, text, &kwargs.0) + embed_batch(transformer, Vec::from([text]), &kwargs.0) + .first() + .unwrap() + .to_vec() } #[pg_extern(immutable, parallel_safe, name = "embed")] @@ -574,7 +577,7 @@ pub fn embed_batch( inputs: default!(Vec, "ARRAY[]::TEXT[]"), kwargs: default!(JsonB, "'{}'"), ) -> Vec> { - crate::bindings::transformers::embed_batch(transformer, &inputs, &kwargs.0) + crate::bindings::transformers::embed(transformer, &inputs, &kwargs.0) } #[pg_extern(immutable, parallel_safe)] diff --git a/pgml-extension/src/bindings/transformers.py b/pgml-extension/src/bindings/transformers.py index 9c5d25856..419fc3467 100644 --- a/pgml-extension/src/bindings/transformers.py +++ b/pgml-extension/src/bindings/transformers.py @@ -112,27 +112,20 @@ def transform(task, args, inputs): def embed(transformer, inputs, kwargs): - try: - inputs = json.loads(inputs) - except json.decoder.JSONDecodeError: - pass - + inputs = json.loads(inputs) kwargs = json.loads(kwargs) ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") if instructor: klass = INSTRUCTOR - if isinstance(inputs, str): - inputs = [[kwargs.pop("instruction"), inputs]] + + texts_with_instructions = [] + instruction = kwargs.pop("instruction") + for text in inputs: + texts_with_instructions.append([instruction, text]) - else: - texts_with_instructions = [] - instruction = kwargs.pop("instruction") - for text in inputs: - texts_with_instructions.append([instruction, text]) - - inputs = texts_with_instructions + inputs = texts_with_instructions else: klass = SentenceTransformer @@ -140,11 +133,7 @@ def embed(transformer, inputs, kwargs): __cache_sentence_transformer_by_name[transformer] = klass(transformer) model = __cache_sentence_transformer_by_name[transformer] - result = model.encode(inputs, **kwargs) - if instructor and len(result) == 1: - result = result[0] - - return result + return model.encode(inputs, **kwargs) def load_dataset(name, subset, limit: None, kwargs: "{}"): diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index 60ec88171..a213d4d89 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -56,39 +56,11 @@ pub fn transform( serde_json::from_str(&results).unwrap() } -pub fn embed(transformer: &str, text: &str, kwargs: &serde_json::Value) -> Vec { +pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Vec { crate::bindings::venv::activate(); let kwargs = serde_json::to_string(kwargs).unwrap(); - Python::with_gil(|py| -> Vec { - let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); - embed - .call1( - py, - PyTuple::new( - py, - &[ - transformer.to_string().into_py(py), - text.to_string().into_py(py), - kwargs.into_py(py), - ], - ), - ) - .unwrap() - .extract(py) - .unwrap() - }) -} - -pub fn embed_batch( - transformer: &str, - inputs: &Vec, - kwargs: &serde_json::Value, -) -> Vec> { - crate::bindings::venv::activate(); - - let kwargs = serde_json::to_string(kwargs).unwrap(); - let inputs = serde_json::to_string(inputs).unwrap(); + let inputs = serde_json::to_string(&inputs).unwrap(); Python::with_gil(|py| -> Vec> { let embed: Py = PY_MODULE.getattr(py, "embed").unwrap().into(); embed From 78bda85ebdaa604adf1c3d22ba13c3481d457e52 Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 16:04:27 -0500 Subject: [PATCH 13/14] adding Vec> --- pgml-extension/src/bindings/transformers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers.rs b/pgml-extension/src/bindings/transformers.rs index a213d4d89..714c0c690 100644 --- a/pgml-extension/src/bindings/transformers.rs +++ b/pgml-extension/src/bindings/transformers.rs @@ -56,7 +56,7 @@ pub fn transform( serde_json::from_str(&results).unwrap() } -pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Vec { +pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Vec> { crate::bindings::venv::activate(); let kwargs = serde_json::to_string(kwargs).unwrap(); From 9111a49548b11c2d2f18aca8d91c794adae6ab6b Mon Sep 17 00:00:00 2001 From: jsaied99 Date: Mon, 5 Jun 2023 16:22:17 -0500 Subject: [PATCH 14/14] fixing compilation errors --- pgml-extension/src/api.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 4fa328dd8..a4fe4f976 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -565,7 +565,7 @@ fn load_dataset( #[pg_extern(immutable, parallel_safe, name = "embed")] pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> Vec { - embed_batch(transformer, Vec::from([text]), &kwargs.0) + embed_batch(transformer, Vec::from([text]), kwargs) .first() .unwrap() .to_vec() @@ -574,7 +574,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> #[pg_extern(immutable, parallel_safe, name = "embed")] pub fn embed_batch( transformer: &str, - inputs: default!(Vec, "ARRAY[]::TEXT[]"), + inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'"), ) -> Vec> { crate::bindings::transformers::embed(transformer, &inputs, &kwargs.0) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy