From de3edcf609458edecea123d9dab093e1f2a45803 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 29 Oct 2025 18:02:58 +0000 Subject: [PATCH 01/11] document sql --- docs/api-reference/sql.md | 7 ++++ docs/generating_sql.md | 60 +++++++++++++++++------------- mkdocs.yml | 1 + narwhals/sql.py | 77 +++++++++++++++++++++++++++++++++++++++ tests/sql_test.py | 24 ++++++++++++ 5 files changed, 143 insertions(+), 26 deletions(-) create mode 100644 docs/api-reference/sql.md create mode 100644 narwhals/sql.py create mode 100644 tests/sql_test.py diff --git a/docs/api-reference/sql.md b/docs/api-reference/sql.md new file mode 100644 index 0000000000..2da48e6a1a --- /dev/null +++ b/docs/api-reference/sql.md @@ -0,0 +1,7 @@ +# `narwhals.sql` + +::: narwhals.sql + handler: python + options: + members: + - table diff --git a/docs/generating_sql.md b/docs/generating_sql.md index 6c0198ccce..0f006932e1 100644 --- a/docs/generating_sql.md +++ b/docs/generating_sql.md @@ -5,54 +5,68 @@ For example, what's the SQL equivalent to: ```python exec="1" source="above" session="generating-sql" import narwhals as nw -from narwhals.typing import IntoFrameT +from narwhals.typing import FrameT -def avg_monthly_price(df_native: IntoFrameT) -> IntoFrameT: +def avg_monthly_price(df: FrameT) -> FrameT: return ( - nw.from_native(df_native) - .group_by(nw.col("date").dt.truncate("1mo")) + df.group_by(nw.col("date").dt.truncate("1mo")) .agg(nw.col("price").mean()) .sort("date") - .to_native() ) ``` ? -There are several ways to find out. +Narwhals provides you with a `narwhals.sql` module to do just that! -## Via DuckDB +!!! info + `narwhals.sql` currently requires DuckDB to be installed. + +## `narwhals.sql` You can generate SQL directly from DuckDB. ```python exec="1" source="above" session="generating-sql" result="sql" -import duckdb - -conn = duckdb.connect() -conn.sql("""CREATE TABLE prices (date DATE, price DOUBLE);""") - -df = nw.from_native(conn.table("prices")) -print(avg_monthly_price(df).sql_query()) +import narwhals as nw +from narwhals.sql import table + +prices = table("prices", {"date": nw.Date, "price": nw.Float64}) + +sql_query = ( + prices.group_by(nw.col("date").dt.truncate("1mo")) + .agg(nw.col("price").mean()) + .sort("date") + .to_native() + .sql_query() +) +print(sql_query) ``` -To make it look a bit prettier, or to then transpile it to other SQL dialects, we can pass it to [SQLGlot](https://github.com/tobymao/sqlglot): +To make it look a bit prettier, or to then transpile it to other SQL dialects, you can pass it to [SQLGlot](https://github.com/tobymao/sqlglot): ```python exec="1" source="above" session="generating-sql" result="sql" import sqlglot -print(sqlglot.transpile(avg_monthly_price(df).sql_query(), pretty=True)[0]) +print(sqlglot.transpile(sql_query, pretty=True)[0]) ``` +You can even pass a [different dialect](https://github.com/tobymao/sqlglot?tab=readme-ov-file#supported-dialects): + +```python exec="1" source="above" session="generating-sql" result="sql" +print(sqlglot.transpile(sql_query, pretty=True, dialect="databricks")[0]) +``` + + ## Via Ibis -We can also use Ibis to generate SQL: +You can also use Ibis or SQLFrame to generate SQL: ```python exec="1" source="above" session="generating-sql" result="sql" import ibis -t = ibis.table({"date": "date", "price": "double"}, name="prices") -print(ibis.to_sql(avg_monthly_price(t))) +df = nw.from_native(ibis.table({"date": "date", "price": "double"}, name="prices")) +print(ibis.to_sql(avg_monthly_price(df).to_native())) ``` ## Via SQLFrame @@ -66,11 +80,5 @@ session = StandaloneSession.builder.getOrCreate() session.catalog.add_table("prices", column_mapping={"date": "date", "price": "float"}) df = nw.from_native(session.read.table("prices")) -print(avg_monthly_price(df).sql(dialect="duckdb")) -``` - -Or, to print the SQL code in a different dialect (say, databricks): - -```python exec="1" source="above" session="generating-sql" result="sql" -print(avg_monthly_price(df).sql(dialect="databricks")) +print(avg_monthly_price(df).to_native().sql(dialect="duckdb")) ``` diff --git a/mkdocs.yml b/mkdocs.yml index c02f227f87..8b281675ed 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ nav: - api-reference/dtypes.md - api-reference/exceptions.md - api-reference/selectors.md + - api-reference/sql.md - api-reference/testing.md - api-reference/typing.md - api-reference/utils.md diff --git a/narwhals/sql.py b/narwhals/sql.py new file mode 100644 index 0000000000..34939e0c97 --- /dev/null +++ b/narwhals/sql.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._duckdb.utils import narwhals_to_native_dtype +from narwhals.translate import from_native +from narwhals.utils import Version + +if TYPE_CHECKING: + from duckdb import DuckDBPyRelation + + from narwhals.dataframe import LazyFrame + from narwhals.typing import IntoSchema + +try: + import duckdb # ignore-banned-import +except ImportError as exc: # pragma: no cover + msg = ( + "`narwhals.sql` requires DuckDB to be installed.\n\n" + "Hint: run `pip install -U narwhals[sql]`" + ) + raise ModuleNotFoundError(msg) from exc + +conn = duckdb.connect() +tz = conn.sql("select value from duckdb_settings() where name = 'TimeZone'").fetchone()[0] + + +def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: + """Generate standalone LazyFrame which you can use to generate SQL. + + Note that this requires DuckDB to be installed. + + Parameters: + name: Table name. + schema: Table schema. + + Returns: + A LazyFrame. + + Examples: + >>> import narwhals as nw + >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} + >>> assets = nw.sql.table("assets", schema) + >>> result = assets.with_columns( + ... nw.col("price").rolling_mean(5).over("symbol", order_by="date") + ... ) + >>> print(result.to_native().sql(dialect="duckdb")) + SELECT + "assets"."date" AS "date", + CASE + WHEN COUNT("assets"."price") OVER ( + PARTITION BY "assets"."symbol" + ORDER BY "assets"."date" NULLS FIRST + ROWS BETWEEN 4 PRECEDING AND CURRENT ROW + ) >= 5 + THEN AVG("assets"."price") OVER ( + PARTITION BY "assets"."symbol" + ORDER BY "assets"."date" NULLS FIRST + ROWS BETWEEN 4 PRECEDING AND CURRENT ROW + ) + END AS "price", + "assets"."symbol" AS "symbol" + FROM "assets" AS "assets" + """ + column_mapping = { + col: narwhals_to_native_dtype(dtype, Version.MAIN, tz) + for col, dtype in schema.items() + } + dtypes = ", ".join(f"{col} {dtype}" for col, dtype in column_mapping.items()) + conn.sql(f""" + CREATE TABLE "{name}" + ({dtypes}); + """) + return from_native(conn.table(name)) + + +__all__ = ["table"] diff --git a/tests/sql_test.py b/tests/sql_test.py new file mode 100644 index 0000000000..4181b6cf96 --- /dev/null +++ b/tests/sql_test.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw + + +def test_sql() -> None: + pytest.importorskip("duckdb") + from narwhals.sql import table + + schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} + assets = table("assets", schema) + result = ( + assets.with_columns( + returns=(nw.col("price") / nw.col("price").shift(1)).over( + "symbol", order_by="date" + ) + ) + .to_native() + .sql_query() + ) + expected = """SELECT date, price, symbol, (price / lag(price, 1) OVER (PARTITION BY symbol ORDER BY date ASC NULLS FIRST)) AS "returns" FROM main.assets""" + assert result == expected From 703fcb5a05ab9f53b9c450710162d3d5c117a662 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 29 Oct 2025 18:06:04 +0000 Subject: [PATCH 02/11] doctest --- narwhals/sql.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/narwhals/sql.py b/narwhals/sql.py index 34939e0c97..31f21856d0 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -39,28 +39,12 @@ def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: Examples: >>> import narwhals as nw + >>> from narwhals.sql import table >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} - >>> assets = nw.sql.table("assets", schema) - >>> result = assets.with_columns( - ... nw.col("price").rolling_mean(5).over("symbol", order_by="date") - ... ) - >>> print(result.to_native().sql(dialect="duckdb")) - SELECT - "assets"."date" AS "date", - CASE - WHEN COUNT("assets"."price") OVER ( - PARTITION BY "assets"."symbol" - ORDER BY "assets"."date" NULLS FIRST - ROWS BETWEEN 4 PRECEDING AND CURRENT ROW - ) >= 5 - THEN AVG("assets"."price") OVER ( - PARTITION BY "assets"."symbol" - ORDER BY "assets"."date" NULLS FIRST - ROWS BETWEEN 4 PRECEDING AND CURRENT ROW - ) - END AS "price", - "assets"."symbol" AS "symbol" - FROM "assets" AS "assets" + >>> assets = table("assets", schema) + >>> result = assets.with_columns(price_2=nw.col("price") * 2) + >>> print(result.to_native().sql_query()) + SELECT date, price, symbol, (price * 2) AS price_2 FROM main.assets """ column_mapping = { col: narwhals_to_native_dtype(dtype, Version.MAIN, tz) From f045a7ca9d579d2835c017e5df1ec31f2075eee8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:10:37 +0000 Subject: [PATCH 03/11] fix test and typing --- narwhals/sql.py | 14 +++++++------- tests/sql_test.py | 3 +++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/narwhals/sql.py b/narwhals/sql.py index 31f21856d0..97594933b0 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -21,8 +21,8 @@ ) raise ModuleNotFoundError(msg) from exc -conn = duckdb.connect() -tz = conn.sql("select value from duckdb_settings() where name = 'TimeZone'").fetchone()[0] +CONN = duckdb.connect() +TZ = CONN.sql("select value from duckdb_settings() where name = 'TimeZone'").fetchone()[0] # type: ignore[index] def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: @@ -42,20 +42,20 @@ def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: >>> from narwhals.sql import table >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} >>> assets = table("assets", schema) - >>> result = assets.with_columns(price_2=nw.col("price") * 2) + >>> result = assets.filter(nw.col("price") > 100) >>> print(result.to_native().sql_query()) - SELECT date, price, symbol, (price * 2) AS price_2 FROM main.assets + SELECT * FROM main.assets WHERE (price > 100) """ column_mapping = { - col: narwhals_to_native_dtype(dtype, Version.MAIN, tz) + col: narwhals_to_native_dtype(dtype, Version.MAIN, TZ) for col, dtype in schema.items() } dtypes = ", ".join(f"{col} {dtype}" for col, dtype in column_mapping.items()) - conn.sql(f""" + CONN.sql(f""" CREATE TABLE "{name}" ({dtypes}); """) - return from_native(conn.table(name)) + return from_native(CONN.table(name)) __all__ = ["table"] diff --git a/tests/sql_test.py b/tests/sql_test.py index 4181b6cf96..7b3778a84b 100644 --- a/tests/sql_test.py +++ b/tests/sql_test.py @@ -3,10 +3,13 @@ import pytest import narwhals as nw +from tests.utils import DUCKDB_VERSION def test_sql() -> None: pytest.importorskip("duckdb") + if DUCKDB_VERSION < (1, 3): + pytest.skip() from narwhals.sql import table schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} From 785c64e20c1e33ffd5850b9aebaea50f40be3a13 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:27:23 +0000 Subject: [PATCH 04/11] typing in test --- tests/sql_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sql_test.py b/tests/sql_test.py index 7b3778a84b..0d9ffe25c0 100644 --- a/tests/sql_test.py +++ b/tests/sql_test.py @@ -12,7 +12,7 @@ def test_sql() -> None: pytest.skip() from narwhals.sql import table - schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} + schema = {"date": nw.Date(), "price": nw.Int64(), "symbol": nw.String()} assets = table("assets", schema) result = ( assets.with_columns( From 5225fe8a71d7e610b1cf3e679a04741489e58120 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:56:17 +0000 Subject: [PATCH 05/11] type completeness --- narwhals/sql.py | 10 ++++++---- pyproject.toml | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/narwhals/sql.py b/narwhals/sql.py index 97594933b0..9df4537f07 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._duckdb.utils import narwhals_to_native_dtype +from narwhals._duckdb.utils import DeferredTimeZone, narwhals_to_native_dtype from narwhals.translate import from_native from narwhals.utils import Version @@ -14,15 +14,17 @@ try: import duckdb # ignore-banned-import -except ImportError as exc: # pragma: no cover +except ImportError as _exc: # pragma: no cover msg = ( "`narwhals.sql` requires DuckDB to be installed.\n\n" "Hint: run `pip install -U narwhals[sql]`" ) - raise ModuleNotFoundError(msg) from exc + raise ModuleNotFoundError(msg) from _exc CONN = duckdb.connect() -TZ = CONN.sql("select value from duckdb_settings() where name = 'TimeZone'").fetchone()[0] # type: ignore[index] +TZ = DeferredTimeZone( + CONN.sql("select value from duckdb_settings() where name = 'TimeZone'") +) def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: diff --git a/pyproject.toml b/pyproject.toml index fd6fc2a5e1..ed5692307b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.1"] ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"] sqlframe = ["sqlframe>=3.22.0,!=3.39.3"] +sql = ["duckdb>=1.1"] [dependency-groups] core = [ From 8deb0a0e51de2c86cf77474c72a0796bb01e42b8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:36:14 +0000 Subject: [PATCH 06/11] wip --- docs/generating_sql.md | 15 ++++++-------- narwhals/_duckdb/group_by.py | 13 ++++++++---- narwhals/sql.py | 40 ++++++++++++++++++++++++++++-------- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/docs/generating_sql.md b/docs/generating_sql.md index 0f006932e1..bdf64ccd98 100644 --- a/docs/generating_sql.md +++ b/docs/generating_sql.md @@ -33,28 +33,25 @@ from narwhals.sql import table prices = table("prices", {"date": nw.Date, "price": nw.Float64}) -sql_query = ( +result = ( prices.group_by(nw.col("date").dt.truncate("1mo")) .agg(nw.col("price").mean()) .sort("date") - .to_native() - .sql_query() ) -print(sql_query) +print(result.to_sql()) ``` -To make it look a bit prettier, or to then transpile it to other SQL dialects, you can pass it to [SQLGlot](https://github.com/tobymao/sqlglot): +To make it look a bit prettier, or to then transpile it to other SQL dialects, you can pass `pretty=True`, but +note that this currently requires [sqlglot](https://github.com/tobymao/sqlglot) to be installed. ```python exec="1" source="above" session="generating-sql" result="sql" -import sqlglot - -print(sqlglot.transpile(sql_query, pretty=True)[0]) +print(result.to_sql(pretty=True)) ``` You can even pass a [different dialect](https://github.com/tobymao/sqlglot?tab=readme-ov-file#supported-dialects): ```python exec="1" source="above" session="generating-sql" result="sql" -print(sqlglot.transpile(sql_query, pretty=True, dialect="databricks")[0]) +print(result.to_sql(pretty=True, dialect="databricks")) ``` diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index f2a36e896a..fb0e9267f6 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -3,6 +3,7 @@ from itertools import chain from typing import TYPE_CHECKING +from narwhals._duckdb.utils import join_column_names from narwhals._sql.group_by import SQLGroupBy if TYPE_CHECKING: @@ -27,7 +28,11 @@ def __init__( self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: - agg_columns = tuple(chain(self._keys, self._evaluate_exprs(exprs))) - return self.compliant._with_native( - self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type] - ).rename(dict(zip(self._keys, self._output_key_names))) + agg_columns = tuple(self._evaluate_exprs(exprs)) + result = self.compliant.native.aggregate( + tuple(chain(self._keys, agg_columns)), join_column_names(*self._keys) + ) + + return self.compliant._with_native(result).rename( + dict(zip(self._keys, self._output_key_names)) + ) diff --git a/narwhals/sql.py b/narwhals/sql.py index 9df4537f07..44f28eb6a5 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -1,15 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from narwhals._duckdb.utils import DeferredTimeZone, narwhals_to_native_dtype +from narwhals.dataframe import LazyFrame from narwhals.translate import from_native from narwhals.utils import Version if TYPE_CHECKING: - from duckdb import DuckDBPyRelation - - from narwhals.dataframe import LazyFrame + from narwhals._compliant.typing import CompliantLazyFrameAny from narwhals.typing import IntoSchema try: @@ -27,7 +26,31 @@ ) -def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: +class SQLTable(LazyFrame["DuckDBPyRelation"]): + def __init__( + self, df: CompliantLazyFrameAny, level: Literal["full", "interchange", "lazy"] + ) -> None: + super().__init__(df, level=level) + + def to_sql(self, *, pretty: bool = False, dialect: str = "duckdb") -> str: + sql_query = self.to_native().sql_query() + if not pretty and dialect == "duckdb": + return sql_query + try: + import sqlglot + except ImportError as _exc: # pragma: no cover + msg = ( + "`SQLTable.to_sql` with `pretty=True` or `dialect!='duckdb'` " + "requires `sqlglot` to be installed.\n\n" + "Hint: run `pip install -U narwhals[sql]`" + ) + raise ModuleNotFoundError(msg) from _exc + return sqlglot.transpile( + sql_query, read="duckdb", identity=False, write=dialect, pretty=pretty + )[0] + + +def table(name: str, schema: IntoSchema) -> SQLTable: """Generate standalone LazyFrame which you can use to generate SQL. Note that this requires DuckDB to be installed. @@ -45,19 +68,20 @@ def table(name: str, schema: IntoSchema) -> LazyFrame[DuckDBPyRelation]: >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} >>> assets = table("assets", schema) >>> result = assets.filter(nw.col("price") > 100) - >>> print(result.to_native().sql_query()) + >>> print(result.to_sql()) SELECT * FROM main.assets WHERE (price > 100) """ column_mapping = { col: narwhals_to_native_dtype(dtype, Version.MAIN, TZ) for col, dtype in schema.items() } - dtypes = ", ".join(f"{col} {dtype}" for col, dtype in column_mapping.items()) + dtypes = ", ".join(f'"{col}" {dtype}' for col, dtype in column_mapping.items()) CONN.sql(f""" CREATE TABLE "{name}" ({dtypes}); """) - return from_native(CONN.table(name)) + lf = from_native(CONN.table(name)) + return SQLTable(lf._compliant_frame, level=lf._level) __all__ = ["table"] From 7384addeada01bbd2602f8ad1b333067d4faaeb9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:15:26 +0000 Subject: [PATCH 07/11] wip --- narwhals/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/sql.py b/narwhals/sql.py index 44f28eb6a5..6b329a21da 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -26,7 +26,7 @@ ) -class SQLTable(LazyFrame["DuckDBPyRelation"]): +class SQLTable(LazyFrame[duckdb.DuckDBPyRelation]): def __init__( self, df: CompliantLazyFrameAny, level: Literal["full", "interchange", "lazy"] ) -> None: From 836424a65560140dd9d68e130ef076f767c40718 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 1 Nov 2025 08:42:29 +0000 Subject: [PATCH 08/11] update docs --- docs/api-reference/sql.md | 8 ++++++ docs/generating_sql.md | 11 ++------- narwhals/sql.py | 51 +++++++++++++++++++++++++++------------ pyproject.toml | 2 +- 4 files changed, 46 insertions(+), 26 deletions(-) diff --git a/docs/api-reference/sql.md b/docs/api-reference/sql.md index 2da48e6a1a..c74aaa5a48 100644 --- a/docs/api-reference/sql.md +++ b/docs/api-reference/sql.md @@ -5,3 +5,11 @@ options: members: - table + +::: narwhals.sql.SQLTable + handler: python + options: + members: + - to_sql + show_source: false + show_bases: false diff --git a/docs/generating_sql.md b/docs/generating_sql.md index bdf64ccd98..69826ff970 100644 --- a/docs/generating_sql.md +++ b/docs/generating_sql.md @@ -41,20 +41,13 @@ result = ( print(result.to_sql()) ``` -To make it look a bit prettier, or to then transpile it to other SQL dialects, you can pass `pretty=True`, but -note that this currently requires [sqlglot](https://github.com/tobymao/sqlglot) to be installed. +To make it look a bit prettier, you can pass `pretty=True`, but +note that this currently requires [sqlparse](https://github.com/andialbrecht/sqlparse) to be installed. ```python exec="1" source="above" session="generating-sql" result="sql" print(result.to_sql(pretty=True)) ``` -You can even pass a [different dialect](https://github.com/tobymao/sqlglot?tab=readme-ov-file#supported-dialects): - -```python exec="1" source="above" session="generating-sql" result="sql" -print(result.to_sql(pretty=True, dialect="databricks")) -``` - - ## Via Ibis You can also use Ibis or SQLFrame to generate SQL: diff --git a/narwhals/sql.py b/narwhals/sql.py index 6b329a21da..c8fe6020cc 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -27,27 +27,42 @@ class SQLTable(LazyFrame[duckdb.DuckDBPyRelation]): + """A LazyFrame with an additional `to_sql` method.""" + def __init__( self, df: CompliantLazyFrameAny, level: Literal["full", "interchange", "lazy"] ) -> None: super().__init__(df, level=level) - def to_sql(self, *, pretty: bool = False, dialect: str = "duckdb") -> str: + def to_sql(self, *, pretty: bool = False) -> str: + """Convert to SQL query. + + Arguments: + pretty: Whether to pretty-print SQL query. If `True`, requires `sqlparse` + to be installed. + + Examples: + >>> import narwhals as nw + >>> from narwhals.sql import table + >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} + >>> assets = table("assets", schema) + >>> result = assets.filter(nw.col("price") > 100) + >>> print(result.to_sql()) + SELECT * FROM main.assets WHERE (price > 100) + """ sql_query = self.to_native().sql_query() - if not pretty and dialect == "duckdb": + if not pretty: return sql_query try: - import sqlglot + import sqlparse except ImportError as _exc: # pragma: no cover msg = ( - "`SQLTable.to_sql` with `pretty=True` or `dialect!='duckdb'` " - "requires `sqlglot` to be installed.\n\n" + "`SQLTable.to_sql` with `pretty=True`" + "requires `sqlparse` to be installed.\n\n" "Hint: run `pip install -U narwhals[sql]`" ) raise ModuleNotFoundError(msg) from _exc - return sqlglot.transpile( - sql_query, read="duckdb", identity=False, write=dialect, pretty=pretty - )[0] + return sqlparse.format(sql_query, reindent=True, keyword_case="upper") def table(name: str, schema: IntoSchema) -> SQLTable: @@ -59,17 +74,21 @@ def table(name: str, schema: IntoSchema) -> SQLTable: name: Table name. schema: Table schema. - Returns: - A LazyFrame. - Examples: >>> import narwhals as nw >>> from narwhals.sql import table - >>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String} - >>> assets = table("assets", schema) - >>> result = assets.filter(nw.col("price") > 100) - >>> print(result.to_sql()) - SELECT * FROM main.assets WHERE (price > 100) + >>> schema = {"date": nw.Date, "price": nw.List(nw.Int64), "symbol": nw.String} + >>> table("t", schema) + ┌────────────────────────────┐ + | Narwhals LazyFrame | + |----------------------------| + |┌──────┬─────────┬─────────┐| + |│ date │ price │ symbol │| + |│ date │ int64[] │ varchar │| + |├──────┴─────────┴─────────┤| + |│ 0 rows │| + |└──────────────────────────┘| + └────────────────────────────┘ """ column_mapping = { col: narwhals_to_native_dtype(dtype, Version.MAIN, TZ) diff --git a/pyproject.toml b/pyproject.toml index c38442b199..dee8c679f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.1"] ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"] sqlframe = ["sqlframe>=3.22.0,!=3.39.3"] -sql = ["duckdb>=1.1"] +sql = ["duckdb>=1.1", "sqlparse"] [dependency-groups] core = [ From 651dbe31ecfb8eb8c19dffd05a18f6bbdf566199 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 1 Nov 2025 09:38:45 +0000 Subject: [PATCH 09/11] update --- narwhals/sql.py | 2 +- pyproject.toml | 4 +++- tests/sql_test.py | 20 ++++++++++++-------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/narwhals/sql.py b/narwhals/sql.py index c8fe6020cc..d881b0d500 100644 --- a/narwhals/sql.py +++ b/narwhals/sql.py @@ -62,7 +62,7 @@ def to_sql(self, *, pretty: bool = False) -> str: "Hint: run `pip install -U narwhals[sql]`" ) raise ModuleNotFoundError(msg) from _exc - return sqlparse.format(sql_query, reindent=True, keyword_case="upper") + return sqlparse.format(sql_query, reindent=True, keyword_case="upper") # type: ignore[no-any-return] def table(name: str, schema: IntoSchema) -> SQLTable: diff --git a/pyproject.toml b/pyproject.toml index dee8c679f3..4b6f34e200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ sql = ["duckdb>=1.1", "sqlparse"] [dependency-groups] core = [ - "narwhals[duckdb,pandas,polars,pyarrow,sqlframe]" + "narwhals[duckdb,pandas,polars,pyarrow,sqlframe,sql]" ] tests = [ "covdefaults", @@ -86,6 +86,7 @@ docs = [ "jinja2", "duckdb", "narwhals[ibis]", + "narwhals[sql]", "markdown-exec[ansi]", "mkdocs", "mkdocs-autorefs", @@ -342,6 +343,7 @@ module = [ "numpy.*", "pyspark.*", "sklearn.*", + "sqlparse.*", ] # TODO: remove follow_imports follow_imports = "skip" diff --git a/tests/sql_test.py b/tests/sql_test.py index 0d9ffe25c0..05520b193a 100644 --- a/tests/sql_test.py +++ b/tests/sql_test.py @@ -14,14 +14,18 @@ def test_sql() -> None: schema = {"date": nw.Date(), "price": nw.Int64(), "symbol": nw.String()} assets = table("assets", schema) - result = ( - assets.with_columns( - returns=(nw.col("price") / nw.col("price").shift(1)).over( - "symbol", order_by="date" - ) + result = assets.with_columns( + returns=(nw.col("price") / nw.col("price").shift(1)).over( + "symbol", order_by="date" ) - .to_native() - .sql_query() ) expected = """SELECT date, price, symbol, (price / lag(price, 1) OVER (PARTITION BY symbol ORDER BY date ASC NULLS FIRST)) AS "returns" FROM main.assets""" - assert result == expected + assert result.to_sql() == expected + expected = ( + "SELECT date, price,\n" + " symbol,\n" + " (price / lag(price, 1) OVER (PARTITION BY symbol\n" + ' ORDER BY date ASC NULLS FIRST)) AS "returns"\n' + "FROM main.assets" + ) + assert result.to_sql(pretty=True) == expected From a12842be6a38bffbfc640da5e600df4ea7c38a1e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 1 Nov 2025 09:51:18 +0000 Subject: [PATCH 10/11] skip if no sqlparse --- tests/sql_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/sql_test.py b/tests/sql_test.py index 05520b193a..1f9db2728d 100644 --- a/tests/sql_test.py +++ b/tests/sql_test.py @@ -8,6 +8,7 @@ def test_sql() -> None: pytest.importorskip("duckdb") + pytest.importorskip("sqlparse") if DUCKDB_VERSION < (1, 3): pytest.skip() from narwhals.sql import table From 07c4447d617813336825ef2acb59482b8c941f99 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 1 Nov 2025 09:53:05 +0000 Subject: [PATCH 11/11] add sqlglot note --- docs/generating_sql.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/generating_sql.md b/docs/generating_sql.md index 69826ff970..aedf0f8365 100644 --- a/docs/generating_sql.md +++ b/docs/generating_sql.md @@ -48,6 +48,10 @@ note that this currently requires [sqlparse](https://github.com/andialbrecht/sql print(result.to_sql(pretty=True)) ``` +Note that the generated SQL follows DuckDB's dialect. To translate it to other dialects, +you may want to look into [sqlglot](https://github.com/tobymao/sqlglot), or use one of the +solutions below (which also use sqlglot). + ## Via Ibis You can also use Ibis or SQLFrame to generate SQL: