Skip to content

Commit

Permalink
fix scaling down to 0 ML pods if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmuhs committed Sep 22, 2023
1 parent b2df331 commit f2aefb3
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,48 +48,78 @@ class MlAutoscalingContext {
final PersistentTasksCustomMetadata persistentTasks;

MlAutoscalingContext() {
anomalyDetectionTasks = List.of();
snapshotUpgradeTasks = List.of();
dataframeAnalyticsTasks = List.of();
modelAssignments = Map.of();

waitingAnomalyJobs = List.of();
waitingSnapshotUpgrades = List.of();
waitingAnalyticsJobs = List.of();
waitingAllocatedModels = List.of();
this(List.of(), List.of(), List.of(), Map.of(), List.of(), null);
}

mlNodes = List.of();
persistentTasks = null;
MlAutoscalingContext(
final Collection<PersistentTasksCustomMetadata.PersistentTask<?>> anomalyDetectionTasks,
final Collection<PersistentTasksCustomMetadata.PersistentTask<?>> snapshotUpgradeTasks,
final Collection<PersistentTasksCustomMetadata.PersistentTask<?>> dataframeAnalyticsTasks,
final Map<String, TrainedModelAssignment> modelAssignments,
final List<DiscoveryNode> mlNodes,
final PersistentTasksCustomMetadata persistentTasks
) {
this.anomalyDetectionTasks = anomalyDetectionTasks;
this.snapshotUpgradeTasks = snapshotUpgradeTasks;
this.dataframeAnalyticsTasks = dataframeAnalyticsTasks;
this.modelAssignments = modelAssignments;
this.mlNodes = mlNodes;
this.persistentTasks = persistentTasks;

waitingAnomalyJobs = waitingAnomalyJobs(anomalyDetectionTasks);
waitingSnapshotUpgrades = getWaitingSnapshotUpgrades(snapshotUpgradeTasks);
waitingAnalyticsJobs = getWaitingAnalyticsJobs(dataframeAnalyticsTasks);
waitingAllocatedModels = getWaitingAllocatedModels(modelAssignments);
}

MlAutoscalingContext(ClusterState clusterState) {
PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
anomalyDetectionTasks = anomalyDetectionTasks(tasks);
snapshotUpgradeTasks = snapshotUpgradeTasks(tasks);
dataframeAnalyticsTasks = dataframeAnalyticsTasks(tasks);
persistentTasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);

anomalyDetectionTasks = anomalyDetectionTasks(persistentTasks);
snapshotUpgradeTasks = snapshotUpgradeTasks(persistentTasks);
dataframeAnalyticsTasks = dataframeAnalyticsTasks(persistentTasks);
modelAssignments = TrainedModelAssignmentMetadata.fromState(clusterState).allAssignments();

waitingAnomalyJobs = anomalyDetectionTasks.stream()
waitingAnomalyJobs = waitingAnomalyJobs(anomalyDetectionTasks);
waitingSnapshotUpgrades = getWaitingSnapshotUpgrades(snapshotUpgradeTasks);
waitingAnalyticsJobs = getWaitingAnalyticsJobs(dataframeAnalyticsTasks);
waitingAllocatedModels = getWaitingAllocatedModels(modelAssignments);

mlNodes = getMlNodes(clusterState);
}

private static List<String> getWaitingAllocatedModels(Map<String, TrainedModelAssignment> modelAssignments) {
return modelAssignments.entrySet()
.stream()
// TODO: Eventually care about those that are STARTED but not FULLY_ALLOCATED
.filter(e -> e.getValue().getAssignmentState().equals(AssignmentState.STARTING) && e.getValue().getNodeRoutingTable().isEmpty())
.map(Map.Entry::getKey)
.toList();
}

private static List<String> getWaitingAnalyticsJobs(
Collection<PersistentTasksCustomMetadata.PersistentTask<?>> dataframeAnalyticsTasks
) {
return dataframeAnalyticsTasks.stream()
.filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
.map(t -> ((OpenJobAction.JobParams) t.getParams()).getJobId())
.map(t -> ((StartDataFrameAnalyticsAction.TaskParams) t.getParams()).getId())
.toList();
waitingSnapshotUpgrades = snapshotUpgradeTasks.stream()
}

private static List<String> getWaitingSnapshotUpgrades(
Collection<PersistentTasksCustomMetadata.PersistentTask<?>> snapshotUpgradeTasks
) {
return snapshotUpgradeTasks.stream()
.filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
.map(t -> ((SnapshotUpgradeTaskParams) t.getParams()).getJobId())
.toList();
waitingAnalyticsJobs = dataframeAnalyticsTasks.stream()
}

private static List<String> waitingAnomalyJobs(Collection<PersistentTasksCustomMetadata.PersistentTask<?>> anomalyDetectionTasks) {
return anomalyDetectionTasks.stream()
.filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
.map(t -> ((StartDataFrameAnalyticsAction.TaskParams) t.getParams()).getId())
.toList();
waitingAllocatedModels = modelAssignments.entrySet()
.stream()
// TODO: Eventually care about those that are STARTED but not FULLY_ALLOCATED
.filter(e -> e.getValue().getAssignmentState().equals(AssignmentState.STARTING) && e.getValue().getNodeRoutingTable().isEmpty())
.map(Map.Entry::getKey)
.map(t -> ((OpenJobAction.JobParams) t.getParams()).getJobId())
.toList();

mlNodes = getMlNodes(clusterState);
persistentTasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
}

private static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> anomalyDetectionTasks(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ static void getMemoryAndProcessors(
&& perNodeAvailableModelMemoryInBytes > 0
&& extraModelMemoryInBytes == 0
&& extraProcessors == 0
&& modelMemoryBytesSum < perNodeMemoryInBytes * (osStatsPerNode.size() - 1)
&& modelMemoryBytesSum <= perNodeMemoryInBytes * (osStatsPerNode.size() - 1)
&& (perNodeModelMemoryInBytes.size() < osStatsPerNode.size() // a node has no assigned jobs
|| checkIfOneNodeCouldBeRemoved(
perNodeModelMemoryInBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,30 @@
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.VersionInformation;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.autoscaling.MlAutoscalingStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;

import java.net.InetAddress;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -867,6 +879,209 @@ public void testCheckIfOneNodeCouldBeRemovedProcessorAndMemory() {
);
}

public void testGetMemoryAndProcessorsScaleDownToZero() throws InterruptedException {
MlAutoscalingContext mlAutoscalingContext = new MlAutoscalingContext();
MlMemoryTracker mockTracker = mock(MlMemoryTracker.class);

long memory = randomLongBetween(100, 1_000_000);
long perNodeAvailableModelMemoryInBytes = memory / 2;

// scale to zero
this.<MlAutoscalingStats>assertAsync(
listener -> MlAutoscalingResourceTracker.getMemoryAndProcessors(
mlAutoscalingContext,
mockTracker,
Map.of(
"ml-1",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
)
),
perNodeAvailableModelMemoryInBytes,
10,
MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE,
listener
),
stats -> {
assertEquals(memory, stats.perNodeMemoryInBytes());
assertEquals(1, stats.nodes());
assertEquals(0, stats.minNodes());
assertEquals(0, stats.extraSingleNodeProcessors());
assertEquals(memory, stats.removeNodeMemoryInBytes());
assertEquals(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), stats.perNodeMemoryOverheadInBytes());
}
);

// 3 nodes with no jobs
this.<MlAutoscalingStats>assertAsync(
listener -> MlAutoscalingResourceTracker.getMemoryAndProcessors(
mlAutoscalingContext,
mockTracker,
Map.of(
"ml-1",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
),
"ml-2",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
),
"ml-3",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
)
),
perNodeAvailableModelMemoryInBytes,
10,
MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE,
listener
),
stats -> {
assertEquals(memory, stats.perNodeMemoryInBytes());
assertEquals(3, stats.nodes());
assertEquals(0, stats.minNodes());
assertEquals(0, stats.extraSingleNodeProcessors());
assertEquals(memory, stats.removeNodeMemoryInBytes());
assertEquals(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), stats.perNodeMemoryOverheadInBytes());
}
);
}

// scenario: 3 ml nodes, but only 2 have assigned models
public void testGetMemoryAndProcessorsScaleDown() throws InterruptedException {
Map<String, String> nodeAttr = Map.of(
MachineLearning.MACHINE_MEMORY_NODE_ATTR,
"1000000000",
MachineLearning.MAX_JVM_SIZE_NODE_ATTR,
"400000000",
MachineLearning.ML_CONFIG_VERSION_NODE_ATTR,
"7.2.0"
);

MlAutoscalingContext mlAutoscalingContext = new MlAutoscalingContext(
List.of(),
List.of(),
List.of(),
Map.of(
"model-1",
TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(
"model-1",
"model-1-deployment",
400,
1,
2,
100,
null,
Priority.NORMAL,
0L,
0L
)
).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(),
"model-2",
TrainedModelAssignment.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(
"model-2",
"model-2-deployment",
400,
1,
2,
100,
null,
Priority.NORMAL,
0L,
0L
)
).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build()
),

List.of(
new DiscoveryNode(
"ml-node-name-1",
"ml-node-1",
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
nodeAttr,
Set.of(DiscoveryNodeRole.ML_ROLE),
VersionInformation.CURRENT
),
new DiscoveryNode(
"ml-node-name-3",
"ml-node-3",
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
nodeAttr,
Set.of(DiscoveryNodeRole.ML_ROLE),
VersionInformation.CURRENT
)
),
PersistentTasksCustomMetadata.builder().build()
);
MlMemoryTracker mockTracker = mock(MlMemoryTracker.class);

long memory = 1000000000;
long perNodeAvailableModelMemoryInBytes = 600000000;

this.<MlAutoscalingStats>assertAsync(
listener -> MlAutoscalingResourceTracker.getMemoryAndProcessors(
mlAutoscalingContext,
mockTracker,
Map.of(
"ml-node-1",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
),
"ml-node-2",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
),
"ml-node-3",
new OsStats(
randomNonNegativeLong(),
new OsStats.Cpu(randomShort(), null),
new OsStats.Mem(memory, memory, randomLongBetween(0, memory)),
new OsStats.Swap(randomNonNegativeLong(), randomNonNegativeLong()),
null
)
),
perNodeAvailableModelMemoryInBytes,
10,
MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE,
listener
),
stats -> {
assertEquals(memory, stats.perNodeMemoryInBytes());
assertEquals(3, stats.nodes());
assertEquals(1, stats.minNodes());
assertEquals(0, stats.extraSingleNodeProcessors());
assertEquals(memory, stats.removeNodeMemoryInBytes());
assertEquals(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), stats.perNodeMemoryOverheadInBytes());
}
);
}

private <T> void assertAsync(Consumer<ActionListener<T>> function, Consumer<T> furtherTests) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
AtomicBoolean listenerCalled = new AtomicBoolean(false);
Expand Down

0 comments on commit f2aefb3

Please sign in to comment.