diff --git a/crates/hotfix/src/lib.rs b/crates/hotfix/src/lib.rs index 573c29f..a25e568 100644 --- a/crates/hotfix/src/lib.rs +++ b/crates/hotfix/src/lib.rs @@ -25,7 +25,6 @@ pub mod application; pub mod config; pub mod initiator; pub mod message; -pub mod message_utils; pub mod session; mod session_schedule; pub mod store; diff --git a/crates/hotfix/src/message.rs b/crates/hotfix/src/message.rs index 77f2e59..9f75b87 100644 --- a/crates/hotfix/src/message.rs +++ b/crates/hotfix/src/message.rs @@ -2,7 +2,9 @@ use hotfix_message::error::EncodingError as EncodeError; pub use hotfix_message::field_types::Timestamp; pub(crate) use hotfix_message::message::{Config, Message}; -use hotfix_message::session_fields::{MSG_SEQ_NUM, SENDER_COMP_ID, SENDING_TIME, TARGET_COMP_ID}; +use hotfix_message::session_fields::{ + MSG_SEQ_NUM, ORIG_SENDING_TIME, POSS_DUP_FLAG, SENDER_COMP_ID, SENDING_TIME, TARGET_COMP_ID, +}; pub use hotfix_message::{Part, RepeatingGroup}; pub mod business_reject; @@ -20,12 +22,75 @@ pub mod verification_error; pub use parser::RawFixMessage; pub use resend_request::ResendRequest; +use heartbeat::Heartbeat; +use logon::Logon; +use logout::Logout; +use reject::Reject; +use sequence_reset::SequenceReset; +use test_request::TestRequest; + +static ADMIN_TYPES: [&str; 7] = [ + Logon::MSG_TYPE, + Heartbeat::MSG_TYPE, + TestRequest::MSG_TYPE, + ResendRequest::MSG_TYPE, + Reject::MSG_TYPE, + SequenceReset::MSG_TYPE, + Logout::MSG_TYPE, +]; + +pub fn is_admin(message_type: &str) -> bool { + ADMIN_TYPES.contains(&message_type) +} + pub trait OutboundMessage: Clone + Send + 'static { fn write(&self, msg: &mut Message); fn message_type(&self) -> &str; } +/// Prepares a FIX message for resend per the FIX spec (PossDupFlag logic). +/// +/// Behaviour: +/// - On first resend (no PossDupFlag Y / no OrigSendingTime): +/// * Move current SendingTime(52) to OrigSendingTime(122) +/// * Set SendingTime(52) to the current sending time (may be equal if clock granularity causes no change) +/// * Set PossDupFlag(43)=Y +/// - On subsequent resends (already marked possible duplicate and has OrigSendingTime): +/// * Refresh SendingTime(52) to current time (value may or may not differ from previous) +pub fn prepare_message_for_resend(msg: &mut Message) -> Result<(), &'static str> { + let header = msg.header_mut(); + + if header.get_raw(SENDING_TIME).is_none() { + return Err("Missing SendingTime (52)"); + } + + let already_poss_dup = header.get::(POSS_DUP_FLAG).unwrap_or(false); + let has_orig_sending_time = header.get_raw(ORIG_SENDING_TIME).is_some(); + + if already_poss_dup && has_orig_sending_time { + // Subsequent resend: refresh SendingTime only + return if header.pop(SENDING_TIME).is_some() { + header.set(SENDING_TIME, Timestamp::utc_now()); + Ok(()) + } else { + Err("Failed to extract previous SendingTime") + }; + } + + // First resend path + if let Some(original_sending_time_field) = header.pop(SENDING_TIME) { + let original_ts = Timestamp::parse(&original_sending_time_field.data) + .ok_or("Invalid original SendingTime format")?; + header.set(ORIG_SENDING_TIME, original_ts); + header.set(SENDING_TIME, Timestamp::utc_now()); + header.set(POSS_DUP_FLAG, true); + Ok(()) + } else { + Err("Failed to extract original SendingTime") + } +} + pub fn generate_message( begin_string: &str, sender_comp_id: &str, diff --git a/crates/hotfix/src/message/business_reject.rs b/crates/hotfix/src/message/business_reject.rs index 302a2e0..9b745da 100644 --- a/crates/hotfix/src/message/business_reject.rs +++ b/crates/hotfix/src/message/business_reject.rs @@ -49,6 +49,8 @@ pub(crate) struct BusinessReject { } impl BusinessReject { + pub(crate) const MSG_TYPE: &str = "j"; + pub(crate) fn new(ref_msg_type: &str, reason: BusinessRejectReason) -> Self { Self { ref_msg_type: ref_msg_type.to_string(), @@ -100,7 +102,7 @@ impl OutboundMessage for BusinessReject { } fn message_type(&self) -> &str { - "j" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/heartbeat.rs b/crates/hotfix/src/message/heartbeat.rs index 56470fe..2157571 100644 --- a/crates/hotfix/src/message/heartbeat.rs +++ b/crates/hotfix/src/message/heartbeat.rs @@ -9,6 +9,8 @@ pub struct Heartbeat { } impl Heartbeat { + pub const MSG_TYPE: &str = "0"; + pub fn for_request(test_req_id: String) -> Self { Self { test_req_id: Some(test_req_id), @@ -24,6 +26,6 @@ impl OutboundMessage for Heartbeat { } fn message_type(&self) -> &str { - "0" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/logon.rs b/crates/hotfix/src/message/logon.rs index 8571618..c3d5fd0 100644 --- a/crates/hotfix/src/message/logon.rs +++ b/crates/hotfix/src/message/logon.rs @@ -19,6 +19,8 @@ pub enum ResetSeqNumConfig { } impl Logon { + pub const MSG_TYPE: &str = "A"; + pub fn new(heartbeat_interval: u64, reset_config: ResetSeqNumConfig) -> Self { let (reset_seq_num_flag, next_expected_msg_seq_num) = match reset_config { ResetSeqNumConfig::Reset => (ResetSeqNumFlag::Yes, None), @@ -45,7 +47,7 @@ impl OutboundMessage for Logon { } fn message_type(&self) -> &str { - "A" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/logout.rs b/crates/hotfix/src/message/logout.rs index 12141af..88b195e 100644 --- a/crates/hotfix/src/message/logout.rs +++ b/crates/hotfix/src/message/logout.rs @@ -9,6 +9,8 @@ pub struct Logout { } impl Logout { + pub const MSG_TYPE: &str = "5"; + pub fn with_reason(reason: String) -> Self { Self { text: Some(reason) } } @@ -22,6 +24,6 @@ impl OutboundMessage for Logout { } fn message_type(&self) -> &str { - "5" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/reject.rs b/crates/hotfix/src/message/reject.rs index d0c1e54..fd751a0 100644 --- a/crates/hotfix/src/message/reject.rs +++ b/crates/hotfix/src/message/reject.rs @@ -16,6 +16,8 @@ pub(crate) struct Reject { } impl Reject { + pub(crate) const MSG_TYPE: &str = "3"; + pub(crate) fn new(ref_seq_num: u64) -> Self { Self { ref_seq_num, @@ -85,7 +87,7 @@ impl OutboundMessage for Reject { } fn message_type(&self) -> &str { - "3" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/resend_request.rs b/crates/hotfix/src/message/resend_request.rs index 1aca0ce..2c904f7 100644 --- a/crates/hotfix/src/message/resend_request.rs +++ b/crates/hotfix/src/message/resend_request.rs @@ -10,6 +10,8 @@ pub struct ResendRequest { } impl ResendRequest { + pub const MSG_TYPE: &str = "2"; + pub fn new(begin: u64, end: u64) -> Self { Self { begin_seq_no: begin, @@ -25,6 +27,6 @@ impl OutboundMessage for ResendRequest { } fn message_type(&self) -> &str { - "2" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/sequence_reset.rs b/crates/hotfix/src/message/sequence_reset.rs index 2167fc0..f0cffc7 100644 --- a/crates/hotfix/src/message/sequence_reset.rs +++ b/crates/hotfix/src/message/sequence_reset.rs @@ -12,6 +12,10 @@ pub struct SequenceReset { pub new_seq_no: u64, } +impl SequenceReset { + pub const MSG_TYPE: &str = "4"; +} + impl OutboundMessage for SequenceReset { fn write(&self, msg: &mut Message) { msg.set(GAP_FILL_FLAG, self.gap_fill); @@ -25,6 +29,6 @@ impl OutboundMessage for SequenceReset { } fn message_type(&self) -> &str { - "4" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/test_request.rs b/crates/hotfix/src/message/test_request.rs index 0dac034..4e0762a 100644 --- a/crates/hotfix/src/message/test_request.rs +++ b/crates/hotfix/src/message/test_request.rs @@ -9,6 +9,8 @@ pub struct TestRequest { } impl TestRequest { + pub const MSG_TYPE: &str = "1"; + pub fn new(test_req_id: String) -> Self { Self { test_req_id } } @@ -20,6 +22,6 @@ impl OutboundMessage for TestRequest { } fn message_type(&self) -> &str { - "1" + Self::MSG_TYPE } } diff --git a/crates/hotfix/src/message/verification.rs b/crates/hotfix/src/message/verification.rs index 9e884fd..123f378 100644 --- a/crates/hotfix/src/message/verification.rs +++ b/crates/hotfix/src/message/verification.rs @@ -17,6 +17,8 @@ pub(crate) fn verify_message( message: &Message, config: &SessionConfig, expected_seq_number: Option, + check_too_high: bool, + check_too_low: bool, ) -> Result<(), MessageVerificationError> { check_begin_string(message, config.begin_string.as_str())?; let actual_seq_number: u64 = message.header().get(MSG_SEQ_NUM).unwrap_or_default(); @@ -37,7 +39,13 @@ pub(crate) fn verify_message( } if let Some(expected_seq_number) = expected_seq_number { - check_sequence_number(actual_seq_number, expected_seq_number, possible_duplicate)?; + check_sequence_number( + actual_seq_number, + expected_seq_number, + possible_duplicate, + check_too_high, + check_too_low, + )?; } Ok(()) @@ -141,15 +149,17 @@ fn check_sequence_number( actual_seq_number: u64, expected_seq_number: u64, possible_duplicate: bool, + check_too_high: bool, + check_too_low: bool, ) -> Result<(), MessageVerificationError> { match actual_seq_number.cmp(&expected_seq_number) { - Ordering::Greater => { + Ordering::Greater if check_too_high => { return Err(MessageVerificationError::SeqNumberTooHigh { expected: expected_seq_number, actual: actual_seq_number, }); } - Ordering::Less => { + Ordering::Less if check_too_low => { return Err(MessageVerificationError::SeqNumberTooLow { expected: expected_seq_number, actual: actual_seq_number, @@ -223,7 +233,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 42); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(result.is_ok()); } @@ -233,7 +243,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.2", "TARGET", "SENDER", 42); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -249,7 +259,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "WRONG_SENDER", "SENDER", 42); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -275,7 +285,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "WRONG_TARGET", 42); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -301,7 +311,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 40); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -327,7 +337,7 @@ mod tests { msg.header_mut().set(fix44::POSS_DUP_FLAG, true); msg.header_mut().set(fix44::ORIG_SENDING_TIME, sending_time); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -350,7 +360,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 50); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -369,7 +379,7 @@ mod tests { msg.header_mut().set(fix44::POSS_DUP_FLAG, true); // Don't set OrigSendingTime - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -394,7 +404,7 @@ mod tests { msg.header_mut().pop(fix44::SENDING_TIME); msg.header_mut().set(fix44::SENDING_TIME, sending_time); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(result.is_ok()); } @@ -413,7 +423,7 @@ mod tests { msg.header_mut().pop(fix44::SENDING_TIME); msg.header_mut().set(fix44::SENDING_TIME, sending_time); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -443,7 +453,7 @@ mod tests { msg.header_mut().pop(fix44::SENDING_TIME); msg.header_mut().set(fix44::SENDING_TIME, timestamp); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); // equal timestamps should be valid (orig <= sending) assert!(result.is_ok()); @@ -461,7 +471,7 @@ mod tests { // remove begin string, which is automatically added by `Message::new` msg.header_mut().pop(fix44::BEGIN_STRING); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -477,7 +487,7 @@ mod tests { msg.set(fix44::MSG_SEQ_NUM, 42u64); msg.set(fix44::SENDING_TIME, Timestamp::utc_now()); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -496,7 +506,7 @@ mod tests { msg.set(fix44::MSG_SEQ_NUM, 42u64); msg.set(fix44::SENDING_TIME, Timestamp::utc_now()); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -515,7 +525,7 @@ mod tests { msg.set(fix44::TARGET_COMP_ID, "SENDER"); msg.set(fix44::SENDING_TIME, Timestamp::utc_now()); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); // missing seq num defaults to 0, which will be too low assert!(matches!( @@ -529,7 +539,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 0); - let result = verify_message(&msg, &config, Some(1)); + let result = verify_message(&msg, &config, Some(1), true, true); assert!(matches!( result, @@ -542,7 +552,7 @@ mod tests { let config = build_test_config(); let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 1); - let result = verify_message(&msg, &config, Some(1)); + let result = verify_message(&msg, &config, Some(1), true, true); assert!(result.is_ok()); } @@ -553,7 +563,7 @@ mod tests { // wrong begin string AND wrong seq num - begin string error should come first let msg = build_test_message("FIX.4.2", "TARGET", "SENDER", 100); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -567,7 +577,7 @@ mod tests { // wrong sender and wrong target - sender error should come first let msg = build_test_message("FIX.4.4", "WRONG_SENDER", "WRONG_TARGET", 42); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -586,7 +596,7 @@ mod tests { msg.set(fix44::TARGET_COMP_ID, "SENDER"); msg.set(fix44::MSG_SEQ_NUM, 42u64); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -607,13 +617,14 @@ mod tests { msg.set(fix44::TARGET_COMP_ID, "SENDER"); msg.set(fix44::MSG_SEQ_NUM, 42u64); - // set sending time to 121 seconds in the past (beyond the threshold) + // set sending time to 122 seconds in the past (beyond the 120 second threshold, + // with margin to account for millisecond truncation in Timestamp) let now = chrono::Utc::now(); - let past_time = now - Duration::seconds(121); + let past_time = now - Duration::seconds(122); let past_timestamp: Timestamp = past_time.naive_utc().into(); msg.set(fix44::SENDING_TIME, past_timestamp); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -634,13 +645,14 @@ mod tests { msg.set(fix44::TARGET_COMP_ID, "SENDER"); msg.set(fix44::MSG_SEQ_NUM, 42u64); - // set sending time to 121 seconds in the future (beyond the threshold) + // set sending time to 122 seconds in the future (beyond the 120 second threshold, + // with margin to account for millisecond truncation in Timestamp) let now = chrono::Utc::now(); - let future_time = now + Duration::seconds(121); + let future_time = now + Duration::seconds(122); let future_timestamp: Timestamp = future_time.naive_utc().into(); msg.set(fix44::SENDING_TIME, future_timestamp); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(matches!( result, @@ -667,7 +679,7 @@ mod tests { let boundary_timestamp: Timestamp = boundary_time.naive_utc().into(); msg.set(fix44::SENDING_TIME, boundary_timestamp); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); assert!(result.is_ok()); } @@ -688,8 +700,93 @@ mod tests { let valid_timestamp: Timestamp = valid_time.naive_utc().into(); msg.set(fix44::SENDING_TIME, valid_timestamp); - let result = verify_message(&msg, &config, Some(42)); + let result = verify_message(&msg, &config, Some(42), true, true); + + assert!(result.is_ok()); + } + + #[test] + fn test_seq_number_too_high_skipped_when_check_too_high_false() { + let config = build_test_config(); + let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 50); + + // With check_too_high=false, seq 50 > expected 42 should be OK + let result = verify_message(&msg, &config, Some(42), false, true); + + assert!(result.is_ok()); + } + + #[test] + fn test_seq_number_too_low_skipped_when_check_too_low_false() { + let config = build_test_config(); + let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 40); + + // With check_too_low=false, seq 40 < expected 42 should be OK + let result = verify_message(&msg, &config, Some(42), true, false); assert!(result.is_ok()); } + + #[test] + fn test_both_checks_disabled() { + let config = build_test_config(); + // Seq number too high + let msg_high = build_test_message("FIX.4.4", "TARGET", "SENDER", 50); + assert!(verify_message(&msg_high, &config, Some(42), false, false).is_ok()); + + // Seq number too low + let msg_low = build_test_message("FIX.4.4", "TARGET", "SENDER", 40); + assert!(verify_message(&msg_low, &config, Some(42), false, false).is_ok()); + + // Seq number matches + let msg_match = build_test_message("FIX.4.4", "TARGET", "SENDER", 42); + assert!(verify_message(&msg_match, &config, Some(42), false, false).is_ok()); + } + + #[test] + fn test_check_too_high_true_still_catches_too_high() { + let config = build_test_config(); + let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 50); + + let result = verify_message(&msg, &config, Some(42), true, false); + + assert!(matches!( + result, + Err(MessageVerificationError::SeqNumberTooHigh { .. }) + )); + } + + #[test] + fn test_check_too_low_true_still_catches_too_low() { + let config = build_test_config(); + let msg = build_test_message("FIX.4.4", "TARGET", "SENDER", 40); + + let result = verify_message(&msg, &config, Some(42), false, true); + + assert!(matches!( + result, + Err(MessageVerificationError::SeqNumberTooLow { .. }) + )); + } + + #[test] + fn test_non_seq_checks_still_applied_when_seq_checks_disabled() { + let config = build_test_config(); + + // Wrong sender comp ID should still be caught even with both seq checks disabled + let msg = build_test_message("FIX.4.4", "WRONG_SENDER", "SENDER", 42); + let result = verify_message(&msg, &config, Some(42), false, false); + assert!(matches!( + result, + Err(MessageVerificationError::IncorrectCompId { .. }) + )); + + // Wrong begin string should still be caught + let msg = build_test_message("FIX.4.2", "TARGET", "SENDER", 42); + let result = verify_message(&msg, &config, Some(42), false, false); + assert!(matches!( + result, + Err(MessageVerificationError::IncorrectBeginString(_)) + )); + } } diff --git a/crates/hotfix/src/message_utils.rs b/crates/hotfix/src/message_utils.rs deleted file mode 100644 index e4f6d88..0000000 --- a/crates/hotfix/src/message_utils.rs +++ /dev/null @@ -1,127 +0,0 @@ -static ADMIN_TYPES: [&str; 7] = ["A", "0", "1", "2", "3", "4", "5"]; - -pub fn is_admin(message_type: &str) -> bool { - ADMIN_TYPES.contains(&message_type) -} - -use hotfix_message::Part; -use hotfix_message::field_types::Timestamp; -use hotfix_message::message::Message; -use hotfix_message::session_fields::{ORIG_SENDING_TIME, POSS_DUP_FLAG, SENDING_TIME}; - -/// Prepares a FIX message for resend per the FIX spec (PossDupFlag logic). -/// -/// Behaviour: -/// - On first resend (no PossDupFlag Y / no OrigSendingTime): -/// * Move current SendingTime(52) to OrigSendingTime(122) -/// * Set SendingTime(52) to the current sending time (may be equal if clock granularity causes no change) -/// * Set PossDupFlag(43)=Y -/// - On subsequent resends (already marked possible duplicate and has OrigSendingTime): -/// * Refresh SendingTime(52) to current time (value may or may not differ from previous) -pub fn prepare_message_for_resend(msg: &mut Message) -> Result<(), &'static str> { - let header = msg.header_mut(); - - if header.get_raw(SENDING_TIME).is_none() { - return Err("Missing SendingTime (52)"); - } - - let already_poss_dup = header.get::(POSS_DUP_FLAG).unwrap_or(false); - let has_orig_sending_time = header.get_raw(ORIG_SENDING_TIME).is_some(); - - if already_poss_dup && has_orig_sending_time { - // Subsequent resend: refresh SendingTime only - return if header.pop(SENDING_TIME).is_some() { - header.set(SENDING_TIME, Timestamp::utc_now()); - Ok(()) - } else { - Err("Failed to extract previous SendingTime") - }; - } - - // First resend path - if let Some(original_sending_time_field) = header.pop(SENDING_TIME) { - let original_ts = Timestamp::parse(&original_sending_time_field.data) - .ok_or("Invalid original SendingTime format")?; - header.set(ORIG_SENDING_TIME, original_ts); - header.set(SENDING_TIME, Timestamp::utc_now()); - header.set(POSS_DUP_FLAG, true); - Ok(()) - } else { - Err("Failed to extract original SendingTime") - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hotfix_message::fix44; - - fn build_test_message() -> Message { - let mut msg = Message::new("FIX.4.4", "D"); - msg.set(fix44::SENDER_COMP_ID, "SND"); - msg.set(fix44::TARGET_COMP_ID, b"TGT"); - msg.set(fix44::MSG_SEQ_NUM, 1u64); - msg.set(fix44::SENDING_TIME, Timestamp::utc_now()); - msg - } - - #[test] - fn first_resend_sets_poss_dup_and_orig_sending_time() { - let mut msg = build_test_message(); - prepare_message_for_resend(&mut msg).unwrap(); - let header = msg.header(); - assert!( - header.get::(fix44::POSS_DUP_FLAG).unwrap(), - "PossDupFlag must be set on first resend" - ); - // Presence checks only (values may be equal or different depending on clock granularity) - assert!( - header.get_raw(fix44::ORIG_SENDING_TIME).is_some(), - "OrigSendingTime must be present" - ); - assert!( - header.get_raw(fix44::SENDING_TIME).is_some(), - "SendingTime must be present after resend" - ); - } - - #[test] - fn subsequent_resend_preserves_orig_sending_time() { - let mut msg = build_test_message(); - prepare_message_for_resend(&mut msg).unwrap(); - let orig_first = msg - .header() - .get::(fix44::ORIG_SENDING_TIME) - .unwrap(); - let sending_first = msg.header().get::(fix44::SENDING_TIME).unwrap(); - assert!( - msg.header().get::(fix44::POSS_DUP_FLAG).unwrap(), - "PossDupFlag must be set after first resend" - ); - - // Second resend - prepare_message_for_resend(&mut msg).unwrap(); - let orig_second = msg - .header() - .get::(fix44::ORIG_SENDING_TIME) - .unwrap(); - let sending_second = msg.header().get::(fix44::SENDING_TIME).unwrap(); - assert!( - msg.header().get::(fix44::POSS_DUP_FLAG).unwrap(), - "PossDupFlag must remain set on subsequent resends" - ); - - assert_eq!( - orig_first, orig_second, - "OrigSendingTime must remain constant across resends" - ); - assert!( - sending_first >= orig_first, - "First resend SendingTime must be >= original" - ); - assert!( - sending_second >= sending_first, - "Second resend SendingTime must be >= first resend SendingTime" - ); - } -} diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index bbfa3e9..3926bc7 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -32,7 +32,7 @@ use crate::message::sequence_reset::SequenceReset; use crate::message::test_request::TestRequest; use crate::message::verification::verify_message; use crate::message::verification_error::{CompIdType, MessageVerificationError}; -use crate::message_utils::{is_admin, prepare_message_for_resend}; +use crate::message::{is_admin, prepare_message_for_resend}; use crate::session::admin_request::AdminRequest; use crate::session::error::SessionCreationError; use crate::session::error::{InternalSendError, InternalSendResultExt, SessionOperationError}; @@ -194,7 +194,7 @@ where if let SessionState::AwaitingResend(state) = &mut self.state { let seq_number = get_msg_seq_num(&message); - if seq_number > state.end_seq_number { + if seq_number > state.end_seq_number && message_type != ResendRequest::MSG_TYPE { state.inbound_queue.push_back(message); return Ok(()); } @@ -202,32 +202,32 @@ where if let SessionState::AwaitingLogon { .. } = &mut self.state { // TODO: should this (and all inbound message processing) logic be pushed into the state? - if message_type != "A" { + if message_type != Logon::MSG_TYPE { self.state.disconnect_writer().await; return Ok(()); } } match message_type { - "0" => { + Heartbeat::MSG_TYPE => { self.on_heartbeat(&message).await?; } - "1" => { + TestRequest::MSG_TYPE => { self.on_test_request(&message).await?; } - "2" => { + ResendRequest::MSG_TYPE => { self.on_resend_request(&message).await?; } - "3" => { + Reject::MSG_TYPE => { self.on_reject(&message).await?; } - "4" => { + SequenceReset::MSG_TYPE => { self.on_sequence_reset(&message).await?; } - "5" => { - self.on_logout().await?; + Logout::MSG_TYPE => { + self.on_logout(&message).await?; } - "A" => { + Logon::MSG_TYPE => { self.on_logon(&message).await?; } _ => self.process_app_message(&message).await?, @@ -240,7 +240,7 @@ where &mut self, message: &Message, ) -> Result<(), SessionOperationError> { - match self.verify_message(message, true) { + match self.verify_message(message, true, true) { Ok(_) => { match self.application.on_inbound_message(message).await { InboundDecision::Accept => {} @@ -293,8 +293,17 @@ where error!("failed to get seq number: {:?}", e); 0 }); - debug!(seq_number, "processing queued message"); - self.process_message(msg).await?; + let msg_type: &str = msg.header().get(MSG_TYPE).unwrap_or(""); + debug!(seq_number, msg_type, "processing queued message"); + + if msg_type == ResendRequest::MSG_TYPE { + // ResendRequest was already processed when it arrived (it bypasses + // the queue in process_message). Just increment the target seq number + // for sequence accounting purposes. + self.store.increment_target_seq_number().await?; + } else { + self.process_message(msg).await?; + } } debug!("resend backlog is cleared, resuming normal operation"); } @@ -305,14 +314,21 @@ where fn verify_message( &self, message: &Message, - verify_target_seq_number: bool, + check_too_high: bool, + check_too_low: bool, ) -> Result<(), MessageVerificationError> { - let expected_seq_number = if verify_target_seq_number { + let expected_seq_number = if check_too_high || check_too_low { Some(self.store.next_target_seq_number()) } else { None }; - verify_message(message, &self.config, expected_seq_number) + verify_message( + message, + &self.config, + expected_seq_number, + check_too_high, + check_too_low, + ) } async fn on_connect(&mut self, writer: WriterRef) -> Result<(), SessionOperationError> { @@ -346,7 +362,7 @@ where async fn on_logon(&mut self, message: &Message) -> Result<(), SessionOperationError> { if let SessionState::AwaitingLogon { writer, .. } = &self.state { - match self.verify_message(message, true) { + match self.verify_message(message, true, true) { Ok(_) => { // happy logon flow, the session is now active self.state = @@ -363,7 +379,12 @@ where Ok(()) } - async fn on_logout(&mut self) -> Result<(), SessionOperationError> { + async fn on_logout(&mut self, message: &Message) -> Result<(), SessionOperationError> { + if let Err(err) = self.verify_message(message, false, false) { + self.handle_verification_error(err).await?; + return Ok(()); + } + if self.state.is_logged_on() { self.send_logout("Logout acknowledged").await?; } @@ -390,6 +411,11 @@ where } async fn on_heartbeat(&mut self, message: &Message) -> Result<(), SessionOperationError> { + if let Err(err) = self.verify_message(message, true, true) { + self.handle_verification_error(err).await?; + return Ok(()); + } + if let (Some(expected_req_id), Ok(message_req_id)) = ( &self.state.expected_test_response_id(), message.get::<&str>(TEST_REQ_ID), @@ -404,6 +430,11 @@ where } async fn on_test_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { + if let Err(err) = self.verify_message(message, true, true) { + self.handle_verification_error(err).await?; + return Ok(()); + } + let req_id: &str = message.get(TEST_REQ_ID).unwrap_or_else(|_| { // TODO: send reject? todo!() @@ -421,12 +452,36 @@ where async fn on_resend_request(&mut self, message: &Message) -> Result<(), SessionOperationError> { if !self.state.is_connected() { warn!("received resend request while disconnected, ignoring"); + return Ok(()); + } + + // Verify with check_too_high=false so ResendRequest is never blocked by seq-too-high. + // This is the key part of the QFJ-673 deadlock fix: when both sides send ResendRequest + // simultaneously, each side's ResendRequest will have a seq number higher than expected. + // By not treating that as an error, we allow the ResendRequest to be processed. + match self.verify_message(message, false, true) { + Ok(_) => {} + Err(err) => { + self.handle_verification_error(err).await?; + return Ok(()); + } + } + + let msg_seq_num = get_msg_seq_num(message); + let expected = self.store.next_target_seq_number(); + + // If seq is too high and we're in AwaitingResend, queue it for seq accounting + // when the gap fill catches up, but still process the resend below. + if msg_seq_num > expected + && let SessionState::AwaitingResend(state) = &mut self.state + { + state.inbound_queue.push_back(message.clone()); } let begin_seq_number: u64 = match message.get(BEGIN_SEQ_NO) { Ok(seq_number) => seq_number, Err(_) => { - let reject = Reject::new(get_msg_seq_num(message)) + let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("missing begin sequence number for resend request"); self.send_message(reject) @@ -446,7 +501,7 @@ where } } Err(_) => { - let reject = Reject::new(get_msg_seq_num(message)) + let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("missing end sequence number for resend request"); self.send_message(reject) @@ -456,7 +511,10 @@ where } }; - self.store.increment_target_seq_number().await?; + // Only increment target seq if seq matches expected + if msg_seq_num == expected { + self.store.increment_target_seq_number().await?; + } self.resend_messages(begin_seq_number, end_seq_number, message) .await?; @@ -465,23 +523,20 @@ where } /// Handle Reject messages. - /// - /// Returns whether the message should be processed as usual - /// and whether the target sequence number should be incremented. async fn on_reject(&mut self, message: &Message) -> Result<(), SessionOperationError> { - if let Ok(seq_num) = message.get::(MSG_SEQ_NUM) - && seq_num == self.store.next_target_seq_number() - { - self.store.increment_target_seq_number().await?; + if let Err(err) = self.verify_message(message, false, true) { + self.handle_verification_error(err).await?; + return Ok(()); } + self.store.increment_target_seq_number().await?; Ok(()) } async fn on_sequence_reset(&mut self, message: &Message) -> Result<(), SessionOperationError> { let msg_seq_num = get_msg_seq_num(message); let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); - if let Err(err) = self.verify_message(message, is_gap_fill) { + if let Err(err) = self.verify_message(message, is_gap_fill, is_gap_fill) { self.handle_verification_error(err).await?; return Ok(()); } @@ -746,7 +801,7 @@ where .map_err(|_| SessionOperationError::MissingField("MSG_TYPE"))? .to_string(); - if is_admin(message_type.as_str()) { + if is_admin(&message_type) { if reset_start.is_none() { reset_start = Some(sequence_number); } @@ -766,11 +821,8 @@ where "failed to prepare message for resend, sending original" ); } - self.send_raw( - message_type.as_bytes(), - message.encode(&self.message_config)?, - ) - .await; + self.send_raw(&message_type, message.encode(&self.message_config)?) + .await; if enabled!(tracing::Level::DEBUG) && let Ok(m) = String::from_utf8(msg.clone()) @@ -833,7 +885,7 @@ where message: impl OutboundMessage, ) -> Result { let seq_num = self.store.next_sender_seq_number(); - let msg_type = message.message_type().as_bytes().to_vec(); + let msg_type = message.message_type().to_string(); let msg = generate_message( &self.config.begin_string, &self.config.sender_comp_id, @@ -863,7 +915,7 @@ where Ok(seq_num) } - async fn send_raw(&mut self, message_type: &[u8], data: Vec) { + async fn send_raw(&mut self, message_type: &str, data: Vec) { self.state .send_message(message_type, RawFixMessage::new(data)) .await; @@ -887,7 +939,7 @@ where sequence_reset, )?; - self.send_raw(b"4", raw_message).await; + self.send_raw(SequenceReset::MSG_TYPE, raw_message).await; debug!(begin, end, "sent reset sequence"); Ok(()) diff --git a/crates/hotfix/src/session/state.rs b/crates/hotfix/src/session/state.rs index 414b8d0..fa84472 100644 --- a/crates/hotfix/src/session/state.rs +++ b/crates/hotfix/src/session/state.rs @@ -1,3 +1,5 @@ +use crate::message::logon::Logon; +use crate::message::logout::Logout; use crate::message::parser::RawFixMessage; use crate::session::event::AwaitingActiveSessionResponse; use crate::session::info::Status as SessionInfoStatus; @@ -61,11 +63,11 @@ impl SessionState { } } - pub async fn send_message(&mut self, message_type: &[u8], message: RawFixMessage) { + pub async fn send_message(&mut self, message_type: &str, message: RawFixMessage) { match self { Self::Active(ActiveState { writer, .. }) | Self::AwaitingResend(AwaitingResendState { writer, .. }) => { - if message_type == b"A" { + if message_type == Logon::MSG_TYPE { error!("logon message is invalid for active sessions") } else { writer.send_raw_message(message).await @@ -73,28 +75,24 @@ impl SessionState { } Self::AwaitingLogon { writer, logon_sent, .. - } => { - match message_type { - b"A" => { - // Logon message - if *logon_sent { - error!("trying to send logon twice"); - } else { - writer.send_raw_message(message).await; - *logon_sent = true; - } - } - b"5" => { - // Logout message + } => match message_type { + Logon::MSG_TYPE => { + if *logon_sent { + error!("trying to send logon twice"); + } else { writer.send_raw_message(message).await; + *logon_sent = true; } - _ => error!("invalid outgoing message for AwaitingLogon state"), } - } + Logout::MSG_TYPE => { + writer.send_raw_message(message).await; + } + _ => error!("invalid outgoing message for AwaitingLogon state"), + }, Self::AwaitingLogout { writer, .. } => { // Logout messages are allowed because we first transition into AwaitingLogout // and only then send the logout message - if message_type == b"5" { + if message_type == Logout::MSG_TYPE { writer.send_raw_message(message).await } } diff --git a/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs index d4738c8..219d7a1 100644 --- a/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs +++ b/crates/hotfix/tests/session_test_cases/common/fakes/fake_counterparty.rs @@ -110,7 +110,7 @@ where _ => panic!("trying to resend invalid message"), }; - if let Err(err) = hotfix::message_utils::prepare_message_for_resend(&mut message) { + if let Err(err) = hotfix::message::prepare_message_for_resend(&mut message) { panic!("failed to prepare message for resend: {err:?}"); } diff --git a/crates/hotfix/tests/session_test_cases/common/test_messages.rs b/crates/hotfix/tests/session_test_cases/common/test_messages.rs index fbfa9fd..33b56ff 100644 --- a/crates/hotfix/tests/session_test_cases/common/test_messages.rs +++ b/crates/hotfix/tests/session_test_cases/common/test_messages.rs @@ -389,6 +389,22 @@ pub fn build_invalid_resend_request( msg.encode(&Config::default()).unwrap() } +/// A Reject message (MsgType=3) for testing counterparty-initiated rejects. +#[derive(Clone)] +pub struct TestReject { + pub ref_seq_num: u64, +} + +impl OutboundMessage for TestReject { + fn write(&self, msg: &mut Message) { + msg.set(fix44::REF_SEQ_NUM, self.ref_seq_num); + } + + fn message_type(&self) -> &str { + "3" + } +} + pub fn build_sequence_reset_without_new_seq_no(msg_seq_num: u64) -> Vec { let mut msg = Message::new("FIX.4.4", "4"); // MsgType 4 = SequenceReset msg.set(fix44::SENDER_COMP_ID, COUNTERPARTY_COMP_ID); diff --git a/crates/hotfix/tests/session_test_cases/heartbeat_tests.rs b/crates/hotfix/tests/session_test_cases/heartbeat_tests.rs index 2a60ffa..4b2147d 100644 --- a/crates/hotfix/tests/session_test_cases/heartbeat_tests.rs +++ b/crates/hotfix/tests/session_test_cases/heartbeat_tests.rs @@ -2,6 +2,7 @@ use crate::common::actions::when; use crate::common::assertions::{assert_msg_type, then}; use crate::common::cleanup::finally; use crate::common::setup::{HEARTBEAT_INTERVAL, given_an_active_session}; +use hotfix::message::heartbeat::Heartbeat; use hotfix::message::test_request::TestRequest; use hotfix_message::Part; use hotfix_message::fix44::{MsgType, TEST_REQ_ID}; @@ -17,7 +18,7 @@ use std::time::Duration; /// periodic heartbeat messages when no other messages are being exchanged, /// as required by the FIX protocol to prevent timeout disconnections. #[tokio::test(start_paused = true)] -async fn test_heartbeats() { +async fn test_heartbeat_is_sent() { let (session, mut counterparty) = given_an_active_session().await; // let's wait enough time for a heartbeat and assert that the heartbeat was sent @@ -85,3 +86,38 @@ async fn test_heartbeat_in_response_to_test_request() { finally(&session, &mut counterparty).disconnect().await; } + +/// Tests that receiving a heartbeat from the counterparty resets the peer timer. +/// +/// Without the counterparty heartbeat, the peer deadline would expire and a TestRequest +/// would be sent (as demonstrated by `test_peer_timeout`). By sending a counterparty +/// heartbeat after our first heartbeat, the peer timer resets, so advancing to our next +/// heartbeat produces a Heartbeat — not a TestRequest. +#[tokio::test(start_paused = true)] +async fn test_receiving_heartbeat_resets_peer_timer() { + let (session, mut counterparty) = given_an_active_session().await; + + // Wait for our first heartbeat + when(Duration::from_secs(HEARTBEAT_INTERVAL + 1)) + .elapses() + .await; + then(&mut counterparty) + .receives(|msg| assert_msg_type(msg, MsgType::Heartbeat)) + .await; + + // Counterparty sends a heartbeat, which should reset the peer timer + when(&mut counterparty) + .sends_message(Heartbeat::default()) + .await; + + // Advance to our next heartbeat. Without the peer timer reset above, + // a TestRequest would arrive before this heartbeat, failing the assertion. + when(Duration::from_secs(HEARTBEAT_INTERVAL + 1)) + .elapses() + .await; + then(&mut counterparty) + .receives(|msg| assert_msg_type(msg, MsgType::Heartbeat)) + .await; + + finally(&session, &mut counterparty).disconnect().await; +} diff --git a/crates/hotfix/tests/session_test_cases/invalid_message_tests.rs b/crates/hotfix/tests/session_test_cases/invalid_message_tests.rs index d4a1fb4..bb06759 100644 --- a/crates/hotfix/tests/session_test_cases/invalid_message_tests.rs +++ b/crates/hotfix/tests/session_test_cases/invalid_message_tests.rs @@ -3,7 +3,7 @@ use crate::common::assertions::{assert_msg_type, then}; use crate::common::cleanup::finally; use crate::common::setup::{COUNTERPARTY_COMP_ID, OUR_COMP_ID, given_an_active_session}; use crate::common::test_messages::{ - ExecutionReportWithInvalidField, TestMessage, build_execution_report_with_comp_id, + ExecutionReportWithInvalidField, TestMessage, TestReject, build_execution_report_with_comp_id, build_execution_report_with_custom_msg_type, build_execution_report_with_incorrect_begin_string, build_execution_report_with_incorrect_body_length, @@ -402,3 +402,34 @@ async fn test_scenario_2g_possdup_without_orig_sending_time() { finally(&session, &mut counterparty).disconnect().await; } + +/// Tests that a Reject (MsgType=3) from the counterparty is processed correctly. +/// +/// The session should increment the target sequence number and remain active, +/// continuing to accept subsequent messages. +#[tokio::test] +async fn test_processing_reject_from_counterparty() { + let (mut session, mut counterparty) = given_an_active_session().await; + + // Counterparty sends a Reject referencing our logon (seq 1) + let reject_seq_num = counterparty.next_target_sequence_number(); + when(&mut counterparty) + .sends_message(TestReject { ref_seq_num: 1 }) + .await; + + // The reject should be processed, incrementing the target sequence number + then(&mut session) + .target_sequence_number_reaches(reject_seq_num) + .await; + + // The session should remain active and accept further messages + let next_seq_num = counterparty.next_target_sequence_number(); + when(&mut counterparty) + .sends_message(TestMessage::dummy_execution_report()) + .await; + then(&mut session) + .target_sequence_number_reaches(next_seq_num) + .await; + + finally(&session, &mut counterparty).disconnect().await; +} diff --git a/crates/hotfix/tests/session_test_cases/resend_tests.rs b/crates/hotfix/tests/session_test_cases/resend_tests.rs index a54246f..dd2c8b1 100644 --- a/crates/hotfix/tests/session_test_cases/resend_tests.rs +++ b/crates/hotfix/tests/session_test_cases/resend_tests.rs @@ -220,3 +220,60 @@ async fn test_resend_request_with_gap_fill_for_admin_messages() { finally(&session, &mut counterparty).disconnect().await; } + +/// Tests that when both sides detect a sequence gap simultaneously (each sends a ResendRequest), +/// the session processes the counterparty's ResendRequest immediately instead of queueing it, +/// preventing a deadlock where neither side processes the other's ResendRequest. +#[tokio::test] +async fn test_resend_request_not_deadlocked_when_both_sides_detect_gap() { + let (mut session, mut counterparty) = given_an_active_session().await; + + // The counterparty previously sent an execution report which we missed + when(&mut counterparty) + .has_previously_sent(TestMessage::dummy_execution_report()) + .await; + + // The counterparty sends another message which we do receive, creating a gap + when(&mut counterparty) + .sends_message(TestMessage::dummy_execution_report()) + .await; + + // We detect the gap and enter AwaitingResend state, sending a ResendRequest + then(&mut session) + .status_changes_to(Status::AwaitingResend { + begin: 2, + end: 3, + attempts: 1, + }) + .await; + then(&mut counterparty) + .receives(|msg| assert_msg_type(msg, MsgType::ResendRequest)) + .await; + + // Now the counterparty also sends a ResendRequest (simulating they also detected a gap). + // This ResendRequest has seq number 4 (> end_seq_number 3), which would previously + // be queued — causing a deadlock. With the fix, it should be processed immediately. + let resend_request = ResendRequest::new(1, 0); + when(&mut counterparty).sends_message(resend_request).await; + + // The session should respond to the counterparty's ResendRequest with a SequenceReset-GapFill + // for the logon message (admin messages are gap-filled). + // This proves the ResendRequest was processed and not stuck in the queue. + then(&mut counterparty) + .receives(|msg| { + let msg_type: &str = msg.header().get(MSG_TYPE).unwrap(); + assert!( + msg_type == MsgType::SequenceReset.to_string() + || msg_type == MsgType::ExecutionReport.to_string(), + "expected SequenceReset or resent message in response to ResendRequest, got {msg_type}" + ); + }) + .await; + + // Now the counterparty fulfills our resend request so we can resume + when(&mut counterparty).resends_message(2).await; + when(&mut counterparty).resends_message(3).await; + then(&mut session).status_changes_to(Status::Active).await; + + finally(&session, &mut counterparty).disconnect().await; +}