Skip to content

Commit

Permalink
[Backport 2.x] Add model version to model metadata and change model m…
Browse files Browse the repository at this point in the history
…etadata reads to be from cluster metadata (#2063)

* Add model version to model metadata and change model metadata reads to be from cluster metadata (#2005)

* Add model version to model metadata

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add model version to model metadata and change model metadata reads to be from cluster metadata

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add changelog entry

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Set version from config context

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix spotless

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Update model index mappings

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Change field mapper to read model version

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* remove println

Signed-off-by: John Mazanec <jmazane@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
Signed-off-by: John Mazanec <jmazane@amazon.com>
Co-authored-by: John Mazanec <jmazane@amazon.com>
(cherry picked from commit 6814c8f)

* Fix tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
ryanbogan committed Sep 9, 2024
1 parent ed5d7d1 commit d69a038
Show file tree
Hide file tree
Showing 27 changed files with 490 additions and 321 deletions.
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Compatible with OpenSearch 2.17.0
* Add spaceType as a top level optional parameter while creating vector field. [#2044](https://github.com/opensearch-project/k-NN/pull/2044)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005)
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public class KNNConstants {
public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";

// Lucene specific constants
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata
return KNNMethodConfigContext.builder()
.vectorDataType(modelMetadata.getVectorDataType())
.dimension(modelMetadata.getDimension())
.versionCreated(Version.V_2_14_0)
.versionCreated(modelMetadata.getModelVersion())
.mode(modelMetadata.getMode())
.compressionLevel(modelMetadata.getCompressionLevel())
.build();
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class IndexUtil {
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION = Version.V_2_17_0;
// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();
public static final Set<VectorDataType> VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE);
Expand Down Expand Up @@ -394,6 +395,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE);
put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE);
put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
if (CompressionLevel.isConfigured(modelMetadata.getCompressionLevel())) {
put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, modelMetadata.getCompressionLevel().getName());
}
put(KNNConstants.MODEL_VERSION, modelMetadata.getModelVersion());
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
Expand Down
56 changes: 46 additions & 10 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.Version;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -59,14 +60,14 @@ public class ModelMetadata implements Writeable, ToXContentObject {
private String error;
@Getter
private final CompressionLevel compressionLevel;
private final Version version;

/**
* Constructor
*
* @param in Stream input
*/
public ModelMetadata(StreamInput in) throws IOException {
String tempTrainingNodeAssignment;
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.dimension = in.readInt();
Expand Down Expand Up @@ -96,7 +97,6 @@ public ModelMetadata(StreamInput in) throws IOException {
} else {
this.vectorDataType = VectorDataType.DEFAULT;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) {
this.mode = Mode.fromName(in.readOptionalString());
this.compressionLevel = CompressionLevel.fromName(in.readOptionalString());
Expand All @@ -105,6 +105,11 @@ public ModelMetadata(StreamInput in) throws IOException {
this.compressionLevel = CompressionLevel.NOT_CONFIGURED;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VERSION)) {
this.version = Version.fromString(in.readString());
} else {
this.version = Version.V_EMPTY;
}
}

/**
Expand Down Expand Up @@ -133,7 +138,8 @@ public ModelMetadata(
MethodComponentContext methodComponentContext,
VectorDataType vectorDataType,
Mode mode,
CompressionLevel compressionLevel
CompressionLevel compressionLevel,
Version version
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -159,6 +165,7 @@ public ModelMetadata(
this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null");
this.mode = Objects.requireNonNull(mode, "Mode must not be null");
this.compressionLevel = Objects.requireNonNull(compressionLevel, "Compression level must not be null");
this.version = Objects.requireNonNull(version, "model version must not be null");
}

/**
Expand Down Expand Up @@ -246,6 +253,14 @@ public VectorDataType getVectorDataType() {
return vectorDataType;
}

/**
* Getter for the model version
* @return version
*/
public Version getModelVersion() {
return version;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -279,7 +294,8 @@ public String toString() {
methodComponentContext.toClusterStateString(),
vectorDataType.getValue(),
mode.getName(),
compressionLevel.getName()
compressionLevel.getName(),
version.toString()
);
}

Expand Down Expand Up @@ -317,6 +333,7 @@ public int hashCode() {
.append(getVectorDataType())
.append(getMode())
.append(getCompressionLevel())
.append(getModelVersion())
.toHashCode();
}

Expand All @@ -329,15 +346,15 @@ public int hashCode() {
public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);
int length = modelMetadataArray.length;

if (length < 7 || length > 12) {
if (length < 7 || length > 13) {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>\". or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>,<Version>\"."
);
}

Expand All @@ -357,6 +374,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
CompressionLevel compressionLevel = length > 11
? CompressionLevel.fromName(modelMetadataArray[11])
: CompressionLevel.NOT_CONFIGURED;
Version version = length > 12 ? Version.fromString(modelMetadataArray[12]) : Version.V_EMPTY;

log.debug(getLogMessage(length));

Expand All @@ -372,7 +390,8 @@ public static ModelMetadata fromString(String modelMetadataString) {
methodComponentContext,
vectorDataType,
mode,
compressionLevel
compressionLevel,
version
);
}

Expand All @@ -386,9 +405,10 @@ private static String getLogMessage(int length) {
return "Model metadata contains training node assignment and method context.";
case 10:
return "Model metadata contains training node assignment, method context and vector data type.";
case 11:
case 12:
return "Model metadata contains mode and compression level";
case 13:
return "Model metadata contains training node assignment, method context, vector data type, and version";
default:
throw new IllegalArgumentException("Unexpected metadata array length: " + length);
}
Expand Down Expand Up @@ -423,6 +443,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD);
Object mode = modelSourceMap.get(KNNConstants.MODE_PARAMETER);
Object compressionLevel = modelSourceMap.get(KNNConstants.COMPRESSION_LEVEL_PARAMETER);
Object version = modelSourceMap.get(KNNConstants.MODEL_VERSION);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
Expand All @@ -447,6 +468,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
vectorDataType = VectorDataType.DEFAULT.getValue();
}

if (version == null) {
version = Version.V_EMPTY;
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString(space)),
Expand All @@ -459,7 +484,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
(MethodComponentContext) methodComponentContext,
VectorDataType.get(objectToString(vectorDataType)),
Mode.fromName(objectToString(mode)),
CompressionLevel.fromName(objectToString(compressionLevel))
CompressionLevel.fromName(objectToString(compressionLevel)),
Version.fromString(version.toString())
);
return modelMetadata;
}
Expand All @@ -486,6 +512,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(mode.getName());
out.writeOptionalString(compressionLevel.getName());
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VERSION)) {
out.writeString(version.toString());
}
}

@Override
Expand Down Expand Up @@ -517,6 +546,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(KNNConstants.COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName());
}
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VERSION)) {
String versionString = "unknown";
if (version != Version.V_EMPTY) {
versionString = version.toString();
}
builder.field(KNNConstants.MODEL_VERSION, versionString);
}
return builder;
}
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public static ModelMetadata getModelMetadata(final String modelId) {
if (StringUtils.isEmpty(modelId)) {
return null;
}
final Model model = ModelCache.getInstance().get(modelId);
final ModelMetadata modelMetadata = model.getModelMetadata();
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (isModelCreated(modelMetadata) == false) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ public TrainingJob(
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getVectorDataType(),
mode,
compressionLevel
compressionLevel,
knnMethodConfigContext.getVersionCreated()
),
null,
this.modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
Expand Down Expand Up @@ -166,11 +165,11 @@ private void train(TrainingJob trainingJob) {
private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException,
ExecutionException, InterruptedException {
if (update) {
Model model = modelDao.get(trainingJob.getModelId());
if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) {
ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId());
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
modelDao.update(trainingJob.getModel(), listener);
} else {
logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState());
logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState());
}
} else {
modelDao.put(trainingJob.getModel(), listener);
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/mappings/model-index.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
},
"compression_level": {
"type": "keyword"
},
"model_version": {
"type": "keyword"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import org.opensearch.Version;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -69,7 +70,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException
MethodComponentContext.EMPTY,
VectorDataType.FLOAT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
CompressionLevel.NOT_CONFIGURED,
Version.V_EMPTY
);

Model model = new Model(modelMetadata, modelBlob, modelId);
Expand Down
Loading

0 comments on commit d69a038

Please sign in to comment.