Refactor MQTT client integration, integrate automatic keep-alive

This commit is contained in:
kpcyrd 2023-01-28 17:58:11 +01:00
parent 70b5039d50
commit e397edb4f4
7 changed files with 114 additions and 43 deletions

1
Cargo.lock generated
View File

@ -4306,6 +4306,7 @@ dependencies = [
"serde_json",
"serde_urlencoded 0.7.1",
"sodiumoxide",
"thiserror",
"tokio",
"tungstenite",
"url 2.3.1",

View File

@ -10,13 +10,17 @@ function run()
mqtt_subscribe(sock, '#', 0)
while true do
-- read the next mqtt packet
local pkt = mqtt_recv(sock)
if last_err() then return end
local log = {pkt=pkt}
local text
if pkt then
log['text'] = utf8_decode(pkt['body'])
-- attempt to utf8 decode the body if there was a pkt
text = utf8_decode(pkt['body'])
if last_err() then clear_err() end
end
info(log)
info({pkt=pkt, text=text})
end
end

View File

@ -50,6 +50,7 @@ bs58 = "0.4"
digest = "0.10"
blake2 = "0.10"
data-encoding = "2.3.3"
thiserror = "1.0.38"
[dev-dependencies]
env_logger = "0.9"

View File

@ -1,9 +1,9 @@
use bytes::Bytes;
use blake2::Blake2bVar;
use bytes::Bytes;
use data_encoding::BASE64;
use digest::{Update, VariableOutput};
use serde::ser::{Serialize, Serializer};
use serde::de::{self, Deserialize, Deserializer};
use serde::ser::{Serialize, Serializer};
use std::result;
#[derive(Debug, Clone, PartialEq)]

View File

@ -1,3 +1,3 @@
pub use log::{trace, debug, info, warn, error};
pub use failure::{Error, ResultExt, format_err, bail};
pub use failure::{bail, format_err, Error, ResultExt};
pub use log::{debug, error, info, trace, warn};
pub type Result<T> = ::std::result::Result<T, Error>;

View File

@ -3,8 +3,8 @@ use hlua_badtouch as hlua;
pub mod blobs;
pub mod crt;
pub mod crypto;
mod errors;
pub mod engine;
mod errors;
pub mod geo;
pub mod geoip;
pub mod gfx;

View File

@ -1,20 +1,26 @@
use chrootable_https::DnsResolver;
use crate::errors::*;
use crate::hlua::AnyLuaValue;
use mqtt::packet::VariablePacketError;
use crate::json::LuaJsonValue;
use crate::sockets::{Stream, SocketOptions};
use mqtt::{TopicFilter, QualityOfService};
use mqtt::control::ConnectReturnCode;
use crate::sockets::{SocketOptions, Stream};
use chrootable_https::DnsResolver;
use mqtt::control::fixed_header::FixedHeaderError;
use mqtt::encodable::{Encodable, Decodable};
use mqtt::packet::{VariablePacket, ConnectPacket, SubscribePacket, PingreqPacket};
use serde::{Serialize, Deserialize};
use mqtt::control::ConnectReturnCode;
use mqtt::encodable::{Decodable, Encodable};
use mqtt::packet::VariablePacketError;
use mqtt::packet::{ConnectPacket, PingreqPacket, SubscribePacket, VariablePacket};
use mqtt::{QualityOfService, TopicFilter};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::io;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use url::Url;
// a reasonable default for keep-alive
// some servers reject 0 as invalid with a very confusing error message
const DEFAULT_PING_INTERVAL: u64 = 90;
const DEFAULT_KEEP_ALIVE: u16 = 120;
#[derive(Debug, Default, Deserialize)]
pub struct MqttOptions {
pub username: Option<String>,
@ -23,10 +29,12 @@ pub struct MqttOptions {
pub proxy: Option<SocketAddr>,
#[serde(default)]
pub connect_timeout: u64,
#[serde(default)]
pub read_timeout: u64,
pub read_timeout: Option<u64>,
#[serde(default)]
pub write_timeout: u64,
pub ping_interval: Option<u64>,
pub keep_alive: Option<u16>,
}
impl MqttOptions {
@ -37,25 +45,44 @@ impl MqttOptions {
}
}
#[derive(Debug, thiserror::Error)]
pub enum MqttRecvError {
#[error("Failed to read mqtt packet: {0:#}")]
Recv(#[from] VariablePacketError),
#[error("Failed to read mqtt packet: connection disconnected")]
RecvDisconnect,
#[error("Failed to interact with mqtt: {0:#}")]
Error(Error),
}
impl From<Error> for MqttRecvError {
fn from(err: Error) -> Self {
MqttRecvError::Error(err)
}
}
pub struct MqttClient {
stream: Stream,
last_ping: Instant,
ping_interval: Option<u64>,
}
impl MqttClient {
pub fn negotiate(stream: Stream, options: &MqttOptions) -> Result<MqttClient> {
// default to DEFAULT_PING_INTERVAL, if an explicit value of 0 was set, disable auto-ping
let ping_interval = Some(options.ping_interval.unwrap_or(DEFAULT_PING_INTERVAL));
ping_interval.filter(|s| *s != 0);
let mut client = MqttClient {
stream,
last_ping: Instant::now(),
ping_interval,
};
let mut pkt = ConnectPacket::new("sn0int");
pkt.set_user_name(options.username.clone());
pkt.set_password(options.password.clone());
/*
if let Some(keep_alive) = msg.keep_alive {
packet.set_keep_alive(keep_alive);
}
*/
pkt.set_keep_alive(options.keep_alive.unwrap_or(DEFAULT_KEEP_ALIVE));
client.send(pkt.into())?;
let pkt = client.recv()?;
@ -72,14 +99,19 @@ impl MqttClient {
}
}
pub fn connect<R: DnsResolver>(resolver: &R, url: Url, options: &MqttOptions) -> Result<MqttClient> {
pub fn connect<R: DnsResolver>(
resolver: &R,
url: Url,
options: &MqttOptions,
) -> Result<MqttClient> {
let tls = match url.scheme() {
"mqtt" => false,
"mqtts" => true,
_ => bail!("Invalid mqtt protocol"),
};
let host = url.host_str()
let host = url
.host_str()
.ok_or_else(|| format_err!("Missing host in url"))?;
let port = match (url.port(), tls) {
@ -88,29 +120,57 @@ impl MqttClient {
(None, false) => 1883,
};
// if no read timeout is configured then keep alive won't work
let read_timeout = options.read_timeout.unwrap_or(DEFAULT_PING_INTERVAL);
let stream = Stream::connect_stream(resolver, host, port, &SocketOptions {
tls,
sni_value: None,
disable_tls_verify: false,
proxy: options.proxy,
let stream = Stream::connect_stream(
resolver,
host,
port,
&SocketOptions {
tls,
sni_value: None,
disable_tls_verify: false,
proxy: options.proxy,
connect_timeout: options.connect_timeout,
read_timeout: options.read_timeout,
write_timeout: options.write_timeout,
})?;
connect_timeout: options.connect_timeout,
read_timeout,
write_timeout: options.write_timeout,
},
)?;
Self::negotiate(stream, options)
}
fn maintain_ping(&mut self) -> Result<()> {
if let Some(ping_interval) = self.ping_interval {
if self.last_ping.elapsed() >= Duration::from_secs(ping_interval) {
self.ping().context("Failed to ping")?;
self.last_ping = Instant::now();
}
}
Ok(())
}
fn send(&mut self, pkt: VariablePacket) -> Result<()> {
self.maintain_ping()?;
debug!("Sending mqtt packet: {:?}", pkt);
pkt.encode(&mut self.stream)?;
Ok(())
}
fn recv(&mut self) -> std::result::Result<VariablePacket, VariablePacketError> {
let pkt = VariablePacket::decode(&mut self.stream)?;
fn recv(&mut self) -> std::result::Result<VariablePacket, MqttRecvError> {
self.maintain_ping()?;
let pkt = VariablePacket::decode(&mut self.stream).map_err(|err| match err {
// search for any io error and check if it's ErrorKind::UnexpectedEof
VariablePacketError::IoError(err)
| VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err))
if err.kind() == io::ErrorKind::UnexpectedEof =>
{
MqttRecvError::RecvDisconnect
}
_ => MqttRecvError::Recv(err),
})?;
debug!("Received mqtt packet: {:?}", pkt);
Ok(pkt)
}
@ -139,24 +199,29 @@ impl MqttClient {
pub fn recv_pkt(&mut self) -> Result<Option<Pkt>> {
match self.recv() {
Ok(pkt) => Ok(Some(Pkt::try_from(pkt)?)),
Err(VariablePacketError::IoError(err)) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err))) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(err) => Err(Error::from(err))
// search for any io error and check if it's ErrorKind::WouldBlock
Err(MqttRecvError::Recv(
VariablePacketError::IoError(err)
| VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err)),
)) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(err) => Err(err.into()),
}
}
pub fn ping(&mut self) -> Result<()> {
let pkt = PingreqPacket::new();
self.send(pkt.into())
let pkt = VariablePacket::PingreqPacket(pkt);
pkt.encode(&mut self.stream)?;
Ok(())
}
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Pkt {
#[serde(rename="publish")]
#[serde(rename = "publish")]
Publish(Publish),
#[serde(rename="pong")]
#[serde(rename = "pong")]
Pong,
}