Skip to content

Commit

Permalink
feat(aws-stepfunctions-tasks): add environment property for SageMaker…
Browse files Browse the repository at this point in the history
…CreateTrainingJob (#18976)

Add environment property for SageMakerCreateTrainingJob. Fixes issue #18919.

----

*By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
  • Loading branch information
mprencipe authored Feb 17, 2022
1 parent f8d8fe4 commit 60d6e66
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
2 changes: 2 additions & 0 deletions packages/@aws-cdk/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,8 @@ If your training job or model uses resources from AWS Marketplace,
[network isolation is required](https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html).
To do so, set the `enableNetworkIsolation` property to `true` for `SageMakerCreateModel` or `SageMakerCreateTrainingJob`.

To set environment variables for the Docker container use the `environment` property.

### Create Training Job

You can call the [`CreateTrainingJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html) API from a `Task` state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Duration, Lazy, Size, Stack } from '@aws-cdk/core';
import { Construct } from 'constructs';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types';
import { renderTags } from './private/utils';
import { renderEnvironment, renderTags } from './private/utils';

/**
* Properties for creating an Amazon SageMaker training job
Expand Down Expand Up @@ -85,6 +85,13 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps
* @default - No VPC
*/
readonly vpcConfig?: VpcConfig;

/**
* Environment variables to set in the Docker container.
*
* @default - No environment variables
*/
readonly environment?: { [key: string]: string };
}

/**
Expand Down Expand Up @@ -234,6 +241,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
...this.renderHyperparameters(this.props.hyperparameters),
...renderTags(this.props.tags),
...this.renderVpcConfig(this.props.vpcConfig),
...renderEnvironment(this.props.environment),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Size, Stack, Token } from '@aws-cdk/core';
import { Construct } from 'constructs';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
import { BatchStrategy, ModelClientOptions, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types';
import { renderTags } from './private/utils';
import { renderEnvironment, renderTags } from './private/utils';

/**
* Properties for creating an Amazon SageMaker transform job task
Expand Down Expand Up @@ -166,7 +166,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
private renderParameters(): { [key: string]: any } {
return {
...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}),
...this.renderEnvironment(this.props.environment),
...renderEnvironment(this.props.environment),
...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}),
...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}),
...this.props.modelClientOptions ? this.renderModelClientOptions(this.props.modelClientOptions) : {},
Expand Down Expand Up @@ -234,10 +234,6 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
};
}

private renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } {
return environment ? { Environment: environment } : {};
}

private makePolicyStatements(): iam.PolicyStatement[] {
const stack = Stack.of(this);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

export function renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } {
return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {};
}
}

export function renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } {
return environment ? { Environment: environment } : {};
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ test('create complex training job', () => {
vpcConfig: {
vpc,
},
environment: {
SOMEVAR: 'myvalue',
},
});
trainTask.addSecurityGroup(securityGroup);

Expand Down Expand Up @@ -285,6 +288,9 @@ test('create complex training job', () => {
{ Ref: 'VPCPrivateSubnet2SubnetCFCDAA7A' },
],
},
Environment: {
SOMEVAR: 'myvalue',
},
},
});
});
Expand Down

0 comments on commit 60d6e66

Please sign in to comment.