From 534a8b3714e3fbd5add1994aa0f887d77d8cd248 Mon Sep 17 00:00:00 2001 From: "Yang, Longlong" Date: Tue, 11 Jun 2024 20:28:23 -0400 Subject: [PATCH] remove unchecked unwrap. fix #99 Signed-off-by: Yang, Longlong --- spdmlib/src/common/key_schedule.rs | 35 ++++++- spdmlib/src/common/mod.rs | 27 +++-- spdmlib/src/error.rs | 2 +- spdmlib/src/requester/finish_req.rs | 25 +++-- spdmlib/src/requester/key_exchange_req.rs | 31 ++++-- spdmlib/src/requester/psk_exchange_req.rs | 28 +++-- spdmlib/src/requester/psk_finish_req.rs | 22 ++-- spdmlib/src/responder/challenge_rsp.rs | 11 +- spdmlib/src/responder/context.rs | 15 ++- spdmlib/src/responder/digest_rsp.rs | 10 +- spdmlib/src/responder/finish_rsp.rs | 91 +++++++++++----- spdmlib/src/responder/key_exchange_rsp.rs | 100 ++++++++++++++---- spdmlib/src/responder/key_update_rsp.rs | 10 +- spdmlib/src/responder/measurement_rsp.rs | 42 ++++++-- spdmlib/src/responder/psk_exchange_rsp.rs | 120 +++++++++++++++++----- spdmlib/src/responder/psk_finish_rsp.rs | 65 +++++++++--- 16 files changed, 480 insertions(+), 154 deletions(-) diff --git a/spdmlib/src/common/key_schedule.rs b/spdmlib/src/common/key_schedule.rs index 98634de..35adfe7 100644 --- a/spdmlib/src/common/key_schedule.rs +++ b/spdmlib/src/common/key_schedule.rs @@ -120,10 +120,15 @@ impl SpdmKeySchedule { return None; } } else { + let empty_pskhint = SpdmPskHintStruct::default(); secret::psk::handshake_secret_hkdf_expand( spdm_version, hash_algo, - psk_hint.unwrap(), + if let Some(hint) = psk_hint { + hint + } else { + &empty_pskhint + }, bin_str1, )? }; @@ -162,10 +167,15 @@ impl SpdmKeySchedule { return None; } } else { + let empty_pskhint = SpdmPskHintStruct::default(); secret::psk::handshake_secret_hkdf_expand( spdm_version, hash_algo, - psk_hint.unwrap(), + if let Some(hint) = psk_hint { + hint + } else { + &empty_pskhint + }, bin_str2, )? }; @@ -295,10 +305,15 @@ impl SpdmKeySchedule { return None; } } else { + let empty_pskhint = SpdmPskHintStruct::default(); secret::psk::master_secret_hkdf_expand( spdm_version, hash_algo, - psk_hint.unwrap(), + if let Some(hint) = psk_hint { + hint + } else { + &empty_pskhint + }, bin_str3, )? }; @@ -337,10 +352,15 @@ impl SpdmKeySchedule { return None; } } else { + let empty_pskhint = SpdmPskHintStruct::default(); secret::psk::master_secret_hkdf_expand( spdm_version, hash_algo, - psk_hint.unwrap(), + if let Some(hint) = psk_hint { + hint + } else { + &empty_pskhint + }, bin_str4, )? }; @@ -378,10 +398,15 @@ impl SpdmKeySchedule { return None; } } else { + let empty_pskhint = SpdmPskHintStruct::default(); secret::psk::master_secret_hkdf_expand( spdm_version, hash_algo, - psk_hint.unwrap(), + if let Some(hint) = psk_hint { + hint + } else { + &empty_pskhint + }, bin_str8, )? }; diff --git a/spdmlib/src/common/mod.rs b/spdmlib/src/common/mod.rs index 41a3c9b..473d55f 100644 --- a/spdmlib/src/common/mod.rs +++ b/spdmlib/src/common/mod.rs @@ -269,8 +269,7 @@ impl SpdmContext { crypto::cert_operation::get_cert_from_cert_chain( &cert_chain.data[..(cert_chain.data_size as usize)], 0, - ) - .unwrap(); + )?; let root_cert = &cert_chain.data[root_cert_begin..root_cert_end]; if let Some(root_hash) = crypto::hash::hash_all(self.negotiate_info.base_hash_sel, root_cert) @@ -509,7 +508,9 @@ impl SpdmContext { } pub fn append_message_k(&mut self, session_id: u32, new_message: &[u8]) -> SpdmResult { - let session = self.get_session_via_id(session_id).unwrap(); + let session = self + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; #[cfg(not(feature = "hashed-transcript-data"))] { @@ -574,7 +575,9 @@ impl SpdmContext { session_id: u32, new_message: &[u8], ) -> SpdmResult { - let session = self.get_session_via_id(session_id).unwrap(); + let session = self + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let _ = session .runtime_info .message_f @@ -590,7 +593,9 @@ impl SpdmContext { session_id: u32, new_message: &[u8], ) -> SpdmResult { - let session = self.get_immutable_session_via_id(session_id).unwrap(); + let session = self + .get_immutable_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; if session.runtime_info.digest_context_th.is_none() { return Err(SPDM_STATUS_INVALID_STATE_LOCAL); } @@ -631,18 +636,24 @@ impl SpdmContext { }; if let Some(mut_cert_digest) = mut_cert_digest { - let session = self.get_session_via_id(session_id).unwrap(); + let session = self + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; crypto::hash::hash_ctx_update( session.runtime_info.digest_context_th.as_mut().unwrap(), &mut_cert_digest.data[..mut_cert_digest.data_size as usize], )?; } - let session = self.get_session_via_id(session_id).unwrap(); + let session = self + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.runtime_info.message_f_initialized = true; } - let session = self.get_session_via_id(session_id).unwrap(); + let session = self + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; crypto::hash::hash_ctx_update( session.runtime_info.digest_context_th.as_mut().unwrap(), new_message, diff --git a/spdmlib/src/error.rs b/spdmlib/src/error.rs index d6ee8c1..3361e30 100644 --- a/spdmlib/src/error.rs +++ b/spdmlib/src/error.rs @@ -311,7 +311,7 @@ impl Codec for SpdmStatus { let mut sc = 0u32; sc += (((self.severity as u8) & 0x0F) as u32) << 28; sc += >::try_into(self.status_code) - .unwrap() //due to the design of encode, panic is allowed + .map_err(|_| codec::EncodeErr)? .get(); sc.encode(bytes)?; Ok(4) diff --git a/spdmlib/src/requester/finish_req.rs b/spdmlib/src/requester/finish_req.rs index 58d0695..1478045 100644 --- a/spdmlib/src/requester/finish_req.rs +++ b/spdmlib/src/requester/finish_req.rs @@ -77,7 +77,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return Err(res.err().unwrap()); } @@ -92,7 +92,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return res; } @@ -107,7 +107,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return Err(res.err().unwrap()); } @@ -186,13 +186,16 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let transcript_hash = self.common .calc_req_transcript_hash(false, req_slot_id, is_mut_auth, session)?; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let hmac = session.generate_hmac_with_request_finished_key(transcript_hash.as_ref())?; @@ -253,7 +256,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let transcript_hash = self.common.calc_req_transcript_hash( false, @@ -291,7 +294,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // generate the data secret let th2 = self.common.calc_req_transcript_hash( @@ -303,7 +306,10 @@ impl RequesterContext { debug!("!!! th2 : {:02x?}\n", th2.as_ref()); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; match session.generate_data_secret(spdm_version_sel, &th2) { Ok(_) => {} Err(e) => { @@ -424,8 +430,7 @@ impl RequesterContext { peer_cert, transcript_sign.as_ref(), &signature, - ) - .unwrap(); + )?; Ok(signature) } diff --git a/spdmlib/src/requester/key_exchange_req.rs b/spdmlib/src/requester/key_exchange_req.rs index 325e732..d784d5d 100644 --- a/spdmlib/src/requester/key_exchange_req.rs +++ b/spdmlib/src/requester/key_exchange_req.rs @@ -12,7 +12,6 @@ use crate::error::SPDM_STATUS_CRYPTO_ERROR; use crate::error::SPDM_STATUS_ERROR_PEER; use crate::error::SPDM_STATUS_INVALID_MSG_FIELD; use crate::error::SPDM_STATUS_INVALID_PARAMETER; -#[cfg(feature = "hashed-transcript-data")] use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL; use crate::error::SPDM_STATUS_SESSION_NUMBER_EXCEED; use crate::error::SPDM_STATUS_VERIF_FAIL; @@ -296,7 +295,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // verify signature if self @@ -321,7 +320,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // generate the handshake secret (including finished_key) before verify HMAC let th1 = self @@ -329,14 +328,17 @@ impl RequesterContext { .calc_req_transcript_hash(false, slot_id, false, session)?; debug!("!!! th1 : {:02x?}\n", th1.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.generate_handshake_secret(spdm_version_sel, &th1)?; if !in_clear_text { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // verify HMAC with finished_key let transcript_hash = self @@ -346,7 +348,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; if session .verify_hmac_with_response_finished_key( @@ -356,8 +358,10 @@ impl RequesterContext { .is_err() { error!("verify_hmac_with_response_finished_key fail"); - let session = - self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.teardown(); return Err(SPDM_STATUS_VERIF_FAIL); } else { @@ -373,15 +377,20 @@ impl RequesterContext { ) .is_err() { - let session = - self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.teardown(); return Err(SPDM_STATUS_BUFFER_FULL); } } // append verify_data after TH1 - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.secure_spdm_version_sel = secure_spdm_version_sel; session.heartbeat_period = key_exchange_rsp.heartbeat_period; diff --git a/spdmlib/src/requester/psk_exchange_req.rs b/spdmlib/src/requester/psk_exchange_req.rs index 5490fb5..19c24a5 100644 --- a/spdmlib/src/requester/psk_exchange_req.rs +++ b/spdmlib/src/requester/psk_exchange_req.rs @@ -5,11 +5,11 @@ use config::MAX_SPDM_PSK_CONTEXT_SIZE; use crate::crypto; -use crate::error::SPDM_STATUS_BUFFER_FULL; use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, SPDM_STATUS_INVALID_PARAMETER, SPDM_STATUS_SESSION_NUMBER_EXCEED, SPDM_STATUS_VERIF_FAIL, }; +use crate::error::{SPDM_STATUS_BUFFER_FULL, SPDM_STATUS_INVALID_STATE_LOCAL}; use crate::message::*; use crate::protocol::*; use crate::requester::*; @@ -217,7 +217,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // generate the handshake secret (including finished_key) before verify HMAC let th1 = self.common.calc_req_transcript_hash( @@ -228,13 +228,16 @@ impl RequesterContext { )?; debug!("!!! th1 : {:02x?}\n", th1.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.generate_handshake_secret(spdm_version_sel, &th1)?; let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; // verify HMAC with finished_key let transcript_hash = self.common.calc_req_transcript_hash( @@ -257,7 +260,10 @@ impl RequesterContext { .is_err() { error!("verify_hmac_with_response_finished_key fail"); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.teardown(); return Err(SPDM_STATUS_VERIF_FAIL); } else { @@ -289,7 +295,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let psk_without_context = self .common .negotiate_info @@ -306,14 +312,20 @@ impl RequesterContext { debug!("!!! th2 : {:02x?}\n", th2.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.generate_data_secret(spdm_version_sel, &th2)?; session.set_session_state( crate::common::session::SpdmSessionState::SpdmSessionEstablished, ); } - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.secure_spdm_version_sel = secure_spdm_version_sel; session.heartbeat_period = psk_exchange_rsp.heartbeat_period; diff --git a/spdmlib/src/requester/psk_finish_req.rs b/spdmlib/src/requester/psk_finish_req.rs index 0779465..a6f345a 100644 --- a/spdmlib/src/requester/psk_finish_req.rs +++ b/spdmlib/src/requester/psk_finish_req.rs @@ -4,7 +4,7 @@ use crate::error::{ SpdmResult, SPDM_STATUS_ERROR_PEER, SPDM_STATUS_INVALID_MSG_FIELD, - SPDM_STATUS_INVALID_PARAMETER, + SPDM_STATUS_INVALID_PARAMETER, SPDM_STATUS_INVALID_STATE_LOCAL, }; use crate::message::*; use crate::protocol::*; @@ -44,7 +44,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return Err(res.err().unwrap()); } @@ -55,7 +55,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return res; } @@ -67,7 +67,7 @@ impl RequesterContext { if res.is_err() { self.common .get_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .teardown(); return Err(res.err().unwrap()); } @@ -108,12 +108,15 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let transcript_hash = self.common .calc_req_transcript_hash(true, INVALID_SLOT, false, session)?; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let hmac = session.generate_hmac_with_request_finished_key(transcript_hash.as_ref())?; self.common @@ -153,7 +156,7 @@ impl RequesterContext { let session = self .common .get_immutable_session_via_id(session_id) - .unwrap(); + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; let th2 = self.common.calc_req_transcript_hash( true, @@ -164,7 +167,10 @@ impl RequesterContext { debug!("!!! th2 : {:02x?}\n", th2.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.generate_data_secret(spdm_version_sel, &th2)?; session.set_session_state( crate::common::session::SpdmSessionState::SpdmSessionEstablished, diff --git a/spdmlib/src/responder/challenge_rsp.rs b/spdmlib/src/responder/challenge_rsp.rs index 727c1f7..0cd85a0 100644 --- a/spdmlib/src/responder/challenge_rsp.rs +++ b/spdmlib/src/responder/challenge_rsp.rs @@ -147,8 +147,13 @@ impl ResponderContext { let cert_chain_hash = crypto::hash::hash_all( self.common.negotiate_info.base_hash_sel, my_cert_chain.as_ref(), - ) - .unwrap(); + ); + let cert_chain_hash = if let Some(hash) = cert_chain_hash { + hash + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice())); + }; let mut nonce = [0u8; SPDM_NONCE_SIZE]; let res = crypto::rand::get_random(&mut nonce); @@ -238,7 +243,7 @@ impl ResponderContext { .digest_context_m1m2 .as_ref() .cloned() - .unwrap(), + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?, ) .ok_or(SPDM_STATUS_CRYPTO_ERROR)?; diff --git a/spdmlib/src/responder/context.rs b/spdmlib/src/responder/context.rs index 9b3aaf7..5a9b6bc 100644 --- a/spdmlib/src/responder/context.rs +++ b/spdmlib/src/responder/context.rs @@ -127,7 +127,10 @@ impl ResponderContext { }; let heartbeat_period = { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.set_session_state( crate::common::session::SpdmSessionState::SpdmSessionEstablished, ); @@ -151,7 +154,10 @@ impl ResponderContext { self.common.runtime_info.set_last_session_id(None); } else if opcode == SpdmRequestResponseCode::SpdmResponseEndSessionAck.get_u8() { - let session = self.common.get_session_via_id(session_id.unwrap()).unwrap(); + let session = self + .common + .get_session_via_id(session_id.unwrap()) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.teardown(); } else if (opcode == SpdmRequestResponseCode::SpdmResponseFinishRsp.get_u8() || opcode == SpdmRequestResponseCode::SpdmResponsePskFinishRsp.get_u8()) @@ -161,7 +167,10 @@ impl ResponderContext { let session_id = session_id.unwrap(); let heartbeat_period = { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = self + .common + .get_session_via_id(session_id) + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)?; session.set_session_state( crate::common::session::SpdmSessionState::SpdmSessionEstablished, ); diff --git a/spdmlib/src/responder/digest_rsp.rs b/spdmlib/src/responder/digest_rsp.rs index 0de2df3..86599a2 100644 --- a/spdmlib/src/responder/digest_rsp.rs +++ b/spdmlib/src/responder/digest_rsp.rs @@ -5,6 +5,7 @@ use crate::common::SpdmCodec; use crate::common::SpdmConnectionState; use crate::crypto; +use crate::error::SPDM_STATUS_CRYPTO_ERROR; use crate::error::SPDM_STATUS_INVALID_MSG_FIELD; use crate::error::SPDM_STATUS_INVALID_STATE_LOCAL; use crate::error::SPDM_STATUS_INVALID_STATE_PEER; @@ -136,8 +137,13 @@ impl ResponderContext { let cert_chain_hash = crypto::hash::hash_all( self.common.negotiate_info.base_hash_sel, my_cert_chain.as_ref(), - ) - .unwrap(); + ); + + let cert_chain_hash = if let Some(hash) = cert_chain_hash { + hash + } else { + return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice())); + }; // patch the message before send let used = writer.used(); diff --git a/spdmlib/src/responder/finish_rsp.rs b/spdmlib/src/responder/finish_rsp.rs index 7a4993d..b5800a4 100644 --- a/spdmlib/src/responder/finish_rsp.rs +++ b/spdmlib/src/responder/finish_rsp.rs @@ -92,11 +92,17 @@ impl ResponderContext { ); } - let mut_auth_attributes = self - .common - .get_immutable_session_via_id(session_id) - .unwrap() - .get_mut_auth_requested(); + let mut_auth_attributes = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session.get_mut_auth_requested() + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; + let finish_request_attributes = finish_req.finish_request_attributes; if (!mut_auth_attributes.is_empty() @@ -114,10 +120,16 @@ impl ResponderContext { let is_mut_auth = !mut_auth_attributes.is_empty(); if is_mut_auth { - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if self .verify_finish_req_signature(&finish_req.signature, session) @@ -146,7 +158,15 @@ impl ResponderContext { let base_hash_size = self.common.negotiate_info.base_hash_sel.get_size() as usize; { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if session.get_use_psk() { self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); @@ -156,10 +176,16 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let slot_id = session.get_slot_id(); @@ -257,10 +283,16 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let slot_id = session.get_slot_id(); @@ -308,10 +340,15 @@ impl ResponderContext { } // generate the data secret - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let slot_id = session.get_slot_id(); let th2 = self .common @@ -324,7 +361,15 @@ impl ResponderContext { let th2 = th2.unwrap(); debug!("!!! th2 : {:02x?}\n", th2.as_ref()); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if let Err(e) = session.generate_data_secret(spdm_version_sel, &th2) { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return (Err(e), Some(writer.used_slice())); diff --git a/spdmlib/src/responder/key_exchange_rsp.rs b/spdmlib/src/responder/key_exchange_rsp.rs index eb69554..41c55fe 100644 --- a/spdmlib/src/responder/key_exchange_rsp.rs +++ b/spdmlib/src/responder/key_exchange_rsp.rs @@ -218,7 +218,12 @@ impl ResponderContext { .set_local_used_cert_chain_slot_id(key_exchange_req.slot_id); let (exchange, key_exchange_context) = - crypto::dhe::generate_key_pair(self.common.negotiate_info.dhe_sel).unwrap(); + if let Some(kp) = crypto::dhe::generate_key_pair(self.common.negotiate_info.dhe_sel) { + kp + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice())); + }; debug!("!!! exchange data : {:02x?}\n", exchange); @@ -394,10 +399,15 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let signature = self.generate_key_exchange_rsp_signature(slot_id as u8, session); if signature.is_err() { @@ -418,10 +428,15 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; // generate the handshake secret (including finished_key) before generate HMAC let th1 = self @@ -434,17 +449,31 @@ impl ResponderContext { let th1 = th1.unwrap(); debug!("!!! th1 : {:02x?}\n", th1.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if let Err(e) = session.generate_handshake_secret(spdm_version_sel, &th1) { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return (Err(e), Some(writer.used_slice())); } if !in_clear_text { - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; // generate HMAC with finished_key let transcript_hash = @@ -456,7 +485,15 @@ impl ResponderContext { } let transcript_hash = transcript_hash.unwrap(); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let hmac = session.generate_hmac_with_response_finished_key(transcript_hash.as_ref()); if hmac.is_err() { @@ -472,7 +509,15 @@ impl ResponderContext { .append_message_k(session_id, hmac.as_ref()) .is_err() { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.teardown(); self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return ( @@ -489,14 +534,29 @@ impl ResponderContext { } let heartbeat_period = self.common.config_info.heartbeat_period; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.heartbeat_period = heartbeat_period; if return_opaque.data_size != 0 { - session.secure_spdm_version_sel = SecuredMessageVersion::try_from( + session.secure_spdm_version_sel = if let Ok(svs) = SecuredMessageVersion::try_from( return_opaque.data[return_opaque.data_size as usize - 1], - ) - .unwrap(); + ) { + svs + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_MSG_FIELD), + Some(writer.used_slice()), + ); + }; } session.set_session_state(crate::common::session::SpdmSessionState::SpdmSessionHandshaking); diff --git a/spdmlib/src/responder/key_update_rsp.rs b/spdmlib/src/responder/key_update_rsp.rs index 3373052..780d907 100644 --- a/spdmlib/src/responder/key_update_rsp.rs +++ b/spdmlib/src/responder/key_update_rsp.rs @@ -63,7 +63,15 @@ impl ResponderContext { let key_update_req = key_update_req.unwrap(); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; match key_update_req.key_update_operation { SpdmKeyUpdateOperation::SpdmUpdateSingleKey => { let _ = session.create_data_secret_update(spdm_version_sel, true, false); diff --git a/spdmlib/src/responder/measurement_rsp.rs b/spdmlib/src/responder/measurement_rsp.rs index 0b5de6b..6d662c0 100644 --- a/spdmlib/src/responder/measurement_rsp.rs +++ b/spdmlib/src/responder/measurement_rsp.rs @@ -150,25 +150,38 @@ impl ResponderContext { ); } - let number_of_measurement = secret::measurement::measurement_collection( + let number_of_measurement = if let Some(meas) = secret::measurement::measurement_collection( spdm_version_sel, measurement_specification_sel, measurement_hash_sel, SpdmMeasurementOperation::SpdmMeasurementQueryTotalNumber.get_u8() as usize, - ) - .unwrap() - .number_of_blocks; + ) { + meas.number_of_blocks + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let measurement_record = if get_measurements.measurement_operation == SpdmMeasurementOperation::SpdmMeasurementRequestAll { - secret::measurement::measurement_collection( + if let Some(meas) = secret::measurement::measurement_collection( spdm_version_sel, measurement_specification_sel, measurement_hash_sel, SpdmMeasurementOperation::SpdmMeasurementRequestAll.get_u8() as usize, - ) - .unwrap() + ) { + meas + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + } } else if let SpdmMeasurementOperation::Unknown(index) = get_measurements.measurement_operation { @@ -179,13 +192,20 @@ impl ResponderContext { Some(writer.used_slice()), ); } - secret::measurement::measurement_collection( + if let Some(meas) = secret::measurement::measurement_collection( spdm_version_sel, measurement_specification_sel, measurement_hash_sel, index as usize, - ) - .unwrap() + ) { + meas + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + } } else { SpdmMeasurementRecordStructure::default() }; @@ -299,7 +319,7 @@ impl ResponderContext { Some(session_id) => crypto::hash::hash_ctx_finalize( self.common .get_immutable_session_via_id(session_id) - .unwrap() + .ok_or(SPDM_STATUS_INVALID_STATE_LOCAL)? .runtime_info .digest_context_l1l2 .as_ref() diff --git a/spdmlib/src/responder/psk_exchange_rsp.rs b/spdmlib/src/responder/psk_exchange_rsp.rs index d556121..b0a2b76 100644 --- a/spdmlib/src/responder/psk_exchange_rsp.rs +++ b/spdmlib/src/responder/psk_exchange_rsp.rs @@ -326,10 +326,15 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; // create session - generate the handshake secret (including finished_key) let th1 = self @@ -342,16 +347,29 @@ impl ResponderContext { let th1 = th1.unwrap(); debug!("!!! th1 : {:02x?}\n", th1.as_ref()); - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if let Err(e) = session.generate_handshake_secret(spdm_version_sel, &th1) { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return (Err(e), Some(writer.used_slice())); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; // generate HMAC with finished_key let transcript_hash = self.common @@ -364,7 +382,15 @@ impl ResponderContext { let hmac = session.generate_hmac_with_response_finished_key(transcript_hash.as_ref()); if hmac.is_err() { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.teardown(); self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice())); @@ -377,7 +403,15 @@ impl ResponderContext { .append_message_k(session_id, hmac.as_ref()) .is_err() { - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.teardown(); self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return ( @@ -389,13 +423,26 @@ impl ResponderContext { // patch the message before send writer.mut_used_slice()[(used - base_hash_size)..used].copy_from_slice(hmac.as_ref()); let heartbeat_period = self.common.config_info.heartbeat_period; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.set_session_state(crate::common::session::SpdmSessionState::SpdmSessionHandshaking); - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if psk_without_context { // generate the data secret directly to skip PSK_FINISH @@ -410,10 +457,22 @@ impl ResponderContext { debug!("!!! th2 : {:02x?}\n", th2.as_ref()); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; let heartbeat_period = { - let session = self.common.get_session_via_id(session_id).unwrap(); - session + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; + if session .generate_data_secret(spdm_version_sel, &th2) - .unwrap(); + .is_err() + { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return (Err(SPDM_STATUS_CRYPTO_ERROR), Some(writer.used_slice())); + } session.set_session_state( crate::common::session::SpdmSessionState::SpdmSessionEstablished, ); @@ -435,13 +494,28 @@ impl ResponderContext { } } - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; session.heartbeat_period = heartbeat_period; if return_opaque.data_size != 0 { - session.secure_spdm_version_sel = SecuredMessageVersion::try_from( + session.secure_spdm_version_sel = if let Ok(ssvs) = SecuredMessageVersion::try_from( return_opaque.data[return_opaque.data_size as usize - 1], - ) - .unwrap(); + ) { + ssvs + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_MSG_FIELD), + Some(writer.used_slice()), + ); + }; } (Ok(()), Some(writer.used_slice())) diff --git a/spdmlib/src/responder/psk_finish_rsp.rs b/spdmlib/src/responder/psk_finish_rsp.rs index 6d444de..41fd6f0 100644 --- a/spdmlib/src/responder/psk_finish_rsp.rs +++ b/spdmlib/src/responder/psk_finish_rsp.rs @@ -80,10 +80,16 @@ impl ResponderContext { let temp_used = read_used - base_hash_size; { - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if !session.get_use_psk() { self.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, writer); @@ -106,10 +112,16 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let transcript_hash = self.common @@ -120,10 +132,16 @@ impl ResponderContext { } let transcript_hash = transcript_hash.as_ref().unwrap(); - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = + if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; let res = session.verify_hmac_with_request_finished_key( transcript_hash.as_ref(), &psk_finish_req.verify_data, @@ -182,10 +200,15 @@ impl ResponderContext { ); } - let session = self - .common - .get_immutable_session_via_id(session_id) - .unwrap(); + let session = if let Some(session) = self.common.get_immutable_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; // generate the data secret let th2 = self .common @@ -198,7 +221,15 @@ impl ResponderContext { let th2 = th2.unwrap(); debug!("!!! th2 : {:02x?}\n", th2.as_ref()); let spdm_version_sel = self.common.negotiate_info.spdm_version_sel; - let session = self.common.get_session_via_id(session_id).unwrap(); + let session = if let Some(session) = self.common.get_session_via_id(session_id) { + session + } else { + self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(writer.used_slice()), + ); + }; if let Err(e) = session.generate_data_secret(spdm_version_sel, &th2) { self.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, writer); return (Err(e), Some(writer.used_slice()));