diff --git a/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts b/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts index 00334a0934323..c43a7bec159a8 100644 --- a/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts +++ b/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts @@ -247,6 +247,8 @@ export abstract class State extends cdk.Construct implements IChainable { * @internal */ protected _addRetry(props: RetryProps = {}) { + validateErrors(props.errors); + this.retries.push({ ...props, errors: props.errors ? props.errors : [Errors.ALL], @@ -258,6 +260,8 @@ export abstract class State extends cdk.Construct implements IChainable { * @internal */ protected _addCatch(handler: State, props: CatchProps = {}) { + validateErrors(props.errors); + this.catches.push({ next: handler, props: { @@ -386,8 +390,8 @@ export abstract class State extends cdk.Construct implements IChainable { */ protected renderRetryCatch(): any { return { - Retry: renderList(this.retries, renderRetry), - Catch: renderList(this.catches, renderCatch), + Retry: renderList(this.retries, renderRetry, (a, b) => compareErrors(a.errors, b.errors)), + Catch: renderList(this.catches, renderCatch, (a, b) => compareErrors(a.props.errors, b.props.errors)), }; } @@ -501,12 +505,38 @@ function renderCatch(c: CatchTransition) { }; } +/** + * Compares a list of Errors to move Errors.ALL last in a sort function + */ +function compareErrors(a?: string[], b?: string[]) { + if (a?.includes(Errors.ALL)) { + return 1; + } + if (b?.includes(Errors.ALL)) { + return -1; + } + return 0; +} + +/** + * Validates an errors list + */ +function validateErrors(errors?: string[]) { + if (errors?.includes(Errors.ALL) && errors.length > 1) { + throw new Error(`${Errors.ALL} must appear alone in an error list`); + } +} + /** * Render a list or return undefined for an empty list */ -export function renderList(xs: T[], fn: (x: T) => any): any { +export function renderList(xs: T[], mapFn: (x: T) => any, sortFn?: (a: T, b: T) => number): any { if (xs.length === 0) { return undefined; } - return xs.map(fn); + let list = xs; + if (sortFn) { + list = xs.sort(sortFn); + } + return list.map(mapFn); } /** diff --git a/packages/@aws-cdk/aws-stepfunctions/test/task-base.test.ts b/packages/@aws-cdk/aws-stepfunctions/test/task-base.test.ts index a57e78ae85c6d..eb14a414e20fc 100644 --- a/packages/@aws-cdk/aws-stepfunctions/test/task-base.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions/test/task-base.test.ts @@ -76,6 +76,69 @@ describe('Task base', () => { }); }); + test('States.ALL catch appears at end of list', () => { + // GIVEN + const httpFailure = new sfn.Fail(stack, 'http', { error: 'HTTP' }); + const otherFailure = new sfn.Fail(stack, 'other', { error: 'Other' }); + const allFailure = new sfn.Fail(stack, 'all'); + + // WHEN + task + .addCatch(httpFailure, { errors: ['HTTPError'] }) + .addCatch(allFailure) + .addCatch(otherFailure, { errors: ['OtherError'] }); + + // THEN + expect(render(task)).toEqual({ + StartAt: 'my-task', + States: { + 'all': { + Type: 'Fail', + }, + 'http': { + Error: 'HTTP', + Type: 'Fail', + }, + 'my-task': { + End: true, + Catch: [ + { + ErrorEquals: ['HTTPError'], + Next: 'http', + }, + { + ErrorEquals: ['OtherError'], + Next: 'other', + }, + { + ErrorEquals: ['States.ALL'], + Next: 'all', + }, + ], + Type: 'Task', + Resource: 'my-resource', + Parameters: { MyParameter: 'myParameter' }, + }, + 'other': { + Error: 'Other', + Type: 'Fail', + }, + }, + }); + }); + + test('addCatch throws when errors are combined with States.ALL', () => { + // GIVEN + const failure = new sfn.Fail(stack, 'failed', { + error: 'DidNotWork', + cause: 'We got stuck', + }); + + expect(() => task.addCatch(failure, { + errors: ['States.ALL', 'HTTPError'], + })).toThrow(/must appear alone/); + }); + test('add retry configuration', () => { // WHEN task.addRetry({ errors: ['HTTPError'], maxAttempts: 2 }) @@ -104,6 +167,44 @@ describe('Task base', () => { }); }); + test('States.ALL retry appears at end of list', () => { + // WHEN + task + .addRetry({ errors: ['HTTPError'] }) + .addRetry() + .addRetry({ errors: ['OtherError'] }); + + // THEN + expect(render(task)).toEqual({ + StartAt: 'my-task', + States: { + 'my-task': { + End: true, + Retry: [ + { + ErrorEquals: ['HTTPError'], + }, + { + ErrorEquals: ['OtherError'], + }, + { + ErrorEquals: ['States.ALL'], + }, + ], + Type: 'Task', + Resource: 'my-resource', + Parameters: { MyParameter: 'myParameter' }, + }, + }, + }); + }); + + test('addRetry throws when errors are combined with States.ALL', () => { + expect(() => task.addRetry({ + errors: ['States.ALL', 'HTTPError'], + })).toThrow(/must appear alone/); + }); + test('add a next state to the task in the chain', () => { // WHEN task.next(new sfn.Pass(stack, 'passState'));