diff --git a/README.md b/README.md index c9f2531..5f7ecee 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,47 @@ class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" ``` + +### Override Column Types + +Option: `overrides` + +You can override the SQL to Python type mapping for specific columns using the `overrides` option. This is useful for columns with JSON data or other custom types. + +Example configuration: + +```yaml +options: + package: authors + emit_pydantic_models: true + overrides: + - column: "some_table.payload" + py_import: "my_lib.models" + py_type: "Payload" +``` + +This will: +1. Override the column `payload` in `some_table` to use the type `Payload` +2. Add an import for `my_lib.models` to the models file + +Example output: + +```python +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 + +import datetime +import pydantic +from typing import Any + +import my_lib.models + + +class SomeTable(pydantic.BaseModel): + id: int + created_at: datetime.datetime + payload: my_lib.models.Payload +``` + +This is similar to the [overrides functionality in the Go version of sqlc](https://docs.sqlc.dev/en/stable/howto/overrides.html#overriding-types). diff --git a/internal/config.go b/internal/config.go index 1a8a565..e78112c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -1,13 +1,20 @@ package python +type OverrideColumn struct { + Column string `json:"column"` + PyType string `json:"py_type"` + PyImport string `json:"py_import"` +} + type Config struct { - EmitExactTableNames bool `json:"emit_exact_table_names"` - EmitSyncQuerier bool `json:"emit_sync_querier"` - EmitAsyncQuerier bool `json:"emit_async_querier"` - Package string `json:"package"` - Out string `json:"out"` - EmitPydanticModels bool `json:"emit_pydantic_models"` - EmitStrEnum bool `json:"emit_str_enum"` - QueryParameterLimit *int32 `json:"query_parameter_limit"` - InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + EmitExactTableNames bool `json:"emit_exact_table_names"` + EmitSyncQuerier bool `json:"emit_sync_querier"` + EmitAsyncQuerier bool `json:"emit_async_querier"` + Package string `json:"package"` + Out string `json:"out"` + EmitPydanticModels bool `json:"emit_pydantic_models"` + EmitStrEnum bool `json:"emit_str_enum"` + QueryParameterLimit *int32 `json:"query_parameter_limit"` + InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + Overrides []OverrideColumn `json:"overrides"` } diff --git a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml index beae200..62ec488 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml +++ b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml index 04e3feb..56fe8bf 100644 --- a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml +++ b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_type_overrides/db/models.py b/internal/endtoend/testdata/emit_type_overrides/db/models.py new file mode 100644 index 0000000..1decb3d --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/db/models.py @@ -0,0 +1,11 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 +import pydantic + +import my_lib.models + + +class Book(pydantic.BaseModel): + id: int + payload: my_lib.models.Payload diff --git a/internal/endtoend/testdata/emit_type_overrides/db/query.py b/internal/endtoend/testdata/emit_type_overrides/db/query.py new file mode 100644 index 0000000..0486a35 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/db/query.py @@ -0,0 +1,92 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.28.0 +# source: query.sql +from typing import AsyncIterator, Iterator, Optional + +import my_lib.models +import sqlalchemy +import sqlalchemy.ext.asyncio + +from db import models + + +CREATE_BOOK = """-- name: create_book \\:one +INSERT INTO books (payload) +VALUES (:p1) +RETURNING id, payload +""" + + +GET_BOOK = """-- name: get_book \\:one +SELECT id, payload FROM books +WHERE id = :p1 LIMIT 1 +""" + + +LIST_BOOKS = """-- name: list_books \\:many +SELECT id, payload FROM books +ORDER BY id +""" + + +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]: + row = self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload}).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + def get_book(self, *, id: int) -> Optional[models.Book]: + row = self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id}).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + def list_books(self) -> Iterator[models.Book]: + result = self._conn.execute(sqlalchemy.text(LIST_BOOKS)) + for row in result: + yield models.Book( + id=row[0], + payload=row[1], + ) + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload})).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + async def get_book(self, *, id: int) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id})).first() + if row is None: + return None + return models.Book( + id=row[0], + payload=row[1], + ) + + async def list_books(self) -> AsyncIterator[models.Book]: + result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS)) + async for row in result: + yield models.Book( + id=row[0], + payload=row[1], + ) diff --git a/internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py b/internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py b/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py new file mode 100644 index 0000000..1f1a052 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/my_lib/models.py @@ -0,0 +1,7 @@ +from datetime import date + +from pydantic import BaseModel + +class Payload(BaseModel): + name: str + release_date: date diff --git a/internal/endtoend/testdata/emit_type_overrides/query.sql b/internal/endtoend/testdata/emit_type_overrides/query.sql new file mode 100644 index 0000000..ab1a3c1 --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/query.sql @@ -0,0 +1,12 @@ +-- name: GetBook :one +SELECT * FROM books +WHERE id = $1 LIMIT 1; + +-- name: ListBooks :many +SELECT * FROM books +ORDER BY id; + +-- name: CreateBook :one +INSERT INTO books (payload) +VALUES (sqlc.arg(payload)) +RETURNING *; diff --git a/internal/endtoend/testdata/emit_type_overrides/schema.sql b/internal/endtoend/testdata/emit_type_overrides/schema.sql new file mode 100644 index 0000000..51997ea --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/schema.sql @@ -0,0 +1,4 @@ +CREATE TABLE books ( + id SERIAL PRIMARY KEY, + payload JSONB NOT NULL +); diff --git a/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml b/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml new file mode 100644 index 0000000..e70c41a --- /dev/null +++ b/internal/endtoend/testdata/emit_type_overrides/sqlc.yaml @@ -0,0 +1,22 @@ +version: "2" +plugins: + - name: py + wasm: + url: file://../../../../bin/sqlc-gen-python.wasm + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" +sql: + - schema: schema.sql + queries: query.sql + engine: postgresql + codegen: + - plugin: py + out: db + options: + package: db + emit_pydantic_models: true + emit_sync_querier: true + emit_async_querier: true + overrides: + - column: "books.payload" + py_import: "my_lib.models" + py_type: "Payload" diff --git a/internal/endtoend/testdata/exec_result/sqlc.yaml b/internal/endtoend/testdata/exec_result/sqlc.yaml index ddffc83..e7fe6ff 100644 --- a/internal/endtoend/testdata/exec_result/sqlc.yaml +++ b/internal/endtoend/testdata/exec_result/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_rows/sqlc.yaml b/internal/endtoend/testdata/exec_rows/sqlc.yaml index ddffc83..e7fe6ff 100644 --- a/internal/endtoend/testdata/exec_rows/sqlc.yaml +++ b/internal/endtoend/testdata/exec_rows/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml index efbb150..030d33e 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml +++ b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml index 336bca7..018e2db 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml index c20cd57..91a7c07 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml index 6e2cdeb..56644ee 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml index c432e4f..2b8d205 100644 --- a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938" sql: - schema: schema.sql queries: query.sql diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..9cd35b3 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -181,6 +181,40 @@ func (q Query) ArgDictNode() *pyast.Node { } func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { + // Parse the configuration + var conf Config + if len(req.PluginOptions) > 0 { + if err := json.Unmarshal(req.PluginOptions, &conf); err != nil { + log.Printf("failed to parse plugin options: %s", err) + } + } + + // Check for overrides + if len(conf.Overrides) > 0 && col.Table != nil { + tableName := col.Table.Name + if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema { + tableName = col.Table.Schema + "." + tableName + } + + // Look for a matching override + for _, override := range conf.Overrides { + overrideKey := tableName + "." + col.Name + if override.Column == overrideKey { + // Found a match, use the override + typeStr := override.PyType + if override.PyImport != "" && !strings.Contains(typeStr, ".") { + typeStr = override.PyImport + "." + override.PyType + } + return pyType{ + InnerType: typeStr, + IsArray: col.IsArray, + IsNull: !col.NotNull, + } + } + } + } + + // No override found, use the standard type mapping typ := pyInnerType(req, col) return pyType{ InnerType: typ, diff --git a/internal/imports.go b/internal/imports.go index b88c58c..454eefd 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -97,6 +97,20 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS pkg := make(map[string]importSpec) + // Add custom imports from overrides + for _, override := range i.C.Overrides { + if override.PyImport != "" { + // Check if it's a standard module or a package import + if strings.Contains(override.PyImport, ".") { + // It's a package import + pkg[override.PyImport] = importSpec{Module: override.PyImport} + } else { + // It's a standard import + std[override.PyImport] = importSpec{Module: override.PyImport} + } + } + } + return std, pkg } @@ -167,6 +181,20 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map } } + // Add custom imports from overrides for query files + for _, override := range i.C.Overrides { + if override.PyImport != "" { + // Check if it's a standard module or a package import + if strings.Contains(override.PyImport, ".") { + // It's a package import + pkg[override.PyImport] = importSpec{Module: override.PyImport} + } else { + // It's a standard import + std[override.PyImport] = importSpec{Module: override.PyImport} + } + } + } + return std, pkg } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy