Skip to content

Instantly share code, notes, and snippets.

@AutMaple
Last active August 31, 2023 07:54
Show Gist options
  • Select an option

  • Save AutMaple/c9e04fa30893250eaa5252ecc74bf454 to your computer and use it in GitHub Desktop.

Select an option

Save AutMaple/c9e04fa30893250eaa5252ecc74bf454 to your computer and use it in GitHub Desktop.
[MQTT 作为 Flink 输入源] #flink #mqtt #java
package org.example;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken;
import org.eclipse.paho.client.mqttv3.MqttCallback;
import org.eclipse.paho.client.mqttv3.MqttClient;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.MqttMessage;
import java.util.Arrays;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.TimeUnit;
@Slf4j
@RequiredArgsConstructor
public class ConsumerCallback implements MqttCallback {
private final MqttClient client;
private final MqttConnectOptions options;
private final MqttConfiguration config;
private final BlockingQueue<String> queue;
@Override
public void connectionLost(Throwable throwable) {
for (int i = 0; i < 10; i++) {
try {
TimeUnit.SECONDS.sleep(1);
log.warn("第 " + i + " 次尝试重新连接 Mqtt: " + client.getServerURI());
client.connect(options);
String[] topic = config.getTopic();
int[] qos = new int[topic.length];
Arrays.fill(qos, 1);
client.subscribe(topic, qos);
log.info("MQTT 重新连接成功: " + client.getServerURI());
return;
} catch (InterruptedException | MqttException e) {
log.warn("", e);
}
}
log.error("MQTT 连接断开: " + client.getServerURI(), throwable);
}
@Override
public void messageArrived(String s, MqttMessage mqttMessage) throws Exception {
String msg = new String(mqttMessage.getPayload());
queue.put(msg);
}
@Override
public void deliveryComplete(IMqttDeliveryToken iMqttDeliveryToken) {
}
}
package org.example;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.serialization.SimpleStringSchema;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.connector.base.DeliveryGuarantee;
import org.apache.flink.connector.kafka.sink.KafkaRecordSerializationSchema;
import org.apache.flink.connector.kafka.sink.KafkaSink;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;
public class DataStreamJob {
public static void main(String[] args) throws Exception {
String clientId = "test";
String username = "test";
String password = "test";
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
MqttConfiguration config = new MqttConfiguration("192.168.3.115", 1883, "flink/hello", clientId);
config.setUsername(username);
config.setPassword(password);
String outputTopic = "WordCount";
String broker = "kafka:9092";
String jobName = "MQTT Word Count";
KafkaRecordSerializationSchema<String> recordSerializer = KafkaRecordSerializationSchema.builder()
.setTopic(outputTopic)
.setValueSerializationSchema(new SimpleStringSchema())
.build();
KafkaSink<String> kafkaSink = KafkaSink.<String>builder()
.setBootstrapServers(broker)
.setRecordSerializer(recordSerializer)
.setDeliveryGuarantee(DeliveryGuarantee.AT_LEAST_ONCE)
.build();
DataStreamSource<String> stream = env.addSource(new MqttConsumer(config));
stream.flatMap(new SplitWord())
.keyBy(v -> v.f0)
.sum(1)
.map(Tuple2::toString)
.sinkTo(kafkaSink);
// Execute program, beginning computation.
env.execute(jobName);
}
private static class SplitWord implements FlatMapFunction<String, Tuple2<String, Integer>> {
@Override
public void flatMap(String value, Collector<Tuple2<String, Integer>> out) throws Exception {
for (String word : value.split("\\s+")) {
out.collect(Tuple2.of(word, 1));
}
}
}
}
package org.example;
import java.io.Serializable;
public class MqttConfiguration implements Serializable {
private String host;
private Integer port;
private String username;
private String password;
private String topic;
private String clientId;
public MqttConfiguration(String host, Integer port, String topic, String clientId) {
this.host = host;
this.port = port;
this.topic = topic;
this.clientId = clientId;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public Integer getPort() {
return port;
}
public void setPort(Integer port) {
this.port = port;
}
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public char[] getPassword() {
return password.toCharArray();
}
public void setPassword(String password) {
this.password = password;
}
public String[] getTopic() {
return topic.split(",");
}
public void setTopic(String topic) {
this.topic = topic;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public String getServerUri() {
return "tcp://" + host + ":" + port;
}
}
package org.example;
import lombok.RequiredArgsConstructor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.eclipse.paho.client.mqttv3.MqttClient;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
/**
* 必须实现检查点机制,否则对于无界流,无法将中间状态进行输出(sink)
*/
@RequiredArgsConstructor
public class MqttConsumer extends RichParallelSourceFunction<String> implements CheckpointedFunction {
private final MqttConfiguration mqttConfig;
private final BlockingQueue<String> queue = new LinkedBlockingQueue<>(10);
private MqttClient mqttClient;
private boolean running = true;
private int offset;
private ListState<Integer> offsetState;
private void connectToMqtt() throws MqttException {
mqttClient = new MqttClient(mqttConfig.getServerUri(), mqttConfig.getClientId(), new MemoryPersistence());
MqttConnectOptions option = new MqttConnectOptions();
option.setCleanSession(false);
option.setUserName(mqttConfig.getUsername());
option.setPassword(mqttConfig.getPassword());
String[] topics = mqttConfig.getTopic();
int[] qos = new int[topics.length];
mqttClient.connect(option);
mqttClient.subscribe(topics, qos);
mqttClient.setCallback(new ConsumerCallback(mqttClient, option, mqttConfig, queue));
}
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
connectToMqtt();
}
@Override
public void run(SourceContext<String> ctx) throws Exception {
while (running) {
String msg = queue.take();
// 使用同步锁机制,在触发 checkpoint 机制后,禁止向下游发送消息
synchronized (ctx.getCheckpointLock()) {
ctx.collect(msg);
offset++;
}
}
}
@Override
public void cancel() {
try {
running = false;
mqttClient.disconnect();
mqttClient.close();
} catch (MqttException e) {
throw new RuntimeException(e);
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
// 清除上一次的状态
offsetState.clear();
// 将最新的 offset 添加到状态中
offsetState.add(offset);
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
// 初始化 offsetState
ListStateDescriptor<Integer> desc = new ListStateDescriptor<>("offset", Types.INT);
offsetState = context.getOperatorStateStore().getListState(desc);
Iterable<Integer> iter = offsetState.get();
if (iter != null && iter.iterator().hasNext()) {
offset = iter.iterator().next();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment