Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/api-reference/sql.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# `narwhals.sql`

::: narwhals.sql
handler: python
options:
members:
- table

::: narwhals.sql.SQLTable
handler: python
options:
members:
- to_sql
show_source: false
show_bases: false
54 changes: 28 additions & 26 deletions docs/generating_sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,62 @@ For example, what's the SQL equivalent to:

```python exec="1" source="above" session="generating-sql"
import narwhals as nw
from narwhals.typing import IntoFrameT
from narwhals.typing import FrameT


def avg_monthly_price(df_native: IntoFrameT) -> IntoFrameT:
def avg_monthly_price(df: FrameT) -> FrameT:
return (
nw.from_native(df_native)
.group_by(nw.col("date").dt.truncate("1mo"))
df.group_by(nw.col("date").dt.truncate("1mo"))
.agg(nw.col("price").mean())
.sort("date")
.to_native()
)
```

?

There are several ways to find out.
Narwhals provides you with a `narwhals.sql` module to do just that!

## Via DuckDB
!!! info
`narwhals.sql` currently requires DuckDB to be installed.

## `narwhals.sql`

You can generate SQL directly from DuckDB.

```python exec="1" source="above" session="generating-sql" result="sql"
import duckdb
import narwhals as nw
from narwhals.sql import table

conn = duckdb.connect()
conn.sql("""CREATE TABLE prices (date DATE, price DOUBLE);""")
prices = table("prices", {"date": nw.Date, "price": nw.Float64})

df = nw.from_native(conn.table("prices"))
print(avg_monthly_price(df).sql_query())
result = (
prices.group_by(nw.col("date").dt.truncate("1mo"))
.agg(nw.col("price").mean())
.sort("date")
)
print(result.to_sql())
```

To make it look a bit prettier, or to then transpile it to other SQL dialects, we can pass it to [SQLGlot](https://github.com/tobymao/sqlglot):
To make it look a bit prettier, you can pass `pretty=True`, but
note that this currently requires [sqlparse](https://github.com/andialbrecht/sqlparse) to be installed.

```python exec="1" source="above" session="generating-sql" result="sql"
import sqlglot

print(sqlglot.transpile(avg_monthly_price(df).sql_query(), pretty=True)[0])
print(result.to_sql(pretty=True))
```

Note that the generated SQL follows DuckDB's dialect. To translate it to other dialects,
you may want to look into [sqlglot](https://github.com/tobymao/sqlglot), or use one of the
solutions below (which also use sqlglot).

## Via Ibis

We can also use Ibis to generate SQL:
You can also use Ibis or SQLFrame to generate SQL:

```python exec="1" source="above" session="generating-sql" result="sql"
import ibis

t = ibis.table({"date": "date", "price": "double"}, name="prices")
print(ibis.to_sql(avg_monthly_price(t)))
df = nw.from_native(ibis.table({"date": "date", "price": "double"}, name="prices"))
print(ibis.to_sql(avg_monthly_price(df).to_native()))
```

## Via SQLFrame
Expand All @@ -66,11 +74,5 @@ session = StandaloneSession.builder.getOrCreate()
session.catalog.add_table("prices", column_mapping={"date": "date", "price": "float"})
df = nw.from_native(session.read.table("prices"))

print(avg_monthly_price(df).sql(dialect="duckdb"))
```

Or, to print the SQL code in a different dialect (say, databricks):

```python exec="1" source="above" session="generating-sql" result="sql"
print(avg_monthly_price(df).sql(dialect="databricks"))
print(avg_monthly_price(df).to_native().sql(dialect="duckdb"))
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ nav:
- api-reference/dtypes.md
- api-reference/exceptions.md
- api-reference/selectors.md
- api-reference/sql.md
- api-reference/testing.md
- api-reference/typing.md
- api-reference/utils.md
Expand Down
106 changes: 106 additions & 0 deletions narwhals/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from narwhals._duckdb.utils import DeferredTimeZone, narwhals_to_native_dtype
from narwhals.dataframe import LazyFrame
from narwhals.translate import from_native
from narwhals.utils import Version

if TYPE_CHECKING:
from narwhals._compliant.typing import CompliantLazyFrameAny
from narwhals.typing import IntoSchema

try:
import duckdb # ignore-banned-import
except ImportError as _exc: # pragma: no cover
msg = (
"`narwhals.sql` requires DuckDB to be installed.\n\n"
"Hint: run `pip install -U narwhals[sql]`"
)
raise ModuleNotFoundError(msg) from _exc

CONN = duckdb.connect()
TZ = DeferredTimeZone(
CONN.sql("select value from duckdb_settings() where name = 'TimeZone'")
)
Comment on lines +14 to +26
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (hope?) that in the future we could make a narwhals-sqlglot or narwhals-substrait plugin and use that here. But for now, I think using DuckDB for this is quite nice



class SQLTable(LazyFrame[duckdb.DuckDBPyRelation]):
"""A LazyFrame with an additional `to_sql` method."""

def __init__(
self, df: CompliantLazyFrameAny, level: Literal["full", "interchange", "lazy"]
) -> None:
super().__init__(df, level=level)

def to_sql(self, *, pretty: bool = False) -> str:
"""Convert to SQL query.

Arguments:
pretty: Whether to pretty-print SQL query. If `True`, requires `sqlparse`
to be installed.

Examples:
>>> import narwhals as nw
>>> from narwhals.sql import table
>>> schema = {"date": nw.Date, "price": nw.Int64, "symbol": nw.String}
>>> assets = table("assets", schema)
>>> result = assets.filter(nw.col("price") > 100)
>>> print(result.to_sql())
SELECT * FROM main.assets WHERE (price > 100)
"""
sql_query = self.to_native().sql_query()
if not pretty:
return sql_query
try:
import sqlparse
except ImportError as _exc: # pragma: no cover
msg = (
"`SQLTable.to_sql` with `pretty=True`"
"requires `sqlparse` to be installed.\n\n"
"Hint: run `pip install -U narwhals[sql]`"
)
raise ModuleNotFoundError(msg) from _exc
return sqlparse.format(sql_query, reindent=True, keyword_case="upper") # type: ignore[no-any-return]


def table(name: str, schema: IntoSchema) -> SQLTable:
"""Generate standalone LazyFrame which you can use to generate SQL.

Note that this requires DuckDB to be installed.

Parameters:
name: Table name.
schema: Table schema.

Examples:
>>> import narwhals as nw
>>> from narwhals.sql import table
>>> schema = {"date": nw.Date, "price": nw.List(nw.Int64), "symbol": nw.String}
>>> table("t", schema)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals LazyFrame |
|----------------------------|
|β”Œβ”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”|
|β”‚ date β”‚ price β”‚ symbol β”‚|
|β”‚ date β”‚ int64[] β”‚ varchar β”‚|
|β”œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€|
|β”‚ 0 rows β”‚|
|β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
column_mapping = {
col: narwhals_to_native_dtype(dtype, Version.MAIN, TZ)
for col, dtype in schema.items()
}
dtypes = ", ".join(f'"{col}" {dtype}' for col, dtype in column_mapping.items())
CONN.sql(f"""
CREATE TABLE "{name}"
({dtypes});
""")
lf = from_native(CONN.table(name))
return SQLTable(lf._compliant_frame, level=lf._level)


__all__ = ["table"]
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ dask = ["dask[dataframe]>=2024.8"]
duckdb = ["duckdb>=1.1"]
ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"]
sqlframe = ["sqlframe>=3.22.0,!=3.39.3"]
sql = ["duckdb>=1.1", "sqlparse"]

[dependency-groups]
core = [
"narwhals[duckdb,pandas,polars,pyarrow,sqlframe]"
"narwhals[duckdb,pandas,polars,pyarrow,sqlframe,sql]"
]
tests = [
"covdefaults",
Expand Down Expand Up @@ -85,6 +86,7 @@ docs = [
"jinja2",
"duckdb",
"narwhals[ibis]",
"narwhals[sql]",
"markdown-exec[ansi]",
"mkdocs",
"mkdocs-autorefs",
Expand Down Expand Up @@ -341,6 +343,7 @@ module = [
"numpy.*",
"pyspark.*",
"sklearn.*",
"sqlparse.*",
]
# TODO: remove follow_imports
follow_imports = "skip"
Expand Down
32 changes: 32 additions & 0 deletions tests/sql_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import pytest

import narwhals as nw
from tests.utils import DUCKDB_VERSION


def test_sql() -> None:
pytest.importorskip("duckdb")
pytest.importorskip("sqlparse")
if DUCKDB_VERSION < (1, 3):
pytest.skip()
from narwhals.sql import table

schema = {"date": nw.Date(), "price": nw.Int64(), "symbol": nw.String()}
assets = table("assets", schema)
result = assets.with_columns(
returns=(nw.col("price") / nw.col("price").shift(1)).over(
"symbol", order_by="date"
)
)
expected = """SELECT date, price, symbol, (price / lag(price, 1) OVER (PARTITION BY symbol ORDER BY date ASC NULLS FIRST)) AS "returns" FROM main.assets"""
assert result.to_sql() == expected
expected = (
"SELECT date, price,\n"
" symbol,\n"
" (price / lag(price, 1) OVER (PARTITION BY symbol\n"
' ORDER BY date ASC NULLS FIRST)) AS "returns"\n'
"FROM main.assets"
)
assert result.to_sql(pretty=True) == expected
Loading