自定义函数(UDF)
Flink的Table API和SQL提供了多种自定义函数的接口,以抽象类的形式定义。当前UDF主要有以下几类:
- 标量函数(Scalar Functions):将输入的标量值转换成一个新的标量值;
- 表函数(Table Functions):将标量值转换成一个或多个新的行数据,也就是扩展成一个表;
- 聚合函数(Aggregate Functions):将多行数据里的标量值转换成一个新的标量值;
- 表聚合函数(Table Aggregate Functions):将多行数据里的标量值转换成一个或多个新的行数据。
1. 整体调用流程
要想在代码中使用自定义的函数,我们需要首先自定义对应UDF抽象类的实现,并在表环境中注册这个函数,然后就可以在Table API和SQL中调用了。
1.1 注册函数
注册函数时需要调用表环境的createTemporarySystemFunction()
方法,传入注册的函数名以及UDF类的Class对象:
// 注册函数
tableEnv.createTemporarySystemFunction("MyFunction", MyFunction.class);
1.2 使用Table API调用函数
在Table API中,需要使用call()方法来调用自定义函数:
tableEnv.from("MyTable").select(call("MyFunction", $("myField")));
1.3 在SQL中调用函数
在SQL中的调用就与内置系统函数完全一样:
tableEnv.sqlQuery("SELECT MyFunction(myField) FROM MyTable");
2. 标量函数(Scalar Functions)
想要实现自定义的标量函数,我们需要自定义一个类来继承抽象类ScalarFunction,并实现叫作eval()
的求值方法。它必须是公有的(public),而且名字必须是eval。求值方法eval可以重载多次,任何数据类型都可作为求值方法的参数和返回值类型。 比如实现一个自定义的哈希(hash)函数HashFunction,用来求传入对象的哈希值。
public class MyScalarFunctionDemo {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStreamSource<WaterSensor> sensorDS = env.fromElements(
new WaterSensor("s1", 1L, 1),
new WaterSensor("s1", 2L, 2),
new WaterSensor("s2", 2L, 2),
new WaterSensor("s3", 3L, 3),
new WaterSensor("s3", 4L, 4)
);
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
Table sensorTable = tableEnv.fromDataStream(sensorDS);
tableEnv.createTemporaryView("sensor", sensorTable);
// 2.注册函数
tableEnv.createTemporaryFunction("HashFunction", HashFunction.class);
// 3.调用 自定义函数
tableEnv.sqlQuery("select id, HashFunction(id) from sensor")
.execute() // 调用了表的execute,就不需要调用env.execute()
.print();
}
// 1.定义 自定义函数的实现类
public static class HashFunction extends ScalarFunction {
// 接受任意类型的输入,返回 INT型输出
public int eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object o) {
return o.hashCode();
}
}
}
运行结果:
3. 表函数(Table Functions)
跟标量函数一样,表函数的输入参数也可以是 0个、1个或多个标量值;不同的是,它可以返回任意多行数据。类似地,要实现自定义的表函数,需要自定义类来继承抽象类TableFunction,内部必须要实现的也是一个名为 eval 的求值方法。与标量函数不同的是,TableFunction类本身是有一个泛型参数T的,这就是表函数返回数据的类型;而eval()方法没有返回类型,内部也没有return语句,是通过调用collect()方法来发送想要输出的行数据的。
public class MyTableFunctionDemo {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStreamSource<String> strDS = env.fromElements(
"hello flink",
"hello world hi",
"hello java"
);
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
Table sensorTable = tableEnv.fromDataStream(strDS, $("words"));
tableEnv.createTemporaryView("str", sensorTable);
// 2.注册函数
tableEnv.createTemporaryFunction("SplitFunction", SplitFunction.class);
// 3.调用 自定义函数
// 3.1 交叉联结
tableEnv
// 3.1 交叉联结
// .sqlQuery("select words,word,length from str,lateral table(SplitFunction(words))")
// 3.2 带 on true 条件的 左联结
// .sqlQuery("select words,word,length from str left join lateral table(SplitFunction(words)) on true")
// 重命名侧向表中的字段
.sqlQuery("""
select words,newWord,newLength
from str
left join lateral table(SplitFunction(words)) as T(newWord,newLength) on true
""")
.execute()
.print();
}
// 1.继承 TableFunction<返回的类型>
// 类型标注: Row包含两个字段:word和length
@FunctionHint(output = @DataTypeHint("ROW<word STRING,length INT>"))
public static class SplitFunction extends TableFunction<Row> {
// 返回是 void,用 collect方法输出
public void eval(String str) {
for (String word : str.split(" ")) {
collect(Row.of(word, word.length()));
}
}
}
}
执行结果:
4. 聚合函数(Aggregate Functions)
用户自定义聚合函数(User Defined AGGregate function,UDAGG)会把一行或多行数据(也就是一个表)聚合成一个标量值。这是一个标准的"多对一"的转换。聚合函数的概念我们之前已经接触过多次,如SUM()、MAX()、MIN()、AVG()、COUNT()都是常见的系统内置聚合函数。而如果有些需求无法直接调用系统函数解决,我们就必须自定义聚合函数来实现功能了。自定义聚合函数需要继承抽象类AggregateFunction。AggregateFunction有两个泛型参数<T, ACC>
,T表示聚合输出的结果类型,ACC则表示聚合的中间状态类型。
AggregateFunction的所有方法都必须是 公有的(public),不能是静态的(static),而且名字必须跟上面写的完全一样。createAccumulator()、getValue()、getResultType()以及getAccumulatorType()这几个方法是在抽象类AggregateFunction中定义的,可以override;而其他则都是底层架构约定的方法。
比如从学生的分数表ScoreTable中计算每个学生的加权平均分。
public class MyAggregateFunctionDemo {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// 姓名,分数,权重
DataStreamSource<Tuple3<String,Integer, Integer>> scoreWeightDS = env.fromElements(
Tuple3.of("zs",80, 3),
Tuple3.of("zs",90, 4),
Tuple3.of("zs",95, 4),
Tuple3.of("ls",75, 4),
Tuple3.of("ls",65, 4),
Tuple3.of("ls",85, 4)
);
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
Table scoreWeightTable = tableEnv.fromDataStream(scoreWeightDS, $("f0").as("name"),$("f1").as("score"), $("f2").as("weight"));
tableEnv.createTemporaryView("scores", scoreWeightTable);
// 2.注册函数
tableEnv.createTemporaryFunction("WeightedAvg", WeightedAvg.class);
// 3.调用 自定义函数
tableEnv
.sqlQuery("select name,WeightedAvg(score,weight) from scores group by name")
.execute()
.print();
}
public static class WeightedAvg extends AggregateFunction<Double, Tuple2<Integer, Integer>> {
@Override
public Double getValue(Tuple2<Integer, Integer> accumulator) {
return accumulator.f0 * 1D / accumulator.f1;
}
@Override
public Tuple2<Integer, Integer> createAccumulator() {
return Tuple2.of(0, 0);
}
/**
* 累加计算的方法,每来一行数据都会调用一次
* @param acc 累加器类型
* @param score 第一个参数:分数
* @param weight 第二个参数:权重
*/
public void accumulate(Tuple2<Integer, Integer> acc,Integer score,Integer weight){
acc.f0 += score * weight; // 加权总和 = 分数1 * 权重1 + 分数2 * 权重2 +....
acc.f1 += weight; // 权重和 = 权重1 + 权重2 +....
}
}
}
运行结果: 聚合函数的
accumulate()
方法有三个输入参数。第一个是WeightedAvgAccum类型的累加器;另外两个则是函数调用时输入的字段:要计算的值value和对应的权重weight。
5. 表聚合函数(Table Aggregate Functions)
用户自定义表聚合函数(UDTAGG)可以把一行或多行数据(也就是一个表)聚合成另一张表,结果表中可以有多行多列。很明显,这就像表函数和聚合函数的结合体,是一个"多对多"的转换。
自定义表聚合函数需要继承抽象类TableAggregateFunction。TableAggregateFunction的结构和原理与AggregateFunction非常类似,同样有两个泛型参数<T, ACC>
,用一个ACC类型的累加器(accumulator)来存储聚合的中间结果。
比如数据的TOP-2查询:
public class MyTableAggregateFunctionDemo {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStreamSource<Integer> numDS = env.fromElements(3, 6, 12, 5, 8, 9, 4);
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
Table numTable = tableEnv.fromDataStream(numDS, $("num"));
// 2.注册函数
tableEnv.createTemporaryFunction("Top2", Top2.class);
// 3.调用 自定义函数: 只能用 Table API
numTable
.flatAggregate(call("Top2", $("num")).as("value", "rank"))
.select( $("value"), $("rank"))
.execute().print();
}
// 1.继承 TableAggregateFunction< 返回类型,累加器类型<加权总和,权重总和> >
// 返回类型 (数值,排名) =》 (12,1) (9,2)
// 累加器类型 (第一大的数,第二大的数) ===》 (12,9)
public static class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
@Override
public Tuple2<Integer, Integer> createAccumulator() {
return Tuple2.of(0, 0);
}
/**
* 每来一个数据调用一次,比较大小,更新 最大的前两个数到 acc中
* @param acc 累加器
* @param num 过来的数据
*/
public void accumulate(Tuple2<Integer, Integer> acc, Integer num) {
if (num > acc.f0) {
// 新来的变第一,原来的第一变第二
acc.f1 = acc.f0;
acc.f0 = num;
} else if (num > acc.f1) {
// 新来的变第二,原来的第二不要了
acc.f1 = num;
}
}
/**
* 输出结果: (数值,排名)两条最大的
* @param acc 累加器
* @param out 采集器<返回类型>
*/
public void emitValue(Tuple2<Integer, Integer> acc, Collector<Tuple2<Integer, Integer>> out) {
if (acc.f0 != 0) {
out.collect(Tuple2.of(acc.f0, 1));
}
if (acc.f1 != 0) {
out.collect(Tuple2.of(acc.f1, 2));
}
}
}
}
运行结果: 目前SQL中没有直接使用表聚合函数的方式,所以需要使用Table API的方式来调用。