Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: refactor rust SDK #79

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ name = "milvus"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tonic = {version = "0.8.2", features = ["tls", "tls-roots"]}
tonic = { version = "0.8.2", features = ["tls", "tls-roots"] }
prost = "0.11.0"
tokio = { version = "1.17.0", features = ["full"] }
thiserror = "1.0"
Expand All @@ -21,9 +21,12 @@ anyhow = "1.0"
strum = "0.24"
strum_macros = "0.24"
base64 = "0.21.0"
dashmap = "5.5.3"

[build-dependencies]
tonic-build = "0.8.2"
tonic-build = { version = "0.8.2", default-features = false, features = [
"prost",
] }

[dev-dependencies]
rand = "0.8.5"
27 changes: 15 additions & 12 deletions examples/collection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use milvus::index::{IndexParams, IndexType};
use milvus::schema::CollectionSchemaBuilder;
use milvus::options::LoadOptions;
use milvus::query::QueryOptions;
use milvus::schema::{CollectionSchemaBuilder, CollectionSchema};
use milvus::{
client::Client, collection::Collection, data::FieldColumn, error::Error, schema::FieldSchema,
};
Expand Down Expand Up @@ -29,42 +31,43 @@ async fn main() -> Result<(), Error> {
DIM,
))
.build()?;
let collection = client.create_collection(schema.clone(), None).await?;
client.create_collection(schema.clone(), None).await?;

if let Err(err) = hello_milvus(&collection).await {
if let Err(err) = hello_milvus(&client, &schema).await {
println!("failed to run hello milvus: {:?}", err);
}
collection.drop().await?;
client.drop_collection(schema.name()).await?;

Ok(())
}

async fn hello_milvus(collection: &Collection) -> Result<(), Error> {
async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result<(), Error> {
let mut embed_data = Vec::<f32>::new();
for _ in 1..=DIM * 1000 {
let mut rng = rand::thread_rng();
let embed = rng.gen();
embed_data.push(embed);
}
let embed_column = FieldColumn::new(
collection.schema().get_field(DEFAULT_VEC_FIELD).unwrap(),
collection.get_field(DEFAULT_VEC_FIELD).unwrap(),
embed_data,
);

collection.insert(vec![embed_column], None).await?;
collection.flush().await?;
client.insert(collection.name(), vec![embed_column], None).await?;
client.flush(collection.name()).await?;
let index_params = IndexParams::new(
"feature_index".to_owned(),
IndexType::IvfFlat,
milvus::index::MetricType::L2,
HashMap::from([("nlist".to_owned(), "32".to_owned())]),
);
collection
.create_index(DEFAULT_VEC_FIELD, index_params)
client
.create_index(collection.name(), DEFAULT_VEC_FIELD, index_params)
.await?;
collection.load(1).await?;
client.load_collection(collection.name(), Some(LoadOptions::default())).await?;

let result = collection.query::<_, [&str; 0]>("id > 0", []).await?;
let options = QueryOptions::default();
let result = client.query(collection.name(), "id > 0", &options).await?;

println!(
"result num: {}",
Expand Down
133 changes: 11 additions & 122 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::collection::CollectionCache;
use crate::config::RPC_TIMEOUT;
use crate::error::{Error, Result};
use crate::options::CreateCollectionOptions;
pub use crate::proto::common::ConsistencyLevel;
use crate::proto::common::{MsgBase, MsgType};
use crate::proto::milvus::milvus_service_client::MilvusServiceClient;
use crate::proto::milvus::{
CreateCollectionRequest, DescribeCollectionRequest, DropCollectionRequest, FlushRequest,
HasCollectionRequest, ShowCollectionsRequest,
};
use crate::schema::CollectionSchema;
use crate::proto::milvus::FlushRequest;
use crate::utils::status_to_result;
use crate::{collection::Collection, proto::common::ErrorCode};
use base64::engine::general_purpose;
use base64::Engine;
use prost::bytes::BytesMut;
use prost::Message;
use std::collections::HashMap;
use std::convert::TryInto;
use std::time::Duration;
Expand Down Expand Up @@ -95,9 +88,10 @@ where
}
}

#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Client {
client: MilvusServiceClient<InterceptedService<Channel, AuthInterceptor>>,
pub(crate) client: MilvusServiceClient<InterceptedService<Channel, AuthInterceptor>>,
pub(crate) collection_cache: CollectionCache,
}

impl Client {
Expand Down Expand Up @@ -142,110 +136,10 @@ impl Client {

let client = MilvusServiceClient::with_interceptor(conn, auth_interceptor);

Ok(Self { client })
}

pub async fn create_collection(
&self,
schema: CollectionSchema,
options: Option<CreateCollectionOptions>,
) -> Result<Collection> {
let options = options.unwrap_or_default();
let schema: crate::proto::schema::CollectionSchema = schema.into();
let mut buf = BytesMut::new();

schema.encode(&mut buf)?;

let status = self
.client
.clone()
.create_collection(CreateCollectionRequest {
base: Some(MsgBase::new(MsgType::CreateCollection)),
collection_name: schema.name.to_string(),
schema: buf.to_vec(),
shards_num: options.shard_num,
consistency_level: options.consistency_level as i32,
..Default::default()
})
.await?
.into_inner();

status_to_result(&Some(status))?;

self.get_collection(&schema.name).await
}

pub async fn get_collection(&self, collection_name: &str) -> Result<Collection> {
let resp = self
.client
.clone()
.describe_collection(DescribeCollectionRequest {
base: Some(MsgBase::new(MsgType::DescribeCollection)),
db_name: "".to_owned(),
collection_name: collection_name.to_owned(),
collection_id: 0,
time_stamp: 0,
})
.await?
.into_inner();

status_to_result(&resp.status)?;

Ok(Collection::new(self.client.clone(), resp))
}

pub async fn has_collection<S>(&self, name: S) -> Result<bool>
where
S: Into<String>,
{
let name = name.into();
let res = self
.client
.clone()
.has_collection(HasCollectionRequest {
base: Some(MsgBase::new(MsgType::HasCollection)),
db_name: "".to_string(),
collection_name: name.clone(),
time_stamp: 0,
})
.await?
.into_inner();

status_to_result(&res.status)?;

Ok(res.value)
}

pub async fn drop_collection<S>(&self, name: S) -> Result<()>
where
S: Into<String>,
{
status_to_result(&Some(
self.client
.clone()
.drop_collection(DropCollectionRequest {
base: Some(MsgBase::new(MsgType::DropCollection)),
collection_name: name.into(),
..Default::default()
})
.await?
.into_inner(),
))
}

pub async fn list_collections(&self) -> Result<Vec<String>> {
let response = self
.client
.clone()
.show_collections(ShowCollectionsRequest {
base: Some(MsgBase::new(MsgType::ShowCollections)),
..Default::default()
})
.await?
.into_inner();

status_to_result(&response.status)?;
Ok(response.collection_names)
Ok(Self {
client: client.clone(),
collection_cache: CollectionCache::new(client),
})
}

pub async fn flush_collections<C>(&self, collections: C) -> Result<HashMap<String, Vec<i64>>>
Expand Down Expand Up @@ -285,10 +179,7 @@ impl Client {
/// # Returns
///
/// Returns a `Result` indicating success or failure.
pub async fn create_alias<S>(&self, collection_name: S, alias: S) -> Result<()>
where
S: Into<String>,
{
pub async fn create_alias(&self, collection_name: impl Into<String>, alias: impl Into<String>) -> Result<()>{
let collection_name = collection_name.into();
let alias = alias.into();
status_to_result(&Some(
Expand Down Expand Up @@ -342,9 +233,7 @@ impl Client {
/// # Returns
///
/// Returns a `Result` indicating success or failure.
pub async fn alter_alias<S>(&self, collection_name: S, alias: S) -> Result<()>
where
S: Into<String>,
pub async fn alter_alias(&self, collection_name: impl Into<String>, alias: impl Into<String>) -> Result<()>
{
let collection_name = collection_name.into();
let alias = alias.into();
Expand Down
Loading
Loading