From 3f9a5c8298ae08b02b0c52bb24aa2675f2c0db43 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 30 Sep 2024 09:26:23 +0100 Subject: [PATCH] [ML] Fix check on E5 model platform compatibility (#113437) Creating an endpoint for the built in multilingual e5 model failed for linux optimised version due to an error in the logic that checks model compatibility. # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java --- docs/changelog/113437.yaml | 6 ++ .../xpack/inference/TextEmbeddingCrudIT.java | 11 ++- .../ElasticsearchInternalService.java | 27 ++++--- .../ElasticsearchInternalServiceTests.java | 72 ++++++++++++++++--- 4 files changed, 86 insertions(+), 30 deletions(-) create mode 100644 docs/changelog/113437.yaml diff --git a/docs/changelog/113437.yaml b/docs/changelog/113437.yaml new file mode 100644 index 0000000000000..98831958e63f8 --- /dev/null +++ b/docs/changelog/113437.yaml @@ -0,0 +1,6 @@ +pr: 113437 +summary: Fix check on E5 model platform compatibility +area: Machine Learning +type: bug +issues: + - 113577 diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 6c15b42dc65d5..01e8c30e3bf27 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -24,7 +24,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest { public void testPutE5Small_withNoModelVariant() { { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "testPutE5Small_withNoModelVariant"; expectThrows( org.elasticsearch.client.ResponseException.class, () -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity()) @@ -33,7 +33,7 @@ public void testPutE5Small_withNoModelVariant() { } public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withplatformagnosticvariant"; putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity()); var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); @@ -50,9 +50,8 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { deleteTextEmbeddingModel(inferenceEntityId); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198") public void testPutE5Small_withPlatformSpecificVariant() throws IOException { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withplatformspecificvariant"; if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) { putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity()); var models = getTrainedModel("_all"); @@ -77,7 +76,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException { } public void testPutE5Small_withFakeModelVariant() { - String inferenceEntityId = randomAlphaOfLength(10).toLowerCase(); + String inferenceEntityId = "teste5mall_withfakevariant"; expectThrows( org.elasticsearch.client.ResponseException.class, () -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity()) @@ -112,7 +111,7 @@ private Map putTextEmbeddingModel(String inferenceEntityId, Stri private String noModelIdVariantJsonEntity() { return """ { - "service": "text_embedding", + "service": "elasticsearch", "service_settings": { "num_allocations": 1, "num_threads": 1 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index d5401f61823db..2de88aa23c88b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -169,16 +169,14 @@ private void e5Case( Map serviceSettingsMap, ActionListener modelListener ) { - var e5ServiceSettings = MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap); + var esServiceSettingsBuilder = MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap); - if (e5ServiceSettings.getModelId() == null) { - e5ServiceSettings.setModelId(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)); - } - - if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, e5ServiceSettings)) { + if (esServiceSettingsBuilder.getModelId() == null) { + esServiceSettingsBuilder.setModelId(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)); + } else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) { throw new IllegalArgumentException( "Error parsing request config, model id does not match any models available on this platform. Was [" - + e5ServiceSettings.getModelId() + + esServiceSettingsBuilder.getModelId() + "]" ); } @@ -191,17 +189,18 @@ private void e5Case( inferenceEntityId, taskType, NAME, - (MultilingualE5SmallInternalServiceSettings) e5ServiceSettings.build() + (MultilingualE5SmallInternalServiceSettings) esServiceSettingsBuilder.build() ) ); } - private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic( - Set platformArchitectures, - InternalServiceSettings.Builder e5ServiceSettings - ) { - return e5ServiceSettings.getModelId().equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)) == false - && e5ServiceSettings.getModelId().equals(MULTILINGUAL_E5_SMALL_MODEL_ID) == false; + static boolean modelVariantValidForArchitecture(Set platformArchitectures, String modelId) { + if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) { + // platform agnostic model is always compatible + return true; + } + + return modelId.equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3bec202ed9e5e..e88572e4a6361 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -64,6 +64,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -165,6 +167,36 @@ public void testParseRequestConfig() { service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); } + } + + public void testParseRequestConfig_E5() { + { + var service = createService(mock(Client.class)); + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + MULTILINGUAL_E5_SMALL_MODEL_ID + ) + ) + ); + + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.TEXT_EMBEDDING, + settings, + Set.of(), + getModelVerificationActionListener(e5ServiceSettings) + ); + } // Invalid service settings { @@ -178,9 +210,8 @@ public void testParseRequestConfig() { 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4, - InternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, // we can't directly test the eland case until we mock - // the threadpool within the client + ElasticsearchInternalServiceSettings.MODEL_ID, + MULTILINGUAL_E5_SMALL_MODEL_ID, "not_a_valid_service_setting", randomAlphaOfLength(10) ) @@ -419,19 +450,15 @@ public void testParsePersistedConfig() { 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4, - InternalServiceSettings.MODEL_ID, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + ElasticsearchInternalServiceSettings.MODEL_ID, + MULTILINGUAL_E5_SMALL_MODEL_ID, ServiceFields.DIMENSIONS, 1 ) ) ); - var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( - 1, - 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID - ); + var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID); MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( randomInferenceEntityId, @@ -860,6 +887,31 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { assertThat(model, is(expectedModel)); } + public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { + { + var architectures = Set.of("Aarch64"); + assertFalse( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + { + var architectures = Set.of("linux-x86_64"); + assertTrue( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + { + var architectures = Set.of("linux-x86_64", "Aarch64"); + assertFalse( + ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86) + ); + assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID)); + } + } + private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElasticsearchInternalService(context);