Skip to content
This repository has been archived by the owner on Sep 4, 2024. It is now read-only.

Commit

Permalink
Reuse HTTP connection
Browse files Browse the repository at this point in the history
  • Loading branch information
raphjaph committed Nov 17, 2022
1 parent 0679d44 commit e867374
Showing 1 changed file with 75 additions and 41 deletions.
116 changes: 75 additions & 41 deletions src/simple_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

#[cfg(feature = "proxy")]
use socks::Socks5Stream;
use std::io::{BufRead, BufReader, Write};
#[cfg(not(feature = "proxy"))]
use std::io::{BufRead, BufReader, Read, Write};
use std::net::TcpStream;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{error, fmt, io, net, thread};

Expand Down Expand Up @@ -38,6 +38,7 @@ pub struct SimpleHttpTransport {
proxy_addr: net::SocketAddr,
#[cfg(feature = "proxy")]
proxy_auth: Option<(String, String)>,
sock: Arc<Mutex<Option<TcpStream>>>,
}

impl Default for SimpleHttpTransport {
Expand All @@ -57,6 +58,7 @@ impl Default for SimpleHttpTransport {
),
#[cfg(feature = "proxy")]
proxy_auth: None,
sock: Arc::new(Mutex::new(None)),
}
}
}
Expand All @@ -73,29 +75,58 @@ impl SimpleHttpTransport {
}

fn request<R>(&self, req: impl serde::Serialize) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
// `try_request` should not panic, so the mutex shouldn't be poisoned
// and unwrapping should be safe
let mut sock = self.sock.lock().expect("poisoned mutex");
match self.try_request(req, &mut sock) {
Ok(response) => Ok(response),
Err(err) => {
*sock = None;
Err(err)
}
}
}

fn try_request<R>(
&self,
req: impl serde::Serialize,
sock: &mut Option<TcpStream>,
) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
// Open connection
let request_deadline = Instant::now() + self.timeout;
#[cfg(feature = "proxy")]
let mut sock = if let Some((username, password)) = &self.proxy_auth {
Socks5Stream::connect_with_password(
self.proxy_addr,
self.addr,
username.as_str(),
password.as_str(),
)?
.into_inner()
} else {
Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner()
};

#[cfg(not(feature = "proxy"))]
let mut sock = TcpStream::connect_timeout(&self.addr, self.timeout)?;
if sock.is_none() {
*sock = Some({
#[cfg(feature = "proxy")]
{
if let Some((username, password)) = &self.proxy_auth {
Socks5Stream::connect_with_password(
self.proxy_addr,
self.addr,
username.as_str(),
password.as_str(),
)?
.into_inner()
} else {
Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner()
}
}

sock.set_read_timeout(Some(self.timeout))?;
sock.set_write_timeout(Some(self.timeout))?;
#[cfg(not(feature = "proxy"))]
{
let stream = TcpStream::connect_timeout(&self.addr, self.timeout)?;
stream.set_read_timeout(Some(self.timeout))?;
stream.set_write_timeout(Some(self.timeout))?;
stream
}
})
};
let sock = sock.as_mut().unwrap();

// Serialize the body first so we can set the Content-Length header.
let body = serde_json::to_vec(&req)?;
Expand All @@ -105,7 +136,6 @@ impl SimpleHttpTransport {
sock.write_all(self.path.as_bytes())?;
sock.write_all(b" HTTP/1.1\r\n")?;
// Write headers
sock.write_all(b"Connection: Close\r\n")?;
sock.write_all(b"Content-Type: application/json\r\n")?;
sock.write_all(b"Content-Length: ")?;
sock.write_all(body.len().to_string().as_bytes())?;
Expand Down Expand Up @@ -133,18 +163,39 @@ impl SimpleHttpTransport {
Err(_) => return Err(Error::HttpParseError),
};

// Skip response header fields
while get_line(&mut reader, request_deadline)? != "\r\n" {}
// Parse response header fields
let mut content_length = None;
loop {
let line = get_line(&mut reader, request_deadline)?;

if line == "\r\n" {
break;
}

const CONTENT_LENGTH: &str = "content-length: ";
if line.to_lowercase().starts_with(CONTENT_LENGTH) {
content_length = Some(
line[CONTENT_LENGTH.len()..]
.trim()
.parse::<usize>()
.map_err(|_| Error::HttpParseError)?,
);
}
}

if response_code == 401 {
// There is no body in a 401 response, so don't try to read it
return Err(Error::HttpErrorCode(response_code));
}

let content_length = content_length.ok_or(Error::HttpParseError)?;

let mut buffer = vec![0; content_length];

// Even if it's != 200, we parse the response as we may get a JSONRPC error instead
// of the less meaningful HTTP error code.
let resp_body = get_lines(&mut reader)?;
match serde_json::from_str(&resp_body) {
reader.read_exact(&mut buffer)?;
match serde_json::from_slice(&buffer) {
Ok(s) => Ok(s),
Err(e) => {
if response_code != 200 {
Expand Down Expand Up @@ -261,23 +312,6 @@ fn get_line<R: BufRead>(reader: &mut R, deadline: Instant) -> Result<String, Err
Err(Error::Timeout)
}

/// Read all lines from a buffered reader.
fn get_lines<R: BufRead>(reader: &mut R) -> Result<String, Error> {
let mut body: String = String::new();

for line in reader.lines() {
match line {
Ok(l) => body.push_str(&l),
// io error occurred, abort
Err(e) => return Err(Error::SocketError(e)),
}
}
// remove whitespace
body.retain(|c| !c.is_whitespace());

Ok(body)
}

/// Do some very basic manual URL parsing because the uri/url crates
/// all have unicode-normalization as a dependency and that's broken.
fn check_url(url: &str) -> Result<(SocketAddr, String), Error> {
Expand Down

0 comments on commit e867374

Please sign in to comment.