Skip to content

Commit

Permalink
Merge pull request #168 from lipanski/fix-join-all-v2
Browse files Browse the repository at this point in the history
Replace tokio Mutexes with sync Mutexes
  • Loading branch information
lipanski committed Mar 25, 2023
2 parents 9a07811 + 8c35191 commit 627d938
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 50 deletions.
18 changes: 9 additions & 9 deletions src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::ops::Drop;
use std::path::Path;
use std::string::ToString;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::sync::RwLock;

#[derive(Clone, Debug)]
pub struct InnerMock {
Expand Down Expand Up @@ -441,7 +441,7 @@ impl Mock {
#[track_caller]
pub fn assert(&self) {
let mutex = self.state.clone();
let state = mutex.blocking_read();
let state = mutex.read().unwrap();
if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
let matched = self.matched_hits(hits);
let message = if !matched {
Expand All @@ -462,7 +462,7 @@ impl Mock {
///
pub async fn assert_async(&self) {
let mutex = self.state.clone();
let state = mutex.read().await;
let state = mutex.read().unwrap();
if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
let matched = self.matched_hits(hits);
let message = if !matched {
Expand All @@ -483,7 +483,7 @@ impl Mock {
///
pub fn matched(&self) -> bool {
let mutex = self.state.clone();
let state = mutex.blocking_read();
let state = mutex.read().unwrap();
let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
return false;
};
Expand All @@ -496,7 +496,7 @@ impl Mock {
///
pub async fn matched_async(&self) -> bool {
let mutex = self.state.clone();
let state = mutex.read().await;
let state = mutex.read().unwrap();
let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
return false;
};
Expand All @@ -518,7 +518,7 @@ impl Mock {
pub fn create(mut self) -> Mock {
let remote_mock = RemoteMock::new(self.inner.clone());
let state = self.state.clone();
let mut state = state.blocking_write();
let mut state = state.write().unwrap();
state.mocks.push(remote_mock);

self.created = true;
Expand All @@ -532,7 +532,7 @@ impl Mock {
pub async fn create_async(mut self) -> Mock {
let remote_mock = RemoteMock::new(self.inner.clone());
let state = self.state.clone();
let mut state = state.write().await;
let mut state = state.write().unwrap();
state.mocks.push(remote_mock);

self.created = true;
Expand All @@ -545,7 +545,7 @@ impl Mock {
///
pub fn remove(&self) {
let mutex = self.state.clone();
let mut state = mutex.blocking_write();
let mut state = mutex.write().unwrap();
state.remove_mock(self.inner.id.clone());
}

Expand All @@ -554,7 +554,7 @@ impl Mock {
///
pub async fn remove_async(&self) {
let mutex = self.state.clone();
let mut state = mutex.write().await;
let mut state = mutex.write().unwrap();
state.remove_mock(self.inner.id.clone());
}

Expand Down
46 changes: 20 additions & 26 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@ use crate::request::Request;
use crate::response::{Body as ResponseBody, Chunked as ResponseChunked};
use crate::ServerGuard;
use crate::{Error, ErrorKind, Matcher, Mock};
use futures::stream::{self, StreamExt};
use hyper::server::conn::Http;
use hyper::service::service_fn;
use hyper::{Body, Request as HyperRequest, Response, StatusCode};
use std::fmt;
use std::net::SocketAddr;
use std::ops::Drop;
use std::sync::Arc;
use std::sync::{mpsc, Arc, RwLock};
use std::thread;
use tokio::net::TcpListener;
use tokio::runtime;
use tokio::sync::{oneshot, RwLock};
use tokio::task::{spawn_local, LocalSet};

#[derive(Clone, Debug)]
Expand All @@ -27,11 +25,11 @@ impl RemoteMock {
RemoteMock { inner }
}

async fn matches(&self, other: &mut Request) -> bool {
fn matches(&self, other: &mut Request) -> bool {
self.method_matches(other)
&& self.path_matches(other)
&& self.headers_match(other)
&& self.body_matches(other).await
&& self.body_matches(other)
}

fn method_matches(&self, request: &Request) -> bool {
Expand All @@ -49,8 +47,8 @@ impl RemoteMock {
.all(|&(ref field, ref expected)| expected.matches_values(&request.header(field)))
}

async fn body_matches(&self, request: &mut Request) -> bool {
let body = request.read_body().await;
fn body_matches(&self, request: &mut Request) -> bool {
let body = request.body().unwrap();
let safe_body = &String::from_utf8_lossy(body);

self.inner.body.matches_value(safe_body) || self.inner.body.matches_binary_value(body)
Expand Down Expand Up @@ -201,7 +199,7 @@ impl Server {
pub(crate) fn try_new_with_port(port: u16) -> Result<Server, Error> {
let state = Arc::new(RwLock::new(State::new()));
let address = SocketAddr::from(([127, 0, 0, 1], port));
let (address_sender, address_receiver) = oneshot::channel::<String>();
let (address_sender, address_receiver) = mpsc::channel::<String>();
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
Expand All @@ -214,7 +212,7 @@ impl Server {
});

let address = address_receiver
.blocking_recv()
.recv()
.map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;

let server = Server { address, state };
Expand All @@ -228,7 +226,7 @@ impl Server {
pub(crate) async fn try_new_with_port_async(port: u16) -> Result<Server, Error> {
let state = Arc::new(RwLock::new(State::new()));
let address = SocketAddr::from(([127, 0, 0, 1], port));
let (address_sender, address_receiver) = oneshot::channel::<String>();
let (address_sender, address_receiver) = mpsc::channel::<String>();
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
Expand All @@ -241,7 +239,7 @@ impl Server {
});

let address = address_receiver
.await
.recv()
.map_err(|err| Error::new_with_context(ErrorKind::ServerFailure, err))?;

let server = Server { address, state };
Expand All @@ -251,7 +249,7 @@ impl Server {

async fn bind_server(
address: SocketAddr,
address_sender: oneshot::Sender<String>,
address_sender: mpsc::Sender<String>,
state: Arc<RwLock<State>>,
) -> Result<(), Error> {
let listener = TcpListener::bind(address)
Expand Down Expand Up @@ -321,28 +319,26 @@ impl Server {
///
pub fn reset(&mut self) {
let state = self.state.clone();
let mut state = state.blocking_write();
let mut state = state.write().unwrap();
state.mocks.clear();
state.unmatched_requests.clear();
}

///
/// Same as `Server::reset` but async.
/// **DEPRECATED:** Use `Server::reset` instead. The implementation is not async any more.
///
#[deprecated(since = "1.0.1", note = "Use `Server::reset` instead")]
pub async fn reset_async(&mut self) {
let state = self.state.clone();
let mut state = state.write().await;
let mut state = state.write().unwrap();
state.mocks.clear();
state.unmatched_requests.clear();
}
}

impl Drop for Server {
fn drop(&mut self) {
futures::executor::block_on(async {
log::debug!("Server::drop() called for {}", self);
self.reset_async().await;
});
self.reset();
}
}

Expand All @@ -361,13 +357,11 @@ async fn handle_request(
log::debug!("Request received: {}", request.formatted());

let mutex = state.clone();
let mut state = mutex.write().await;

let mut mocks_stream = stream::iter(&mut state.mocks);
let mut state = mutex.write().unwrap();
let mut matching_mocks: Vec<&mut RemoteMock> = vec![];

while let Some(mock) = mocks_stream.next().await {
if mock.matches(&mut request).await {
for mock in state.mocks.iter_mut() {
if mock.matches(&mut request) {
matching_mocks.push(mock);
}
}
Expand All @@ -382,15 +376,15 @@ async fn handle_request(
if let Some(mock) = mock {
log::debug!("Mock found");
mock.inner.hits += 1;
respond_with_mock(request, mock).await
respond_with_mock(request, mock)
} else {
log::debug!("Mock not found");
state.unmatched_requests.push(request);
respond_with_mock_not_found()
}
}

async fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
let status: StatusCode = mock.inner.response.status;
let mut response = Response::builder().status(status);

Expand Down
32 changes: 17 additions & 15 deletions src/server_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ use crate::Server;
use crate::{Error, ErrorKind};
use lazy_static::lazy_static;
use std::collections::VecDeque;
use std::ops::Drop;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore, SemaphorePermit};
use std::ops::{Deref, DerefMut, Drop};
use std::sync::{Arc, Mutex};
use tokio::sync::{Semaphore, SemaphorePermit};

const DEFAULT_POOL_SIZE: usize = 100;

Expand Down Expand Up @@ -46,11 +45,9 @@ impl DerefMut for ServerGuard {

impl Drop for ServerGuard {
fn drop(&mut self) {
futures::executor::block_on(async {
if let Some(server) = self.server.take() {
SERVER_POOL.recycle_async(server).await;
}
});
if let Some(server) = self.server.take() {
SERVER_POOL.recycle(server);
}
}
}

Expand Down Expand Up @@ -81,11 +78,16 @@ impl ServerPool {
.await
.map_err(|err| Error::new_with_context(ErrorKind::Deadlock, err))?;

let server = if self.created < self.max_size {
Some(Server::try_new_with_port_async(0).await?)
} else {
None
};

let state_mutex = self.state.clone();
let mut state = state_mutex.lock().await;
let mut state = state_mutex.lock().unwrap();

if self.created < self.max_size {
let server = Server::try_new_with_port_async(0).await?;
if let Some(server) = server {
state.push_back(server);
}

Expand All @@ -96,10 +98,10 @@ impl ServerPool {
}
}

async fn recycle_async(&'static self, mut server: Server) {
server.reset_async().await;
fn recycle(&self, mut server: Server) {
server.reset();
let state_mutex = self.state.clone();
let mut state = state_mutex.lock().await;
let mut state = state_mutex.lock().unwrap();
state.push_back(server);
}
}
13 changes: 13 additions & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1993,3 +1993,16 @@ async fn test_match_body_asnyc() {

assert_eq!(200, response.status());
}

#[tokio::test]
async fn test_join_all_async() {
let futures = (0..10).map(|_| async {
let mut s = Server::new_async().await;
let m = s.mock("POST", "/").create_async().await;

reqwest::Client::new().post(s.url()).send().await.unwrap();
m.assert_async().await;
});

let _results = futures::future::join_all(futures).await;
}

0 comments on commit 627d938

Please sign in to comment.