use std::{error, fmt, io, net, time};
use crate::client::Transport;
use crate::{Request, Response};
#[derive(Debug, Clone)]
pub struct TcpTransport {
pub addr: net::SocketAddr,
pub timeout: Option<time::Duration>,
}
impl TcpTransport {
pub fn new(addr: net::SocketAddr) -> TcpTransport {
TcpTransport {
addr,
timeout: None,
}
}
fn request<R>(&self, req: impl serde::Serialize) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
let mut sock = net::TcpStream::connect(self.addr)?;
sock.set_read_timeout(self.timeout)?;
sock.set_write_timeout(self.timeout)?;
serde_json::to_writer(&mut sock, &req)?;
let resp: R = serde_json::Deserializer::from_reader(&mut sock)
.into_iter()
.next()
.ok_or(Error::Timeout)??;
Ok(resp)
}
}
impl Transport for TcpTransport {
fn send_request(&self, req: Request) -> Result<Response, crate::Error> {
Ok(self.request(req)?)
}
fn send_batch(&self, reqs: &[Request]) -> Result<Vec<Response>, crate::Error> {
Ok(self.request(reqs)?)
}
fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.addr)
}
}
#[derive(Debug)]
pub enum Error {
SocketError(io::Error),
Timeout,
Json(serde_json::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use Error::*;
match *self {
SocketError(ref e) => write!(f, "couldn't connect to host: {}", e),
Timeout => f.write_str("didn't receive response data in time, timed out."),
Json(ref e) => write!(f, "JSON error: {}", e),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
use self::Error::*;
match *self {
SocketError(ref e) => Some(e),
Timeout => None,
Json(ref e) => Some(e),
}
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::SocketError(e)
}
}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Self {
Error::Json(e)
}
}
impl From<Error> for crate::Error {
fn from(e: Error) -> crate::Error {
match e {
Error::Json(e) => crate::Error::Json(e),
e => crate::Error::Transport(Box::new(e)),
}
}
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
thread,
};
use super::*;
use crate::Client;
#[test]
fn sanity_check_tcp_transport() {
let addr: net::SocketAddr =
net::SocketAddrV4::new(net::Ipv4Addr::new(127, 0, 0, 1), 0).into();
let server = net::TcpListener::bind(addr).unwrap();
let addr = server.local_addr().unwrap();
let dummy_req = Request {
method: "arandommethod",
params: None,
id: serde_json::Value::Number(4242242.into()),
jsonrpc: Some("2.0"),
};
let dummy_req_ser = serde_json::to_vec(&dummy_req).unwrap();
let dummy_resp = Response {
result: None,
error: None,
id: serde_json::Value::Number(4242242.into()),
jsonrpc: Some("2.0".into()),
};
let dummy_resp_ser = serde_json::to_vec(&dummy_resp).unwrap();
let client_thread = thread::spawn(move || {
let transport = TcpTransport {
addr,
timeout: Some(time::Duration::from_secs(5)),
};
let client = Client::with_transport(transport);
client.send_request(dummy_req.clone()).unwrap()
});
let (mut stream, _) = server.accept().unwrap();
stream.set_read_timeout(Some(time::Duration::from_secs(5))).unwrap();
let mut recv_req = vec![0; dummy_req_ser.len()];
let mut read = 0;
while read < dummy_req_ser.len() {
read += stream.read(&mut recv_req[read..]).unwrap();
}
assert_eq!(recv_req, dummy_req_ser);
stream.write_all(&dummy_resp_ser).unwrap();
stream.flush().unwrap();
let recv_resp = client_thread.join().unwrap();
assert_eq!(serde_json::to_vec(&recv_resp).unwrap(), dummy_resp_ser);
}
}