Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions flink-python/pyflink/table/tests/test_pandas_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
################################################################################
import uuid

from pyflink.common import Row
from pyflink.table.expressions import col, call, lit, row_interval
from pyflink.table.types import DataTypes
from pyflink.table.udf import udaf, udf, AggregateFunction
Expand All @@ -39,6 +40,17 @@ def setUpClass(cls):
func_type="pandas"))
cls.t_env.create_temporary_system_function("mean_udaf", mean_udaf)

def test_pandas_udaf_in_sql(self):
sql = f"""
CREATE TEMPORARY FUNCTION pymean AS
'{BatchPandasUDAFITTests.__module__}.mean_str_udaf'
LANGUAGE PYTHON
"""
self.t_env.execute_sql(sql)
self.assert_equals(
list(self.t_env.execute_sql("SELECT pymean(1)").collect()),
[Row(1)])

def test_check_result_type(self):
def pandas_udaf():
pass
Expand Down Expand Up @@ -861,6 +873,11 @@ def mean_udaf(v):
return v.mean()


@udaf(input_types=['FLOAT'], result_type='FLOAT', func_type="pandas")
def mean_str_udaf(v):
return v.mean()


class MaxAdd(AggregateFunction):

def __init__(self):
Expand Down
365 changes: 191 additions & 174 deletions flink-python/pyflink/table/tests/test_udf.py

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions flink-python/pyflink/table/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
################################################################################
import unittest

from pyflink.common import Row
from pyflink.table import DataTypes
from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf
from pyflink.table.expressions import col
Expand Down Expand Up @@ -71,6 +72,19 @@ def test_table_function_with_sql_query(self):
actual = source_sink_utils.results()
self.assert_equals(actual, ["+I[1, 1, 0]", "+I[2, 2, 0]", "+I[3, 3, 0]", "+I[3, 3, 1]"])

def test_table_function_in_sql(self):
sql = f"""
CREATE TEMPORARY FUNCTION pyfunc AS
'{UserDefinedTableFunctionTests.__module__}.identity'
LANGUAGE PYTHON
"""
self.t_env.execute_sql(sql)
self.assert_equals(
list(self.t_env.execute_sql(
"SELECT v FROM (VALUES (1)) AS T(id), LATERAL TABLE(pyfunc(id)) AS P(v)"
).collect()),
[Row(1)])


class PyFlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
PyFlinkStreamTableTestCase):
Expand Down
4 changes: 0 additions & 4 deletions flink-python/pyflink/table/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,6 @@ def _create_judf(self, serialized_func, j_input_types, j_function_kind):
else:
self._accumulator_type = 'ARRAY<{0}>'.format(self._result_type)

if j_input_types is not None:
gateway = get_gateway()
j_input_types = java_utils.to_jarray(
gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
if isinstance(self._result_type, DataType):
j_result_type = _to_java_data_type(self._result_type)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,9 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) {

if (inputTypesString != null) {
inputTypes =
(DataType[])
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray();
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray(DataType[]::new);
}

if (inputTypes != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,9 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) {

if (inputTypesString != null) {
inputTypes =
(DataType[])
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray();
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray(DataType[]::new);
}

if (inputTypes != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,9 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) {
TypeInference.Builder builder = TypeInference.newBuilder();
if (inputTypesString != null) {
inputTypes =
(DataType[])
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray();
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray(DataType[]::new);
}

if (inputTypes != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) {

if (inputTypesString != null) {
inputTypes =
(DataType[])
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray();
Arrays.stream(inputTypesString)
.map(typeFactory::createDataType)
.toArray(DataType[]::new);
}

if (inputTypes != null) {
Expand Down