mirror of https://github.com/kpcyrd/sn0int.git
Refactor MQTT client integration, integrate automatic keep-alive
This commit is contained in:
parent
70b5039d50
commit
e397edb4f4
|
@ -4306,6 +4306,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_urlencoded 0.7.1",
|
||||
"sodiumoxide",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"url 2.3.1",
|
||||
|
|
|
@ -10,13 +10,17 @@ function run()
|
|||
mqtt_subscribe(sock, '#', 0)
|
||||
|
||||
while true do
|
||||
-- read the next mqtt packet
|
||||
local pkt = mqtt_recv(sock)
|
||||
if last_err() then return end
|
||||
local log = {pkt=pkt}
|
||||
|
||||
local text
|
||||
if pkt then
|
||||
log['text'] = utf8_decode(pkt['body'])
|
||||
-- attempt to utf8 decode the body if there was a pkt
|
||||
text = utf8_decode(pkt['body'])
|
||||
if last_err() then clear_err() end
|
||||
end
|
||||
info(log)
|
||||
|
||||
info({pkt=pkt, text=text})
|
||||
end
|
||||
end
|
||||
|
|
|
@ -50,6 +50,7 @@ bs58 = "0.4"
|
|||
digest = "0.10"
|
||||
blake2 = "0.10"
|
||||
data-encoding = "2.3.3"
|
||||
thiserror = "1.0.38"
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.9"
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use bytes::Bytes;
|
||||
use blake2::Blake2bVar;
|
||||
use bytes::Bytes;
|
||||
use data_encoding::BASE64;
|
||||
use digest::{Update, VariableOutput};
|
||||
use serde::ser::{Serialize, Serializer};
|
||||
use serde::de::{self, Deserialize, Deserializer};
|
||||
use serde::ser::{Serialize, Serializer};
|
||||
use std::result;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
pub use log::{trace, debug, info, warn, error};
|
||||
pub use failure::{Error, ResultExt, format_err, bail};
|
||||
pub use failure::{bail, format_err, Error, ResultExt};
|
||||
pub use log::{debug, error, info, trace, warn};
|
||||
pub type Result<T> = ::std::result::Result<T, Error>;
|
||||
|
|
|
@ -3,8 +3,8 @@ use hlua_badtouch as hlua;
|
|||
pub mod blobs;
|
||||
pub mod crt;
|
||||
pub mod crypto;
|
||||
mod errors;
|
||||
pub mod engine;
|
||||
mod errors;
|
||||
pub mod geo;
|
||||
pub mod geoip;
|
||||
pub mod gfx;
|
||||
|
|
|
@ -1,20 +1,26 @@
|
|||
use chrootable_https::DnsResolver;
|
||||
use crate::errors::*;
|
||||
use crate::hlua::AnyLuaValue;
|
||||
use mqtt::packet::VariablePacketError;
|
||||
use crate::json::LuaJsonValue;
|
||||
use crate::sockets::{Stream, SocketOptions};
|
||||
use mqtt::{TopicFilter, QualityOfService};
|
||||
use mqtt::control::ConnectReturnCode;
|
||||
use crate::sockets::{SocketOptions, Stream};
|
||||
use chrootable_https::DnsResolver;
|
||||
use mqtt::control::fixed_header::FixedHeaderError;
|
||||
use mqtt::encodable::{Encodable, Decodable};
|
||||
use mqtt::packet::{VariablePacket, ConnectPacket, SubscribePacket, PingreqPacket};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use mqtt::control::ConnectReturnCode;
|
||||
use mqtt::encodable::{Decodable, Encodable};
|
||||
use mqtt::packet::VariablePacketError;
|
||||
use mqtt::packet::{ConnectPacket, PingreqPacket, SubscribePacket, VariablePacket};
|
||||
use mqtt::{QualityOfService, TopicFilter};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryFrom;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
use url::Url;
|
||||
|
||||
// a reasonable default for keep-alive
|
||||
// some servers reject 0 as invalid with a very confusing error message
|
||||
const DEFAULT_PING_INTERVAL: u64 = 90;
|
||||
const DEFAULT_KEEP_ALIVE: u16 = 120;
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
pub struct MqttOptions {
|
||||
pub username: Option<String>,
|
||||
|
@ -23,10 +29,12 @@ pub struct MqttOptions {
|
|||
pub proxy: Option<SocketAddr>,
|
||||
#[serde(default)]
|
||||
pub connect_timeout: u64,
|
||||
#[serde(default)]
|
||||
pub read_timeout: u64,
|
||||
pub read_timeout: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub write_timeout: u64,
|
||||
|
||||
pub ping_interval: Option<u64>,
|
||||
pub keep_alive: Option<u16>,
|
||||
}
|
||||
|
||||
impl MqttOptions {
|
||||
|
@ -37,25 +45,44 @@ impl MqttOptions {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MqttRecvError {
|
||||
#[error("Failed to read mqtt packet: {0:#}")]
|
||||
Recv(#[from] VariablePacketError),
|
||||
#[error("Failed to read mqtt packet: connection disconnected")]
|
||||
RecvDisconnect,
|
||||
#[error("Failed to interact with mqtt: {0:#}")]
|
||||
Error(Error),
|
||||
}
|
||||
|
||||
impl From<Error> for MqttRecvError {
|
||||
fn from(err: Error) -> Self {
|
||||
MqttRecvError::Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MqttClient {
|
||||
stream: Stream,
|
||||
last_ping: Instant,
|
||||
ping_interval: Option<u64>,
|
||||
}
|
||||
|
||||
impl MqttClient {
|
||||
pub fn negotiate(stream: Stream, options: &MqttOptions) -> Result<MqttClient> {
|
||||
// default to DEFAULT_PING_INTERVAL, if an explicit value of 0 was set, disable auto-ping
|
||||
let ping_interval = Some(options.ping_interval.unwrap_or(DEFAULT_PING_INTERVAL));
|
||||
ping_interval.filter(|s| *s != 0);
|
||||
|
||||
let mut client = MqttClient {
|
||||
stream,
|
||||
last_ping: Instant::now(),
|
||||
ping_interval,
|
||||
};
|
||||
|
||||
let mut pkt = ConnectPacket::new("sn0int");
|
||||
pkt.set_user_name(options.username.clone());
|
||||
pkt.set_password(options.password.clone());
|
||||
|
||||
/*
|
||||
if let Some(keep_alive) = msg.keep_alive {
|
||||
packet.set_keep_alive(keep_alive);
|
||||
}
|
||||
*/
|
||||
pkt.set_keep_alive(options.keep_alive.unwrap_or(DEFAULT_KEEP_ALIVE));
|
||||
|
||||
client.send(pkt.into())?;
|
||||
let pkt = client.recv()?;
|
||||
|
@ -72,14 +99,19 @@ impl MqttClient {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn connect<R: DnsResolver>(resolver: &R, url: Url, options: &MqttOptions) -> Result<MqttClient> {
|
||||
pub fn connect<R: DnsResolver>(
|
||||
resolver: &R,
|
||||
url: Url,
|
||||
options: &MqttOptions,
|
||||
) -> Result<MqttClient> {
|
||||
let tls = match url.scheme() {
|
||||
"mqtt" => false,
|
||||
"mqtts" => true,
|
||||
_ => bail!("Invalid mqtt protocol"),
|
||||
};
|
||||
|
||||
let host = url.host_str()
|
||||
let host = url
|
||||
.host_str()
|
||||
.ok_or_else(|| format_err!("Missing host in url"))?;
|
||||
|
||||
let port = match (url.port(), tls) {
|
||||
|
@ -88,29 +120,57 @@ impl MqttClient {
|
|||
(None, false) => 1883,
|
||||
};
|
||||
|
||||
// if no read timeout is configured then keep alive won't work
|
||||
let read_timeout = options.read_timeout.unwrap_or(DEFAULT_PING_INTERVAL);
|
||||
|
||||
let stream = Stream::connect_stream(resolver, host, port, &SocketOptions {
|
||||
tls,
|
||||
sni_value: None,
|
||||
disable_tls_verify: false,
|
||||
proxy: options.proxy,
|
||||
let stream = Stream::connect_stream(
|
||||
resolver,
|
||||
host,
|
||||
port,
|
||||
&SocketOptions {
|
||||
tls,
|
||||
sni_value: None,
|
||||
disable_tls_verify: false,
|
||||
proxy: options.proxy,
|
||||
|
||||
connect_timeout: options.connect_timeout,
|
||||
read_timeout: options.read_timeout,
|
||||
write_timeout: options.write_timeout,
|
||||
})?;
|
||||
connect_timeout: options.connect_timeout,
|
||||
read_timeout,
|
||||
write_timeout: options.write_timeout,
|
||||
},
|
||||
)?;
|
||||
|
||||
Self::negotiate(stream, options)
|
||||
}
|
||||
|
||||
fn maintain_ping(&mut self) -> Result<()> {
|
||||
if let Some(ping_interval) = self.ping_interval {
|
||||
if self.last_ping.elapsed() >= Duration::from_secs(ping_interval) {
|
||||
self.ping().context("Failed to ping")?;
|
||||
self.last_ping = Instant::now();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send(&mut self, pkt: VariablePacket) -> Result<()> {
|
||||
self.maintain_ping()?;
|
||||
debug!("Sending mqtt packet: {:?}", pkt);
|
||||
pkt.encode(&mut self.stream)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv(&mut self) -> std::result::Result<VariablePacket, VariablePacketError> {
|
||||
let pkt = VariablePacket::decode(&mut self.stream)?;
|
||||
fn recv(&mut self) -> std::result::Result<VariablePacket, MqttRecvError> {
|
||||
self.maintain_ping()?;
|
||||
let pkt = VariablePacket::decode(&mut self.stream).map_err(|err| match err {
|
||||
// search for any io error and check if it's ErrorKind::UnexpectedEof
|
||||
VariablePacketError::IoError(err)
|
||||
| VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err))
|
||||
if err.kind() == io::ErrorKind::UnexpectedEof =>
|
||||
{
|
||||
MqttRecvError::RecvDisconnect
|
||||
}
|
||||
_ => MqttRecvError::Recv(err),
|
||||
})?;
|
||||
debug!("Received mqtt packet: {:?}", pkt);
|
||||
Ok(pkt)
|
||||
}
|
||||
|
@ -139,24 +199,29 @@ impl MqttClient {
|
|||
pub fn recv_pkt(&mut self) -> Result<Option<Pkt>> {
|
||||
match self.recv() {
|
||||
Ok(pkt) => Ok(Some(Pkt::try_from(pkt)?)),
|
||||
Err(VariablePacketError::IoError(err)) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err))) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(err) => Err(Error::from(err))
|
||||
// search for any io error and check if it's ErrorKind::WouldBlock
|
||||
Err(MqttRecvError::Recv(
|
||||
VariablePacketError::IoError(err)
|
||||
| VariablePacketError::FixedHeaderError(FixedHeaderError::IoError(err)),
|
||||
)) if err.kind() == io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ping(&mut self) -> Result<()> {
|
||||
let pkt = PingreqPacket::new();
|
||||
self.send(pkt.into())
|
||||
let pkt = VariablePacket::PingreqPacket(pkt);
|
||||
pkt.encode(&mut self.stream)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Pkt {
|
||||
#[serde(rename="publish")]
|
||||
#[serde(rename = "publish")]
|
||||
Publish(Publish),
|
||||
#[serde(rename="pong")]
|
||||
#[serde(rename = "pong")]
|
||||
Pong,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue