diff --git a/aes.c b/aes.c index e2e420f..505bc80 100644 --- a/aes.c +++ b/aes.c @@ -69,6 +69,8 @@ void aes_encrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) { /* Prepare context */ mbedtls_cipher_reset(&ctx->cipher_enc); + int offset = 0; + /* XTS doesn't need per-block updating */ if (mbedtls_cipher_get_cipher_mode(&ctx->cipher_enc) == MBEDTLS_MODE_XTS || mbedtls_cipher_get_cipher_mode(&ctx->cipher_enc) == MBEDTLS_MODE_CBC) mbedtls_cipher_update(&ctx->cipher_enc, (const unsigned char * )src, l, (unsigned char *)dst, &out_len); @@ -77,7 +79,7 @@ void aes_encrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) { unsigned int blk_size = mbedtls_cipher_get_block_size(&ctx->cipher_enc); /* Do per-block updating */ - for (int offset = 0; (unsigned int)offset < l; offset += blk_size) + for (; (unsigned int)offset < l; offset += blk_size) { int len = ((unsigned int)(l - offset) > blk_size) ? blk_size : (unsigned int) (l - offset); mbedtls_cipher_update(&ctx->cipher_enc, (const unsigned char * )src + offset, len, (unsigned char *)dst + offset, &out_len); @@ -85,7 +87,7 @@ void aes_encrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) { } /* Flush all data */ - mbedtls_cipher_finish(&ctx->cipher_enc, NULL, NULL); + mbedtls_cipher_finish(&ctx->cipher_enc, (unsigned char * )dst + offset + out_len, &out_len); } /* Decrypt with context. */ @@ -109,6 +111,8 @@ void aes_decrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) /* Prepare context */ mbedtls_cipher_reset(&ctx->cipher_dec); + int offset = 0; + /* XTS doesn't need per-block updating */ if (mbedtls_cipher_get_cipher_mode(&ctx->cipher_dec) == MBEDTLS_MODE_XTS || mbedtls_cipher_get_cipher_mode(&ctx->cipher_enc) == MBEDTLS_MODE_CBC) mbedtls_cipher_update(&ctx->cipher_dec, (const unsigned char * )src, l, (unsigned char *)dst, &out_len); @@ -117,7 +121,7 @@ void aes_decrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) unsigned int blk_size = mbedtls_cipher_get_block_size(&ctx->cipher_dec); /* Do per-block updating */ - for (int offset = 0; (unsigned int)offset < l; offset += blk_size) + for (; (unsigned int)offset < l; offset += blk_size) { int len = ((unsigned int)(l - offset) > blk_size) ? blk_size : (unsigned int) (l - offset); mbedtls_cipher_update(&ctx->cipher_dec, (const unsigned char * )src + offset, len, (unsigned char *)dst + offset, &out_len); @@ -125,7 +129,7 @@ void aes_decrypt(aes_ctx_t *ctx, void *dst, const void *src, size_t l) } /* Flush all data */ - mbedtls_cipher_finish(&ctx->cipher_dec, NULL, NULL); + mbedtls_cipher_finish(&ctx->cipher_dec, (unsigned char * )dst + offset + out_len, &out_len); if (src_equals_dst) {