Skip to content

Adds pipeline model caching in the transform function. #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2023

Conversation

f-prime
Copy link
Contributor

@f-prime f-prime commented Apr 16, 2023

This PR adds the ability to run a query that looks like this:

SELECT pgml.transform(
        '{"model": "roberta-large-mnli"}'::JSONB, 
        inputs => ARRAY[
            'I love how amazingly simple ML has become!', 
            'I hate doing mundane and thankless tasks. ☹️'
        ],
        cache => TRUE
    ) AS positivity;

roberta-large-mnli will be cached in memory to prevent transformers.pipeline() from being called more than once for the same model.

By default, cache is FALSE.

task = json.loads(task)
args = json.loads(args)
inputs = json.loads(inputs)

pipe = transformers.pipeline(**task)
model = task.get("model")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there may be different pipelines with the same model, to handle different tasks. We may need to use the full task param for caching. Something like:

    if cache:
        key = ','.join([str(key) + ':' + str(value) for (key, value) in sorted(task.items())])
        if key not in __cache_transformer_by_task:
            __cache_transformer_by_task[key] = transformers.pipeline(**task)
        pipe = __cache_transformer_by_task[key]
    else:
        pipe = transformers.pipeline(**task)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, good point. I'll think some more about the parameters of .pipeline() and push an update.

Copy link
Contributor Author

@f-prime f-prime Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I pretty much copied your code verbatim. I also like the idea in there where you have to pass the cache parameter also to USE the cached model. My initial version always used the cached version if it was available whether or not the cache flag was passed in.

@montanalow
Copy link
Contributor

This looks good to me as an improvement over the status quo, so I'm going to merge it. We're being slightly inconsistent, i.e. every other API caches by default, including the pgml.embed() which may use similarly large models. I'm not really content with that approach either since the only way to clear those caches is to drop the connection and kill the backend process. It's good to have these examples though, that we'll want to design a consistent interface around for more feature rich cache management.

@montanalow montanalow merged commit 41de0aa into postgresml:master Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy