Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the session and pool. #10

Merged
merged 11 commits into from
Dec 16, 2021
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ jobs:
run: |
cargo test
timeout-minutes: 4
- name: example
run: |
cargo run --example basic_op
timeout-minutes: 4
2 changes: 1 addition & 1 deletion fbthrift-transport/tests/transport_tokio_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ mod transport_tokio_io_tests {
task::JoinHandle,
};

use nebula_fbthrift_transport::AsyncTransport;
use fbthrift_transport_response_handler::ResponseHandler;
use nebula_fbthrift_transport::AsyncTransport;

#[derive(Clone)]
pub struct FooResponseHandler;
Expand Down
4 changes: 4 additions & 0 deletions nebula_rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ tokio = { version = "1.8.2", features = ["full"] }
fbthrift = { version = "0.0.2" }
fbthrift-transport = { path = "../fbthrift-transport", package = "nebula-fbthrift-transport" , features = ["tokio_io"], version = "0.0.2" }
bytes = { version = "0.5" }
futures = { version = "0.3.16" }

[build-dependencies]

[dev-dependencies]

[[example]]
name = "basic_op"
31 changes: 31 additions & 0 deletions nebula_rust/examples/basic_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

use nebula_rust::graph_client;

#[tokio::main]
async fn main() {
let mut conf = graph_client::pool_config::PoolConfig::new();
conf.min_connection_pool_size(2)
.max_connection_pool_size(10)
.address("localhost:9669".to_string());

let pool = graph_client::connection_pool::ConnectionPool::new(&conf).await;
let session = pool.get_session("root", "nebula", true).await.unwrap();

let resp = session.execute("YIELD 1").await.unwrap();
assert!(resp.error_code == common::types::ErrorCode::SUCCEEDED);

println!("{:?}", resp.data.as_ref().unwrap());
println!(
"The result of query `YIELD 1' is {}.",
if let common::types::Value::iVal(v) = resp.data.unwrap().rows[0].values[0] {
v
} else {
panic!()
}
);
}
41 changes: 32 additions & 9 deletions nebula_rust/src/graph_client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,54 @@ use tokio::net::TcpStream;

use crate::graph_client::transport_response_handler;

/// The simple abstraction of a connection to nebula graph server
#[derive(Default)]
pub struct Connection {
client: client::GraphServiceImpl<
BinaryProtocol,
AsyncTransport<TcpStream, transport_response_handler::GraphTransportResponseHandler>,
// The option is used to construct a null connection
// which is used to give back the connection to pool from session
// So we could assume it's alway not null
client: Option<
client::GraphServiceImpl<
BinaryProtocol,
AsyncTransport<TcpStream, transport_response_handler::GraphTransportResponseHandler>,
>,
>,
}

impl Connection {
/// Create connection with the specified [host:port]
pub async fn new(host: &str, port: i32) -> Result<Connection> {
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(addr).await?;
/// Create connection with the specified [host:port] address
pub async fn new_from_address(address: &str) -> Result<Connection> {
let stream = TcpStream::connect(address).await?;
let transport = AsyncTransport::new(
stream,
AsyncTransportConfiguration::new(
transport_response_handler::GraphTransportResponseHandler,
),
);
Ok(Connection {
client: client::GraphServiceImpl::new(transport),
client: Some(client::GraphServiceImpl::new(transport)),
})
}

/// Create connection with the specified [host:port]
pub async fn new(host: &str, port: i32) -> Result<Connection> {
let address = format!("{}:{}", host, port);
Connection::new_from_address(&address).await
}

/// Authenticate by username and password
/// The returned error of `Result` only means the request/response status
/// The error from Nebula Graph is still in `error_code` field in response, so you need check it
/// to known wether authenticate succeeded
pub async fn authenticate(
&self,
username: &str,
password: &str,
) -> std::result::Result<graph::types::AuthResponse, common::types::ErrorCode> {
let result = self
.client
.as_ref()
.unwrap()
.authenticate(
&username.to_string().into_bytes(),
&password.to_string().into_bytes(),
Expand All @@ -56,25 +73,31 @@ impl Connection {
}

/// Sign out the authentication by session id which got by authenticating previous
/// The returned error of `Result` only means the request/response status
pub async fn signout(
&self,
session_id: i64,
) -> std::result::Result<(), common::types::ErrorCode> {
let result = self.client.signout(session_id).await;
let result = self.client.as_ref().unwrap().signout(session_id).await;
if let Err(_) = result {
return Err(common::types::ErrorCode::E_RPC_FAILURE);
}
Ok(())
}

/// Execute the query with current session id which got by authenticating previous
/// The returned error of `Result` only means the request/response status
/// The error from Nebula Graph is still in `error_code` field in response, so you need check it
/// to known wether the query execute succeeded
pub async fn execute(
&self,
session_id: i64,
query: &str,
) -> std::result::Result<graph::types::ExecutionResponse, common::types::ErrorCode> {
let result = self
.client
.as_ref()
.unwrap()
.execute(session_id, &query.to_string().into_bytes())
.await;
if let Err(_) = result {
Expand Down
141 changes: 140 additions & 1 deletion nebula_rust/src/graph_client/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,143 @@
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

mod graph_client {};
use crate::graph_client::connection::Connection;
use crate::graph_client::pool_config::PoolConfig;
use crate::graph_client::session::Session;

/// The pool of connection to server, it's MT-safe to access.
pub struct ConnectionPool {
/// The connections
/// The interior mutable to enable could get multiple sessions in one scope
conns: std::sync::Mutex<std::cell::RefCell<std::collections::LinkedList<Connection>>>,
/// It should be immutable
config: PoolConfig,
/// Address cursor
cursor: std::cell::RefCell<std::sync::atomic::AtomicUsize>,
/// The total count of connections, contains which hold by session
conns_count: std::cell::RefCell<std::sync::atomic::AtomicUsize>,
}

impl ConnectionPool {
/// Construct pool by the configuration
pub async fn new(conf: &PoolConfig) -> Self {
let conns = std::collections::LinkedList::<Connection>::new();
let pool = ConnectionPool {
conns: std::sync::Mutex::new(std::cell::RefCell::new(conns)),
config: conf.clone(),
cursor: std::cell::RefCell::new(std::sync::atomic::AtomicUsize::new(0)),
conns_count: std::cell::RefCell::new(std::sync::atomic::AtomicUsize::new(0)),
};
assert!(pool.config.min_connection_pool_size <= pool.config.max_connection_pool_size);
pool.new_connection(pool.config.min_connection_pool_size)
.await;
pool
}

/// Get a session authenticated by username and password
/// retry_connect means keep the connection available if true
pub async fn get_session(
&self,
username: &str,
password: &str,
retry_connect: bool,
) -> std::result::Result<Session<'_>, common::types::ErrorCode> {
if self.conns.lock().unwrap().borrow_mut().is_empty() {
self.new_connection(1).await;
}
let conn = self.conns.lock().unwrap().borrow_mut().pop_back();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the connection pop from the list, and if the user does no return it, the pool couldn't close it. You need to make the pool can manage it when the use get it from the pool.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to make the connection is ok, here you can get a bad connection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pool can't access the connection owned by session. It's designed to avoid some data race

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The keep-alive will add in later pr.

if let Some(conn) = conn {
let resp = conn.authenticate(username, password).await?;
if resp.error_code != common::types::ErrorCode::SUCCEEDED {
return Err(resp.error_code);
}
Ok(Session::new(
resp.session_id.unwrap(),
conn,
self,
username.to_string(),
password.to_string(),
if let Some(time_zone_name) = resp.time_zone_name {
std::str::from_utf8(&time_zone_name).unwrap().to_string()
} else {
String::new()
},
resp.time_zone_offset_seconds.unwrap(),
retry_connect,
))
} else {
Err(common::types::ErrorCode::E_UNKNOWN)
}
}

/// Give back the connection to pool
#[inline]
pub fn give_back(&self, conn: Connection) {
self.conns.lock().unwrap().borrow_mut().push_back(conn);
}

/// Get the count of connections
#[inline]
pub fn len(&self) -> usize {
self.conns.lock().unwrap().borrow().len()
}

// Add new connection to pool
// inc is the count of new connection created, which shouldn't be zero
// the incremental count maybe can't fit when occurs error in connection creating
async fn new_connection(&self, inc: u32) {
assert!(inc != 0);
// TODO concurrent these
let mut count = 0;
let mut loop_count = 0;
let loop_limit = inc as usize * self.config.addresses.len();
while count < inc {
if count as usize
+ self
.conns_count
.borrow()
.load(std::sync::atomic::Ordering::Acquire)
>= self.config.max_connection_pool_size as usize
{
// Reach the pool size limit
break;
}
let cursor = { self.cursor() };
match Connection::new_from_address(&self.config.addresses[cursor]).await {
Ok(conn) => {
self.conns.lock().unwrap().borrow_mut().push_back(conn);
count += 1;
}
Err(_) => (),
};
loop_count += 1;
if loop_count > loop_limit {
// Can't get so many connections, avoid dead loop
break;
}
}
// Release ordering make sure inc happened after creating new connections
self.conns_count
.borrow_mut()
.fetch_add(count as usize, std::sync::atomic::Ordering::Release);
}

// cursor on the server addresses
fn cursor(&self) -> usize {
if self
.cursor
.borrow()
.load(std::sync::atomic::Ordering::Relaxed)
>= self.config.addresses.len()
{
self.cursor
.borrow_mut()
.store(0, std::sync::atomic::Ordering::Relaxed);
0
} else {
self.cursor
.borrow_mut()
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
}
}
3 changes: 3 additions & 0 deletions nebula_rust/src/graph_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@
*/

pub mod connection;
pub mod connection_pool;
pub mod pool_config;
pub mod session;
mod transport_response_handler;
56 changes: 55 additions & 1 deletion nebula_rust/src/graph_client/pool_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,58 @@
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

mod graph_client {};
#[derive(Debug, Default, Clone)]
pub struct PoolConfig {
/// connection timeout in ms
pub timeout: u32,
pub idle_time: u32,
/// max limit count of connections in pool
pub max_connection_pool_size: u32,
/// min limit count of connections in pool, also the initial count if works well
pub min_connection_pool_size: u32,
/// address of graph server
pub addresses: std::vec::Vec<String>,
}

impl PoolConfig {
#[inline]
pub fn new() -> Self {
Self::default()
}

#[inline]
pub fn timeout(&mut self, timeout: u32) -> &mut Self {
self.timeout = timeout;
self
}

#[inline]
pub fn idle_time(&mut self, idle_time: u32) -> &mut Self {
self.idle_time = idle_time;
self
}

#[inline]
pub fn max_connection_pool_size(&mut self, size: u32) -> &mut Self {
self.max_connection_pool_size = size;
self
}

#[inline]
pub fn min_connection_pool_size(&mut self, size: u32) -> &mut Self {
self.min_connection_pool_size = size;
self
}

#[inline]
pub fn addresses(&mut self, addresses: std::vec::Vec<String>) -> &mut Self {
self.addresses = addresses;
self
}

#[inline]
pub fn address(&mut self, address: String) -> &mut Self {
self.addresses.push(address);
self
}
}
Loading