Skip to content

Commit b4c8772

Browse files
authored
fix: Polars optional import (#95)
1 parent 0431b75 commit b4c8772

File tree

10 files changed

+289
-160
lines changed

10 files changed

+289
-160
lines changed

influxdb_client_3/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import urllib.parse
2-
32
import pyarrow as pa
43
import importlib.util
54

influxdb_client_3/write_client/client/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
import os
88
from typing import Iterable
99

10-
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer, \
11-
PolarsDataframeSerializer
10+
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer
1211
from influxdb_client_3.write_client.configuration import Configuration
1312
from influxdb_client_3.write_client.rest import _UTF_8_encoding
1413
from influxdb_client_3.write_client.service.write_service import WriteService
@@ -249,6 +248,7 @@ def _serialize(self, record, write_precision, payload, **kwargs):
249248
self._serialize(Point.from_dict(record, write_precision=write_precision, **kwargs),
250249
write_precision, payload, **kwargs)
251250
elif 'polars' in str(type(record)):
251+
from influxdb_client_3.write_client.client.write.dataframe_serializer import PolarsDataframeSerializer
252252
serializer = PolarsDataframeSerializer(record, self._point_settings, write_precision, **kwargs)
253253
self._serialize(serializer.serialize(), write_precision, payload, **kwargs)
254254

influxdb_client_3/write_client/client/write/dataframe_serializer.py

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -284,137 +284,6 @@ def number_of_chunks(self):
284284
return self.number_of_chunks
285285

286286

287-
class PolarsDataframeSerializer:
288-
"""Serialize DataFrame into LineProtocols."""
289-
290-
def __init__(self, data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, chunk_size: int = None,
291-
**kwargs) -> None:
292-
"""
293-
Init serializer.
294-
295-
:param data_frame: Polars DataFrame to serialize
296-
:param point_settings: Default Tags
297-
:param precision: The precision for the unix timestamps within the body line-protocol.
298-
:param chunk_size: The size of chunk for serializing into chunks.
299-
:key data_frame_measurement_name: name of measurement for writing Polars DataFrame
300-
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
301-
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp.
302-
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column
303-
"""
304-
305-
self.data_frame = data_frame
306-
self.point_settings = point_settings
307-
self.precision = precision
308-
self.chunk_size = chunk_size
309-
self.measurement_name = kwargs.get("data_frame_measurement_name", "measurement")
310-
self.tag_columns = kwargs.get("data_frame_tag_columns", [])
311-
self.timestamp_column = kwargs.get("data_frame_timestamp_column", None)
312-
self.timestamp_timezone = kwargs.get("data_frame_timestamp_timezone", None)
313-
314-
self.column_indices = {name: index for index, name in enumerate(data_frame.columns)}
315-
316-
if self.timestamp_column is None or self.timestamp_column not in self.column_indices:
317-
raise ValueError(
318-
f"Timestamp column {self.timestamp_column} not found in DataFrame. Please define a valid timestamp "
319-
f"column.")
320-
321-
#
322-
# prepare chunks
323-
#
324-
if chunk_size is not None:
325-
self.number_of_chunks = int(math.ceil(len(data_frame) / float(chunk_size)))
326-
self.chunk_size = chunk_size
327-
else:
328-
self.number_of_chunks = None
329-
330-
def escape_key(self, value):
331-
return str(value).translate(_ESCAPE_KEY)
332-
333-
def escape_value(self, value):
334-
return str(value).translate(_ESCAPE_STRING)
335-
336-
def to_line_protocol(self, row):
337-
# Filter out None or empty values for tags
338-
tags = ""
339-
340-
tags = ",".join(
341-
f'{self.escape_key(col)}={self.escape_key(row[self.column_indices[col]])}'
342-
for col in self.tag_columns
343-
if row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
344-
)
345-
346-
if self.point_settings.defaultTags:
347-
default_tags = ",".join(
348-
f'{self.escape_key(key)}={self.escape_key(value)}'
349-
for key, value in self.point_settings.defaultTags.items()
350-
)
351-
# Ensure there's a comma between existing tags and default tags if both are present
352-
if tags and default_tags:
353-
tags += ","
354-
tags += default_tags
355-
356-
# add escape symbols for special characters to tags
357-
358-
fields = ",".join(
359-
f"{col}=\"{self.escape_value(row[self.column_indices[col]])}\"" if isinstance(row[self.column_indices[col]],
360-
str)
361-
else f"{col}={str(row[self.column_indices[col]]).lower()}" if isinstance(row[self.column_indices[col]],
362-
bool) # Check for bool first
363-
else f"{col}={row[self.column_indices[col]]}i" if isinstance(row[self.column_indices[col]], int)
364-
else f"{col}={row[self.column_indices[col]]}"
365-
for col in self.column_indices
366-
if col not in self.tag_columns + [self.timestamp_column] and
367-
row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
368-
)
369-
370-
# Access the Unix timestamp
371-
timestamp = row[self.column_indices[self.timestamp_column]]
372-
if tags != "":
373-
line_protocol = f"{self.measurement_name},{tags} {fields} {timestamp}"
374-
else:
375-
line_protocol = f"{self.measurement_name} {fields} {timestamp}"
376-
377-
return line_protocol
378-
379-
def serialize(self, chunk_idx: int = None):
380-
from ...extras import pl
381-
382-
df = self.data_frame
383-
384-
# Check if the timestamp column is already an integer
385-
if df[self.timestamp_column].dtype in [pl.Int32, pl.Int64]:
386-
# The timestamp column is already an integer, assuming it's in Unix format
387-
pass
388-
else:
389-
# Convert timestamp to Unix timestamp based on specified precision
390-
if self.precision in [None, 'ns']:
391-
df = df.with_columns(
392-
pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
393-
elif self.precision == 'us':
394-
df = df.with_columns(
395-
pl.col(self.timestamp_column).dt.epoch(time_unit="us").alias(self.timestamp_column))
396-
elif self.precision == 'ms':
397-
df = df.with_columns(
398-
pl.col(self.timestamp_column).dt.epoch(time_unit="ms").alias(self.timestamp_column))
399-
elif self.precision == 's':
400-
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="s").alias(self.timestamp_column))
401-
else:
402-
raise ValueError(f"Unsupported precision: {self.precision}")
403-
404-
if chunk_idx is None:
405-
chunk = df
406-
else:
407-
logger.debug("Serialize chunk %s/%s ...", chunk_idx + 1, self.number_of_chunks)
408-
chunk = df[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size]
409-
410-
# Apply the UDF to each row
411-
line_protocol_expr = chunk.apply(self.to_line_protocol, return_dtype=pl.Object)
412-
413-
lp = line_protocol_expr['map'].to_list()
414-
415-
return lp
416-
417-
418287
def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
419288
"""
420289
Serialize DataFrame into LineProtocols.
@@ -430,20 +299,3 @@ def data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_W
430299
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
431300
""" # noqa: E501
432301
return DataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()
433-
434-
435-
def polars_data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
436-
"""
437-
Serialize DataFrame into LineProtocols.
438-
439-
:param data_frame: Pandas DataFrame to serialize
440-
:param point_settings: Default Tags
441-
:param precision: The precision for the unix timestamps within the body line-protocol.
442-
:key data_frame_measurement_name: name of measurement for writing Pandas DataFrame
443-
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
444-
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp. The column can be defined as a :class:`~str` value
445-
formatted as `2018-10-26`, `2018-10-26 12:00`, `2018-10-26 12:00:00-05:00`
446-
or other formats and types supported by `pandas.to_datetime <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html#pandas.to_datetime>`_ - ``DataFrame``
447-
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
448-
""" # noqa: E501
449-
return PolarsDataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Functions for serialize Polars DataFrame.
3+
4+
Much of the code here is inspired by that in the aioinflux packet found here: https://github.com/gusutabopb/aioinflux
5+
"""
6+
7+
import logging
8+
import math
9+
10+
from influxdb_client_3.write_client.client.write.point import _ESCAPE_KEY, _ESCAPE_STRING, DEFAULT_WRITE_PRECISION
11+
12+
logger = logging.getLogger('influxdb_client.client.write.polars_dataframe_serializer')
13+
14+
15+
class PolarsDataframeSerializer:
16+
"""Serialize DataFrame into LineProtocols."""
17+
18+
def __init__(self, data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, chunk_size: int = None,
19+
**kwargs) -> None:
20+
"""
21+
Init serializer.
22+
23+
:param data_frame: Polars DataFrame to serialize
24+
:param point_settings: Default Tags
25+
:param precision: The precision for the unix timestamps within the body line-protocol.
26+
:param chunk_size: The size of chunk for serializing into chunks.
27+
:key data_frame_measurement_name: name of measurement for writing Polars DataFrame
28+
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
29+
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp.
30+
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column
31+
"""
32+
33+
self.data_frame = data_frame
34+
self.point_settings = point_settings
35+
self.precision = precision
36+
self.chunk_size = chunk_size
37+
self.measurement_name = kwargs.get("data_frame_measurement_name", "measurement")
38+
self.tag_columns = kwargs.get("data_frame_tag_columns", [])
39+
self.timestamp_column = kwargs.get("data_frame_timestamp_column", None)
40+
self.timestamp_timezone = kwargs.get("data_frame_timestamp_timezone", None)
41+
42+
self.column_indices = {name: index for index, name in enumerate(data_frame.columns)}
43+
44+
if self.timestamp_column is None or self.timestamp_column not in self.column_indices:
45+
raise ValueError(
46+
f"Timestamp column {self.timestamp_column} not found in DataFrame. Please define a valid timestamp "
47+
f"column.")
48+
49+
#
50+
# prepare chunks
51+
#
52+
if chunk_size is not None:
53+
self.number_of_chunks = int(math.ceil(len(data_frame) / float(chunk_size)))
54+
self.chunk_size = chunk_size
55+
else:
56+
self.number_of_chunks = None
57+
58+
def escape_key(self, value):
59+
return str(value).translate(_ESCAPE_KEY)
60+
61+
def escape_value(self, value):
62+
return str(value).translate(_ESCAPE_STRING)
63+
64+
def to_line_protocol(self, row):
65+
# Filter out None or empty values for tags
66+
tags = ""
67+
68+
tags = ",".join(
69+
f'{self.escape_key(col)}={self.escape_key(row[self.column_indices[col]])}'
70+
for col in self.tag_columns
71+
if row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
72+
)
73+
74+
if self.point_settings.defaultTags:
75+
default_tags = ",".join(
76+
f'{self.escape_key(key)}={self.escape_key(value)}'
77+
for key, value in self.point_settings.defaultTags.items()
78+
)
79+
# Ensure there's a comma between existing tags and default tags if both are present
80+
if tags and default_tags:
81+
tags += ","
82+
tags += default_tags
83+
84+
# add escape symbols for special characters to tags
85+
86+
fields = ",".join(
87+
f"{col}=\"{self.escape_value(row[self.column_indices[col]])}\"" if isinstance(row[self.column_indices[col]],
88+
str)
89+
else f"{col}={str(row[self.column_indices[col]]).lower()}" if isinstance(row[self.column_indices[col]],
90+
bool) # Check for bool first
91+
else f"{col}={row[self.column_indices[col]]}i" if isinstance(row[self.column_indices[col]], int)
92+
else f"{col}={row[self.column_indices[col]]}"
93+
for col in self.column_indices
94+
if col not in self.tag_columns + [self.timestamp_column] and
95+
row[self.column_indices[col]] is not None and row[self.column_indices[col]] != ""
96+
)
97+
98+
# Access the Unix timestamp
99+
timestamp = row[self.column_indices[self.timestamp_column]]
100+
if tags != "":
101+
line_protocol = f"{self.measurement_name},{tags} {fields} {timestamp}"
102+
else:
103+
line_protocol = f"{self.measurement_name} {fields} {timestamp}"
104+
105+
return line_protocol
106+
107+
def serialize(self, chunk_idx: int = None):
108+
import polars as pl
109+
110+
df = self.data_frame
111+
112+
# Check if the timestamp column is already an integer
113+
if df[self.timestamp_column].dtype in [pl.Int32, pl.Int64]:
114+
# The timestamp column is already an integer, assuming it's in Unix format
115+
pass
116+
else:
117+
# Convert timestamp to Unix timestamp based on specified precision
118+
if self.precision in [None, 'ns']:
119+
df = df.with_columns(
120+
pl.col(self.timestamp_column).dt.epoch(time_unit="ns").alias(self.timestamp_column))
121+
elif self.precision == 'us':
122+
df = df.with_columns(
123+
pl.col(self.timestamp_column).dt.epoch(time_unit="us").alias(self.timestamp_column))
124+
elif self.precision == 'ms':
125+
df = df.with_columns(
126+
pl.col(self.timestamp_column).dt.epoch(time_unit="ms").alias(self.timestamp_column))
127+
elif self.precision == 's':
128+
df = df.with_columns(pl.col(self.timestamp_column).dt.epoch(time_unit="s").alias(self.timestamp_column))
129+
else:
130+
raise ValueError(f"Unsupported precision: {self.precision}")
131+
132+
if chunk_idx is None:
133+
chunk = df
134+
else:
135+
logger.debug("Serialize chunk %s/%s ...", chunk_idx + 1, self.number_of_chunks)
136+
chunk = df[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size]
137+
138+
# Apply the UDF to each row
139+
line_protocol_expr = chunk.apply(self.to_line_protocol, return_dtype=pl.Object)
140+
141+
lp = line_protocol_expr['map'].to_list()
142+
143+
return lp
144+
145+
146+
def polars_data_frame_to_list_of_points(data_frame, point_settings, precision=DEFAULT_WRITE_PRECISION, **kwargs):
147+
"""
148+
Serialize DataFrame into LineProtocols.
149+
150+
:param data_frame: Pandas DataFrame to serialize
151+
:param point_settings: Default Tags
152+
:param precision: The precision for the unix timestamps within the body line-protocol.
153+
:key data_frame_measurement_name: name of measurement for writing Pandas DataFrame
154+
:key data_frame_tag_columns: list of DataFrame columns which are tags, rest columns will be fields
155+
:key data_frame_timestamp_column: name of DataFrame column which contains a timestamp. The column can be defined as a :class:`~str` value
156+
formatted as `2018-10-26`, `2018-10-26 12:00`, `2018-10-26 12:00:00-05:00`
157+
or other formats and types supported by `pandas.to_datetime <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.to_datetime.html#pandas.to_datetime>`_ - ``DataFrame``
158+
:key data_frame_timestamp_timezone: name of the timezone which is used for timestamp column - ``DataFrame``
159+
""" # noqa: E501
160+
return PolarsDataframeSerializer(data_frame, point_settings, precision, **kwargs).serialize()

influxdb_client_3/write_client/client/write_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from influxdb_client_3.write_client.domain import WritePrecision
2020
from influxdb_client_3.write_client.client._base import _BaseWriteApi, _HAS_DATACLASS
2121
from influxdb_client_3.write_client.client.util.helpers import get_org_query_param
22-
from influxdb_client_3.write_client.client.write.dataframe_serializer import (DataframeSerializer,
23-
PolarsDataframeSerializer)
22+
from influxdb_client_3.write_client.client.write.dataframe_serializer import DataframeSerializer
2423
from influxdb_client_3.write_client.client.write.point import Point, DEFAULT_WRITE_PRECISION
2524
from influxdb_client_3.write_client.client.write.retry import WritesRetry
2625
from influxdb_client_3.write_client.rest import _UTF_8_encoding
@@ -462,6 +461,7 @@ def _write_batching(self, bucket, org, data,
462461
precision, **kwargs)
463462

464463
elif 'polars' in str(type(data)):
464+
from influxdb_client_3.write_client.client.write.dataframe_serializer import PolarsDataframeSerializer
465465
serializer = PolarsDataframeSerializer(data,
466466
self._point_settings, precision,
467467
self._write_options.batch_size, **kwargs)

influxdb_client_3/write_client/extras.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,4 @@
1010
except ModuleNotFoundError as err:
1111
raise ImportError(f"`data_frame` requires numpy which couldn't be imported due: {err}")
1212

13-
try:
14-
import polars as pl
15-
except ModuleNotFoundError as err:
16-
raise ImportError(f"`polars_frame` requires polars which couldn't be imported due: {err}")
17-
18-
__all__ = ['pd', 'np', 'pl']
13+
__all__ = ['pd', 'np']

tests/data/iot.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name,building,temperature,time
2+
iot-devices,5a,72.3,2022-10-01T12:01:00Z
3+
iot-devices,5a,72.1,2022-10-02T12:01:00Z
4+
iot-devices,5a,72.2,2022-10-03T12:01:00Z

0 commit comments

Comments
 (0)