Skip to content

Commit

Permalink
[ML] Downloaded and write model parts using multiple streams (elastic…
Browse files Browse the repository at this point in the history
…#111684)

Uses the range header to split the model download into multiple streams
using a separate thread for each stream
# Conflicts:
#	x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java
#	x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java
  • Loading branch information
davidkyle committed Sep 13, 2024
1 parent 84b4c59 commit 47c6502
Show file tree
Hide file tree
Showing 12 changed files with 897 additions and 178 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,23 @@
import java.util.Locale;

public enum Level {
INFO,
WARNING,
ERROR;
INFO {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.INFO;
}
},
WARNING {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.WARN;
}
},
ERROR {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.ERROR;
}
};

public abstract org.apache.logging.log4j.Level log4jLevel();

/**
* Case-insensitive from string method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down
Loading

0 comments on commit 47c6502

Please sign in to comment.