1
1
from typing import Any
2
- from pypika .functions import Function
3
- from typing import Dict
4
- from pypika import JSON , Array
2
+ from pypika .functions import Function , Cast
3
+ from typing import Dict , List
4
+ from pypika import JSON , Array , Field
5
+ import json
5
6
6
7
7
8
class Embed (Function ):
@@ -20,3 +21,29 @@ def __init__(
20
21
class CosineDistance (Function ):
21
22
def __init__ (self , lhs : Array , rhs : Array , alias : str = "cosine" ) -> None :
22
23
super (CosineDistance , self ).__init__ ("cosine_distance" , lhs , rhs , alias = alias )
24
+
25
+
26
+ class Transform (Function ):
27
+ def __init__ (
28
+ self ,
29
+ task : str | Dict [str , Any ],
30
+ inputs : List ,
31
+ args : Dict [str , Any ] = {},
32
+ alias : str = "transform" ,
33
+ ) -> None :
34
+ super (Transform , self ).__init__ (
35
+ "pgml.transform" , task = task , inputs = inputs , args = args , alias = alias
36
+ )
37
+ self .task = task
38
+ self .inputs = inputs
39
+ self .args = args
40
+
41
+ def get_function_sql (self , ** kwargs : Any ) -> str :
42
+ if isinstance (self .task , str ):
43
+ return "pgml.transform(task => '{}', inputs => ARRAY{}, args => '{}'::JSONB)" .format (
44
+ self .task , self .inputs , self .args
45
+ )
46
+ elif isinstance (self .task , dict ):
47
+ return "pgml.transform(task => '{}'::JSONB, inputs => ARRAY{}, args => '{}'::JSONB)" .format (
48
+ json .dumps (self .task ), self .inputs , self .args
49
+ )
0 commit comments