diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md index 17948bc1112f4..ad4c2a365c38b 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md @@ -1167,6 +1167,8 @@ If your training job or model uses resources from AWS Marketplace, [network isolation is required](https://docs.aws.amazon.com/sagemaker/latest/dg/mkt-algo-model-internet-free.html). To do so, set the `enableNetworkIsolation` property to `true` for `SageMakerCreateModel` or `SageMakerCreateTrainingJob`. +To set environment variables for the Docker container use the `environment` property. + ### Create Training Job You can call the [`CreateTrainingJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html) API from a `Task` state. diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index 64680dc357747..ba3274579eb37 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -5,7 +5,7 @@ import { Duration, Lazy, Size, Stack } from '@aws-cdk/core'; import { Construct } from 'constructs'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types'; -import { renderTags } from './private/utils'; +import { renderEnvironment, renderTags } from './private/utils'; /** * Properties for creating an Amazon SageMaker training job @@ -85,6 +85,13 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps * @default - No VPC */ readonly vpcConfig?: VpcConfig; + + /** + * Environment variables to set in the Docker container. + * + * @default - No environment variables + */ + readonly environment?: { [key: string]: string }; } /** @@ -234,6 +241,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam ...this.renderHyperparameters(this.props.hyperparameters), ...renderTags(this.props.tags), ...this.renderVpcConfig(this.props.vpcConfig), + ...renderEnvironment(this.props.environment), }; } diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index cc108c45e4439..84c08fa22529d 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -5,7 +5,7 @@ import { Size, Stack, Token } from '@aws-cdk/core'; import { Construct } from 'constructs'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; import { BatchStrategy, ModelClientOptions, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types'; -import { renderTags } from './private/utils'; +import { renderEnvironment, renderTags } from './private/utils'; /** * Properties for creating an Amazon SageMaker transform job task @@ -166,7 +166,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { private renderParameters(): { [key: string]: any } { return { ...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}), - ...this.renderEnvironment(this.props.environment), + ...renderEnvironment(this.props.environment), ...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}), ...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}), ...this.props.modelClientOptions ? this.renderModelClientOptions(this.props.modelClientOptions) : {}, @@ -234,10 +234,6 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { }; } - private renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } { - return environment ? { Environment: environment } : {}; - } - private makePolicyStatements(): iam.PolicyStatement[] { const stack = Stack.of(this); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/private/utils.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/private/utils.ts index e308fd890864c..bbcaed118a083 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/private/utils.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/private/utils.ts @@ -1,4 +1,7 @@ - export function renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } { return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {}; -} \ No newline at end of file +} + +export function renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } { + return environment ? { Environment: environment } : {}; +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts index 7e7a5eff26348..9d5ebb1f9711a 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts @@ -192,6 +192,9 @@ test('create complex training job', () => { vpcConfig: { vpc, }, + environment: { + SOMEVAR: 'myvalue', + }, }); trainTask.addSecurityGroup(securityGroup); @@ -285,6 +288,9 @@ test('create complex training job', () => { { Ref: 'VPCPrivateSubnet2SubnetCFCDAA7A' }, ], }, + Environment: { + SOMEVAR: 'myvalue', + }, }, }); });