Skip to content

Commit

Permalink
Use AlignmentBuffer<> for MDx_HashFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
reneme committed Sep 12, 2023
1 parent 78c09ba commit 53a1765
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 34 deletions.
55 changes: 23 additions & 32 deletions src/lib/hash/mdx_hash/mdx_hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ MDx_HashFunction::MDx_HashFunction(size_t block_len, bool byte_big_endian, bool
m_block_bits(ceil_log2(block_len)),
m_count_big_endian(byte_big_endian),
m_count(0),
m_buffer(block_len),
m_position(0) {
m_buffer(block_len) {
if(!is_power_of_2(block_len)) {
throw Invalid_Argument("MDx_HashFunction block length must be a power of 2");
}
Expand All @@ -40,8 +39,8 @@ MDx_HashFunction::MDx_HashFunction(size_t block_len, bool byte_big_endian, bool
* Clear memory of sensitive data
*/
void MDx_HashFunction::clear() {
zeroise(m_buffer);
m_count = m_position = 0;
m_buffer.clear();
m_count = 0;
}

/*
Expand All @@ -50,57 +49,49 @@ void MDx_HashFunction::clear() {
void MDx_HashFunction::add_data(std::span<const uint8_t> input) {
const size_t block_len = static_cast<size_t>(1) << m_block_bits;

m_count += input.size();

if(m_position) {
buffer_insert(m_buffer, m_position, input.data(), input.size());
BufferSlicer in(input);

if(m_position + input.size() >= block_len) {
compress_n(m_buffer.data(), 1);
input = input.last(input.size() - block_len + m_position);
m_position = 0;
while(!in.empty()) {
if(m_buffer.store_unaligned_data(in); m_buffer.ready_to_consume()) {
compress_n(m_buffer.take().data(), 1);
}
}

// Just in case the compiler can't figure out block_len is a power of 2
const size_t full_blocks = input.size() >> m_block_bits;

BufferSlicer in(input);
if(full_blocks > 0) {
compress_n(in.take(full_blocks * block_len).data(), full_blocks);
if(m_buffer.in_alignment()) {
// Just in case the compiler can't figure out block_len is a power of 2
const size_t full_blocks = in.remaining() >> m_block_bits;
if(full_blocks > 0) {
compress_n(in.take(full_blocks * block_len).data(), full_blocks);
}
}
}

const auto remaining = in.take(in.remaining());
buffer_insert(m_buffer, m_position, remaining.data(), remaining.size());
m_position += remaining.size();
m_count += input.size();
}

/*
* Finalize a hash
*/
void MDx_HashFunction::final_result(std::span<uint8_t> output) {
const size_t block_len = static_cast<size_t>(1) << m_block_bits;

clear_mem(&m_buffer[m_position], block_len - m_position);
m_buffer[m_position] = m_pad_char;
m_buffer.append({&m_pad_char, 1});

if(m_position >= block_len - m_counter_size) {
compress_n(m_buffer.data(), 1);
zeroise(m_buffer);
if(m_buffer.elements_until_alignment() < m_counter_size) {
m_buffer.fill_up_with_zeros();
compress_n(m_buffer.take().data(), 1);
}

BOTAN_ASSERT_NOMSG(m_counter_size <= output_length());
BOTAN_ASSERT_NOMSG(m_counter_size >= 8);

const uint64_t bit_count = m_count * 8;

m_buffer.fill_up_with_zeros();
if(m_count_big_endian) {
store_be(bit_count, &m_buffer[block_len - 8]);
store_be(bit_count, m_buffer.directly_modify_last(8).data());
} else {
store_le(bit_count, &m_buffer[block_len - 8]);
store_le(bit_count, m_buffer.directly_modify_last(8).data());
}

compress_n(m_buffer.data(), 1);
compress_n(m_buffer.take().data(), 1);
copy_out(output.data());
clear();
}
Expand Down
5 changes: 3 additions & 2 deletions src/lib/hash/mdx_hash/mdx_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <botan/hash.h>

#include <botan/internal/alignment_buffer.h>

namespace Botan {

/**
Expand Down Expand Up @@ -54,8 +56,7 @@ class MDx_HashFunction : public HashFunction {
const bool m_count_big_endian;

uint64_t m_count;
secure_vector<uint8_t> m_buffer;
size_t m_position;
AlignmentBuffer<secure_vector<uint8_t>> m_buffer;
};

} // namespace Botan
Expand Down

0 comments on commit 53a1765

Please sign in to comment.