Skip to content

Commit f18a191

Browse files
Merge pull request #72 from InfluxCommunity/67-support-polars-conversion
67 support polars conversion
2 parents 436930d + 05b04fe commit f18a191

File tree

5 files changed

+171
-4
lines changed

5 files changed

+171
-4
lines changed

influxdb_client_3/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from pyarrow.flight import FlightClient, Ticket, FlightCallOptions
88
from influxdb_client_3.read_file import UploadFile
99
import urllib.parse
10+
try:
11+
import polars as pl
12+
polars = True
13+
except ImportError:
14+
polars = False
1015

1116

1217

@@ -216,7 +221,10 @@ def query(self, query, language="sql", mode="all", database=None,**kwargs ):
216221
:param kwargs: FlightClientCallOptions for the query.
217222
:return: The queried data.
218223
"""
224+
if mode == "polars" and polars is False:
225+
raise ImportError("Polars is not installed. Please install it with `pip install polars`.")
219226

227+
220228

221229
if database is None:
222230
database = self._database
@@ -237,9 +245,11 @@ def query(self, query, language="sql", mode="all", database=None,**kwargs ):
237245
mode_func = {
238246
"all": flight_reader.read_all,
239247
"pandas": flight_reader.read_pandas,
248+
"polars": lambda: pl.from_arrow(flight_reader.read_all()),
240249
"chunk": lambda: flight_reader,
241250
"reader": flight_reader.to_reader,
242251
"schema": lambda: flight_reader.schema
252+
243253
}.get(mode, flight_reader.read_all)
244254

245255
return mode_func() if callable(mode_func) else mode_func

influxdb_client_3/write_client/client/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(self, influxdb_client, point_settings=None):
219219
self._point_settings.add_default_tag(key, value)
220220

221221
def _append_default_tag(self, key, val, record):
222-
from write_client import Point
222+
from influxdb_client_3.write_client import Point
223223
if isinstance(record, bytes) or isinstance(record, str):
224224
pass
225225
elif isinstance(record, Point):

influxdb_client_3/write_client/client/write/dataframe_serializer.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,132 @@ def number_of_chunks(self):
280280
return self.number_of_chunks
281281

282282

283+
class PolarsDataframeSerializer:
284+
"""Serialize DataFrame into LineProtocols."""
285+
286+
def __init__(self, data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, chunk_size: int = None,
287+
**kwargs) -> None:
288+
"""
289+
Init serializer.
290+
291+
:param data_frame: Polars DataFrame to serialize
292+
:param point_settings: Default Tags
293+
:param precision: The precision for the unix timestamps within the body line-protocol.
294+
:param chunk_size: The size of chunk for serializing into chunks.
295+
:key data_frame_measurement_name: name of measurement for writing Polars DataFrame
296+
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
297+
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp.
298+
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column
299+
"""
300+
301+
302+
self.data_frame = data_frame
303+
self.point_settings = point_settings
304+
self.precision = precision
305+
self.chunk_size = chunk_size
306+
self.measurement_name = kwargs.get("data_frame_measurement_name", "measurement")
307+
self.tag_columns = kwargs.get("data_frame_tag_columns", [])
308+
self.timestamp_column = kwargs.get("data_frame_timestamp_column", None)
309+
self.timestamp_timezone = kwargs.get("data_frame_timestamp_timezone", None)
310+
311+
self.column_indices = {name: index for index, name in enumerate(data_frame.columns)}
312+
313+
#
314+
# prepare chunks
315+
#
316+
if chunk_size is not None:
317+
self.number_of_chunks = int(math.ceil(len(data_frame) / float(chunk_size)))
318+
self.chunk_size = chunk_size
319+
else:
320+
self.number_of_chunks = None
321+
322+
def escape_value(self,value):
323+
return str(value).translate(_ESCAPE_KEY)
324+
325+
326+
def to_line_protocol(self, row):
327+
# Filter out None or empty values for tags
328+
tags = ""
329+
330+
tags = ",".join(
331+
f'{self.escape_value(col)}={self.escape_value(row[self.column_indices[col]])}'
332+
for col in self.tag_columns
333+
if row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
334+
)
335+
336+
if self.point_settings.defaultTags:
337+
default_tags = ",".join(
338+
f'{self.escape_value(key)}={self.escape_value(value)}'
339+
for key, value in self.point_settings.defaultTags.items()
340+
)
341+
# Ensure there's a comma between existing tags and default tags if both are present
342+
if tags and default_tags:
343+
tags += ","
344+
tags += default_tags
345+
346+
347+
348+
349+
# add escape symbols for special characters to tags
350+
351+
fields = ",".join(
352+
f"{col}=\"{row[self.column_indices[col]]}\"" if isinstance(row[self.column_indices[col]], str)
353+
else f"{col}={row[self.column_indices[col]]}i" if isinstance(row[self.column_indices[col]], int)
354+
else f"{col}={row[self.column_indices[col]]}"
355+
for col in self.column_indices
356+
if col not in self.tag_columns + [self.timestamp_column]
357+
and row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
358+
)
359+
360+
# Access the Unix timestamp
361+
timestamp = row[self.column_indices[self.timestamp_column]]
362+
if tags != "":
363+
line_protocol = f"{self.measurement_name},{tags} {fields} {timestamp}"
364+
else:
365+
line_protocol = f"{self.measurement_name} {fields} {timestamp}"
366+
367+
return line_protocol
368+
369+
370+
def serialize(self, chunk_idx: int = None):
371+
from ...extras import pl
372+
373+
df = self.data_frame
374+
375+
# Convert timestamp to unix timestamp
376+
print(self.precision)
377+
if self.precision is None:
378+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
379+
elif self.precision == 'ns':
380+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
381+
elif self.precision == 'us':
382+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="us").alias(self.timestamp_column))
383+
elif self.precision == 'ms':
384+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="ms").alias(self.timestamp_column))
385+
elif self.precision == 's':
386+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="s").alias(self.timestamp_column))
387+
else:
388+
raise ValueError(f"Unsupported precision: {self.precision}")
389+
390+
if chunk_idx is None:
391+
chunk = df
392+
else:
393+
logger.debug("Serialize chunk %s/%s ...", chunk_idx + 1, self.number_of_chunks)
394+
chunk = df[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size]
395+
396+
# Apply the UDF to each row
397+
line_protocol_expr = chunk.apply(self.to_line_protocol,return_dtype=pl.Object)
398+
399+
lp = line_protocol_expr['map'].to_list()
400+
401+
402+
return lp
403+
404+
405+
406+
407+
408+
283409
def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
284410
"""
285411
Serialize DataFrame into LineProtocols.
@@ -295,3 +421,19 @@ def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_W
295421
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
296422
""" # noqa: E501
297423
return DataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()
424+
425+
def polars_data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
426+
"""
427+
Serialize DataFrame into LineProtocols.
428+
429+
:param data_frame: Pandas DataFrame to serialize
430+
:param point_settings: Default Tags
431+
:param precision: The precision for the unix timestamps within the body line-protocol.
432+
:key data_frame_measurement_name: name of measurement for writing Pandas DataFrame
433+
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
434+
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp. The column can be defined as a :class:`~str` value
435+
formatted as `2018-10-26`, `2018-10-26 12:00`, `2018-10-26 12:00:00-05:00`
436+
or other formats and types supported by `pandas.to_datetime <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html#pandas.to_datetime>`_ - ``DataFrame``
437+
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
438+
""" # noqa: E501
439+
return PolarsDataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()

influxdb_client_3/write_client/client/write_api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from influxdb_client_3.write_client.domain import WritePrecision
2020
from influxdb_client_3.write_client.client._base import _BaseWriteApi, _HAS_DATACLASS
2121
from influxdb_client_3.write_client.client.util.helpers import get_org_query_param
22-
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer
22+
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer, PolarsDataframeSerializer
2323
from influxdb_client_3.write_client.client.write.point import Point, DEFAULT_WRITE_PRECISION
2424
from influxdb_client_3.write_client.client.write.retry import WritesRetry
2525
from influxdb_client_3.write_client.rest import _UTF_8_encoding
@@ -460,14 +460,24 @@ def _write_batching(self, bucket, org, data,
460460
elif isinstance(data, dict):
461461
self._write_batching(bucket, org, Point.from_dict(data, write_precision=precision, **kwargs),
462462
precision, **kwargs)
463+
464+
elif 'polars' in str(type(data)):
465+
serializer = PolarsDataframeSerializer(data, self._point_settings, precision, self._write_options.batch_size,
466+
**kwargs)
467+
for chunk_idx in range(serializer.number_of_chunks):
468+
self._write_batching(bucket, org,
469+
serializer.serialize(chunk_idx),
470+
precision, **kwargs)
463471

464-
elif 'DataFrame' in type(data).__name__:
472+
elif 'pandas' in str(type(data)):
465473
serializer = DataframeSerializer(data, self._point_settings, precision, self._write_options.batch_size,
466474
**kwargs)
467475
for chunk_idx in range(serializer.number_of_chunks):
468476
self._write_batching(bucket, org,
469477
serializer.serialize(chunk_idx),
470478
precision, **kwargs)
479+
480+
471481
elif hasattr(data, "_asdict"):
472482
# noinspection PyProtectedMember
473483
self._write_batching(bucket, org, data._asdict(), precision, **kwargs)

influxdb_client_3/write_client/extras.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@
1010
except ModuleNotFoundError as err:
1111
raise ImportError(f"`data_frame` requires numpy which couldn't be imported due: {err}")
1212

13-
__all__ = ['pd', 'np']
13+
try:
14+
import polars as pl
15+
except ModuleNotFoundError as err:
16+
raise ImportError(f"`polars_frame` requires polars which couldn't be imported due: {err}")
17+
18+
__all__ = ['pd', 'np', 'pl']

0 commit comments

Comments
 (0)