Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
hhbyyh committed Aug 3, 2019
1 parent 60368ad commit b8e7afa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_Pyt

// create Input tuple
int input_size = jenv -> GetArrayLength(input_jstorage);
jint* c_input_offsets = (jint*) jenv -> GetPrimitiveArrayCritical(input_joffset, JNI_FALSE);
jint* c_input_offsets = (jint*) jenv -> GetIntArrayElements(input_joffset, JNI_FALSE);
std::vector<c10::IValue> input_vector;

for (int i = 0; i < input_size; i++) {
Expand All @@ -228,7 +228,6 @@ JNIEXPORT jobjectArray JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_Pyt
j_shape_vector.push_back(tensor_shape);
c_shape_vector.push_back(input_c_shape);
}
auto input_tuple = torch::jit::Tuple::create(input_vector);

// Execute the model
std::shared_ptr<torch::jit::script::Module> model_ptr = modelHandles[nativeRef];
Expand All @@ -237,12 +236,15 @@ JNIEXPORT jobjectArray JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_Pyt

if (isTraining) {
mtx.lock();
auto input_tuple = torch::jit::Tuple::create(input_vector);
modelInputs[nativeRef] = input_tuple;
modelOutputs[nativeRef] = output;
mtx.unlock();
}

// TODO check if the release will affect cached modelInputs[nativeRef]
// Release critical part
jenv -> ReleaseIntArrayElements(input_joffset, c_input_offsets, JNI_ABORT);
for (int i = 0; i < input_size; i++) {
jenv -> ReleasePrimitiveArrayCritical(j_data_vector[i], c_data_vector[i], 0);
jenv -> ReleasePrimitiveArrayCritical(j_shape_vector[i], c_shape_vector[i], 0);
Expand Down Expand Up @@ -291,7 +293,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_Pyt

// create gradOutput tuple
int input_size = jenv -> GetArrayLength(input_jstorage);
jint* c_input_offsets = (jint*) jenv -> GetPrimitiveArrayCritical(input_joffset, JNI_FALSE);
jint* c_input_offsets = (jint*) jenv -> GetIntArrayElements(input_joffset, JNI_FALSE);
std::vector<c10::IValue> input_tuple;

for (int i = 0; i < input_size; i++) {
Expand All @@ -316,22 +318,22 @@ JNIEXPORT jobjectArray JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_Pyt
j_shape_vector.push_back(tensor_shape);
c_shape_vector.push_back(input_c_shape);
}
auto gradOutput_table = torch::jit::Tuple::create(input_tuple);

auto y = modelOutputs[nativeRef];
if (y.isTuple()) {
auto yTuple = y.toTuple();
assert (input_size == yTuple -> elements().size());
for (int i = 0; i < input_size; i++) {
auto gradTensor = gradOutput_table -> elements()[i].toTensor();
auto gradTensor = input_tuple[i].toTensor();
auto outputTensor = yTuple -> elements()[i].toTensor();
outputTensor.backward(gradTensor);
}
} else {
y.toTensor().backward(gradOutput_table -> elements()[0].toTensor());
y.toTensor().backward(input_tuple[0].toTensor());
}

// Release critical part
jenv -> ReleaseIntArrayElements(input_joffset, c_input_offsets, JNI_ABORT);
for (int i = 0; i < input_size; i++) {
jenv -> ReleasePrimitiveArrayCritical(j_data_vector[i], c_data_vector[i], 0);
jenv -> ReleasePrimitiveArrayCritical(j_shape_vector[i], c_shape_vector[i], 0);
Expand Down Expand Up @@ -360,7 +362,7 @@ JNIEXPORT jobject JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_PytorchM

// create input tuple
int input_size = jenv -> GetArrayLength(input_jstorage);
jint* c_input_offsets = (jint*) jenv -> GetPrimitiveArrayCritical(input_joffset, JNI_FALSE);
jint* c_input_offsets = (jint*) jenv -> GetIntArrayElements(input_joffset, JNI_FALSE);
std::vector<c10::IValue> input_tuple;

for (int i = 0; i < input_size; i++) {
Expand Down Expand Up @@ -390,7 +392,7 @@ JNIEXPORT jobject JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_PytorchM

// create label tuple
int label_size = jenv -> GetArrayLength(label_jstorage);
jint* c_label_offsets = (jint*) jenv -> GetPrimitiveArrayCritical(label_joffset, JNI_FALSE);
jint* c_label_offsets = (jint*) jenv -> GetIntArrayElements(label_joffset, JNI_FALSE);
std::vector<c10::IValue> label_tuple;
for (int i = 0; i < label_size; i++) {
jfloatArray tensor_storage = (jfloatArray)jenv->GetObjectArrayElement(label_jstorage, i);
Expand Down Expand Up @@ -445,6 +447,8 @@ JNIEXPORT jobject JNICALL Java_com_intel_analytics_zoo_pipeline_api_net_PytorchM
mtx.unlock();

// Release critical part
jenv -> ReleaseIntArrayElements(input_joffset, c_input_offsets, JNI_ABORT);
jenv -> ReleaseIntArrayElements(label_joffset, c_label_offsets, JNI_ABORT);
for (int i = 0; i < input_size + label_size; i++) {
jenv -> ReleasePrimitiveArrayCritical(j_data_vector[i], c_data_vector[i], 0);
jenv -> ReleasePrimitiveArrayCritical(j_shape_vector[i], c_shape_vector[i], 0);
Expand Down
Binary file modified zoo/src/main/resources/pytorch/libpytorch-engine.so
Binary file not shown.

0 comments on commit b8e7afa

Please sign in to comment.