Skip to content

Commit

Permalink
Merge pull request #243 from str4d/header-decoding-perf
Browse files Browse the repository at this point in the history
Header decoding performance improvements
  • Loading branch information
str4d authored Aug 8, 2021
2 parents e608cf3 + 0fe89b9 commit 8d02ba2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 89 deletions.
3 changes: 3 additions & 0 deletions age-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ to 1.0.0 are beta releases.
## [Unreleased]
### Changed
- MSRV is now 1.51.0.
- The `body` property of `age_core::format::AgeStanza` has been replaced by the
`AgeStanza::body` method, to enable enclosing parsers to defer Base64 decoding
until the very end.

## [0.6.0] - 2021-05-02
### Security
Expand Down
181 changes: 99 additions & 82 deletions age-core/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,30 @@ pub struct AgeStanza<'a> {
/// Zero or more arguments.
pub args: Vec<&'a str>,
/// The body of the stanza, containing a wrapped [`FileKey`].
pub body: Vec<u8>,
///
/// Represented as the set of Base64-encoded lines for efficiency (so the caller can
/// defer the cost of decoding until the structure containing this stanza has been
/// fully-parsed).
body: Vec<&'a [u8]>,
}

impl<'a> AgeStanza<'a> {
/// Decodes and returns the body of this stanza.
pub fn body(&self) -> Vec<u8> {
// An AgeStanza will always contain at least one chunk.
let (partial_chunk, full_chunks) = self.body.split_last().unwrap();

// This is faster than collecting from a flattened iterator.
let mut data = vec![0; full_chunks.len() * 64 + partial_chunk.len()];
for (i, chunk) in full_chunks.iter().enumerate() {
// These chunks are guaranteed to be full by construction.
data[i * 64..(i + 1) * 64].copy_from_slice(chunk);
}
data[full_chunks.len() * 64..].copy_from_slice(partial_chunk);

// The chunks are guaranteed to contain Base64 characters by construction.
base64::decode_config(&data, base64::STANDARD_NO_PAD).unwrap()
}
}

/// A section of the age header that encapsulates the file key as encrypted to a specific
Expand All @@ -55,10 +78,11 @@ pub struct Stanza {

impl From<AgeStanza<'_>> for Stanza {
fn from(stanza: AgeStanza<'_>) -> Self {
let body = stanza.body();
Stanza {
tag: stanza.tag.to_string(),
args: stanza.args.into_iter().map(|s| s.to_string()).collect(),
body: stanza.body,
body,
}
}
}
Expand Down Expand Up @@ -105,102 +129,66 @@ pub fn grease_the_joint() -> Stanza {
pub mod read {
use nom::{
branch::alt,
bytes::streaming::{tag, take_while, take_while1},
bytes::streaming::{tag, take_while1, take_while_m_n},
character::streaming::newline,
combinator::{map, map_opt, opt, verify},
multi::{many0, separated_list1},
combinator::{map, map_opt, opt},
multi::{many_till, separated_list1},
sequence::{pair, preceded, terminated},
IResult,
};

use super::{AgeStanza, STANZA_TAG};

fn is_base64_char(c: u8) -> bool {
// Check against the ASCII values of the standard Base64 character set.
match c {
// A..=Z | a..=z | 0..=9 | + | /
65..=90 | 97..=122 | 48..=57 | 43 | 47 => true,
_ => false,
}
}

/// From the age specification:
/// ```text
/// ... an arbitrary string is a sequence of ASCII characters with values 33 to 126.
/// ```
pub fn arbitrary_string(input: &[u8]) -> IResult<&[u8], &str> {
map(take_while1(|c| c >= 33 && c <= 126), |bytes| {
std::str::from_utf8(bytes).expect("ASCII is valid UTF-8")
})(input)
}

/// Returns the slice of input up to (but not including) the first LF
/// character, if that slice is entirely Base64 characters
///
/// # Errors
///
/// - Returns Incomplete(1) if a LF is not found.
fn take_b64_line(input: &[u8]) -> IResult<&[u8], &[u8]> {
verify(take_while(|c| c != b'\n'), |bytes: &[u8]| {
// STANDARD_NO_PAD only differs from STANDARD during serialization; the base64
// crate always allows padding during parsing. We require canonical
// serialization, so we explicitly reject padding characters here.
base64::decode_config(bytes, base64::STANDARD_NO_PAD).is_ok() && !bytes.contains(&b'=')
})(input)
}

/// Returns the slice of input up to (but not including) the first LF
/// character, if that slice is entirely Base64 characters
///
/// # Errors
///
/// - Returns Failure on an empty slice.
/// - Returns Incomplete(1) if a LF is not found.
fn take_b64_line1(input: &[u8]) -> IResult<&[u8], &[u8]> {
verify(take_while1(|c| c != b'\n'), |bytes: &[u8]| {
// STANDARD_NO_PAD only differs from STANDARD during serialization; the base64
// crate always allows padding during parsing. We require canonical
// serialization, so we explicitly reject padding characters here.
base64::decode_config(bytes, base64::STANDARD_NO_PAD).is_ok() && !bytes.contains(&b'=')
map(take_while1(|c| (33..=126).contains(&c)), |bytes| {
// Safety: ASCII bytes are valid UTF-8
unsafe { std::str::from_utf8_unchecked(bytes) }
})(input)
}

fn wrapped_encoded_data(input: &[u8]) -> IResult<&[u8], Vec<u8>> {
map_opt(
pair(
fn wrapped_encoded_data(input: &[u8]) -> IResult<&[u8], Vec<&[u8]>> {
map(
many_till(
// Any body lines before the last MUST be full-length.
many0(map_opt(terminated(take_b64_line, newline), |chunk| {
if chunk.len() != 64 {
None
} else {
Some(chunk)
}
})),
terminated(take_while_m_n(64, 64, is_base64_char), newline),
// Last body line MUST be short (empty if necessary).
map_opt(terminated(take_b64_line, newline), |chunk| {
if chunk.len() < 64 {
Some(chunk)
} else {
None
}
}),
terminated(take_while_m_n(0, 63, is_base64_char), newline),
),
|(full_chunks, partial_chunk)| {
let data: Vec<u8> = full_chunks
.into_iter()
.chain(Some(partial_chunk))
.flatten()
.cloned()
.collect();
base64::decode_config(&data, base64::STANDARD_NO_PAD).ok()
|(full_chunks, partial_chunk): (Vec<&[u8]>, &[u8])| {
let mut chunks = full_chunks;
chunks.push(partial_chunk);
chunks
},
)(input)
}

fn legacy_wrapped_encoded_data(input: &[u8]) -> IResult<&[u8], Vec<u8>> {
map_opt(separated_list1(newline, take_b64_line1), |chunks| {
// Enforce that the only chunk allowed to be shorter than 64 characters
// is the last chunk.
if chunks.iter().rev().skip(1).any(|s| s.len() != 64)
|| chunks.last().map(|s| s.len() > 64) == Some(true)
{
None
} else {
let data: Vec<u8> = chunks.into_iter().flatten().cloned().collect();
base64::decode_config(&data, base64::STANDARD_NO_PAD).ok()
}
})(input)
fn legacy_wrapped_encoded_data(input: &[u8]) -> IResult<&[u8], Vec<&[u8]>> {
map_opt(
separated_list1(newline, take_while1(is_base64_char)),
|chunks: Vec<&[u8]>| {
// Enforce that the only chunk allowed to be shorter than 64 characters
// is the last chunk.
let (partial_chunk, full_chunks) = chunks.split_last().unwrap();
if full_chunks.iter().any(|s| s.len() != 64) || partial_chunk.len() > 64 {
None
} else {
Some(chunks)
}
},
)(input)
}

/// Reads an age stanza.
Expand Down Expand Up @@ -240,7 +228,7 @@ pub mod read {
AgeStanza {
tag,
args,
body: body.unwrap_or_default(),
body: body.unwrap_or_else(|| vec![&[]]),
}
},
)(input)
Expand Down Expand Up @@ -276,8 +264,11 @@ pub mod read {

#[test]
fn base64_padding_rejected() {
assert!(take_b64_line(b"Tm8gcGFkZGluZyE\n").is_ok());
assert!(take_b64_line(b"Tm8gcGFkZGluZyE=\n").is_err());
assert!(wrapped_encoded_data(b"Tm8gcGFkZGluZyE\n").is_ok());
assert!(wrapped_encoded_data(b"Tm8gcGFkZGluZyE=\n").is_err());
// Internal padding is also rejected.
assert!(wrapped_encoded_data(b"SW50ZXJuYWwUGFk\n").is_ok());
assert!(wrapped_encoded_data(b"SW50ZXJuYWw=UGFk\n").is_err());
}
}
}
Expand Down Expand Up @@ -356,7 +347,7 @@ C3ZAeY64NXS4QFrksLm3EGz+uPRyI0eQsWw7LWbbYig
let (_, stanza) = read::age_stanza(test_stanza.as_bytes()).unwrap();
assert_eq!(stanza.tag, test_tag);
assert_eq!(stanza.args, test_args);
assert_eq!(stanza.body, test_body);
assert_eq!(stanza.body(), test_body);

let mut buf = vec![];
cookie_factory::gen_simple(write::age_stanza(test_tag, test_args, &test_body), &mut buf)
Expand All @@ -378,7 +369,7 @@ C3ZAeY64NXS4QFrksLm3EGz+uPRyI0eQsWw7LWbbYig
let (_, stanza) = read::age_stanza(test_stanza.as_bytes()).unwrap();
assert_eq!(stanza.tag, test_tag);
assert_eq!(stanza.args, test_args);
assert_eq!(stanza.body, test_body);
assert_eq!(stanza.body(), test_body);

let mut buf = vec![];
cookie_factory::gen_simple(write::age_stanza(test_tag, test_args, test_body), &mut buf)
Expand All @@ -405,11 +396,37 @@ xD7o4VEOu1t7KZQ1gDgq2FPzBEeSRqbnqvQEXdLRYy143BxR6oFxsUUJCRB0ErXA
let (_, stanza) = read::age_stanza(test_stanza.as_bytes()).unwrap();
assert_eq!(stanza.tag, test_tag);
assert_eq!(stanza.args, test_args);
assert_eq!(stanza.body, test_body);
assert_eq!(stanza.body(), test_body);

let mut buf = vec![];
cookie_factory::gen_simple(write::age_stanza(test_tag, test_args, &test_body), &mut buf)
.unwrap();
assert_eq!(buf, test_stanza.as_bytes());
}

#[test]
fn age_stanza_with_legacy_full_body() {
let test_tag = "full-body";
let test_args = &["some", "arguments"];
let test_body = base64::decode_config(
"xD7o4VEOu1t7KZQ1gDgq2FPzBEeSRqbnqvQEXdLRYy143BxR6oFxsUUJCRB0ErXA",
base64::STANDARD_NO_PAD,
)
.unwrap();

// The body fills a complete line, but lacks a trailing empty line.
let test_stanza = "-> full-body some arguments
xD7o4VEOu1t7KZQ1gDgq2FPzBEeSRqbnqvQEXdLRYy143BxR6oFxsUUJCRB0ErXA
--- header end
";

// The normal parser returns an error.
assert!(read::age_stanza(test_stanza.as_bytes()).is_err());

// We can parse with the legacy parser
let (_, stanza) = read::legacy_age_stanza(test_stanza.as_bytes()).unwrap();
assert_eq!(stanza.tag, test_tag);
assert_eq!(stanza.args, test_args);
assert_eq!(stanza.body(), test_body);
}
}
1 change: 0 additions & 1 deletion age-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![forbid(unsafe_code)]
// Catch documentation errors caused by code changes.
#![deny(broken_intra_doc_links)]

Expand Down
8 changes: 2 additions & 6 deletions age/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,12 @@ mod read {
use super::*;
use crate::util::read::base64_arg;

fn recipient_stanza(input: &[u8]) -> IResult<&[u8], Stanza> {
map(legacy_age_stanza, Stanza::from)(input)
}

fn header_v1(input: &[u8]) -> IResult<&[u8], HeaderV1> {
preceded(
pair(tag(V1_MAGIC), newline),
map(
pair(
many1(recipient_stanza),
many1(legacy_age_stanza),
preceded(
pair(tag(MAC_TAG), tag(b" ")),
terminated(
Expand All @@ -184,7 +180,7 @@ mod read {
),
),
|(recipients, mac)| HeaderV1 {
recipients,
recipients: recipients.into_iter().map(Stanza::from).collect(),
mac,
encoded_bytes: None,
},
Expand Down

0 comments on commit 8d02ba2

Please sign in to comment.