diff --git a/crates/obfs4/Cargo.toml b/crates/obfs4/Cargo.toml index 0c844f2..986d866 100644 --- a/crates/obfs4/Cargo.toml +++ b/crates/obfs4/Cargo.toml @@ -81,10 +81,10 @@ simple_asn1 = { version="0.6.1", optional=true} tracing-subscriber = "0.3.18" hex-literal = "0.4.1" tor-basic-utils = "0.20.0" +rand_distr = "0.4.3" # benches # criterion = "0.5" - # # o5 pqc test # pqc_kyber = {version="0.7.1", features=["kyber1024", "std"]} # ml-kem = "0.1.0" diff --git a/crates/obfs4/README.md b/crates/obfs4/README.md index c4675d6..1c8c389 100644 --- a/crates/obfs4/README.md +++ b/crates/obfs4/README.md @@ -65,22 +65,42 @@ tokio::spawn(async move { Server example using [ptrs](../ptrs) ```rs -use ptrs::{ServerBuilder, ServerTransport}; -... +use ptrs::{ServerBuilder as _, ServerTransport as _}; +use obfs4::Obfs4PT; + +let mut builder = Obfs4PT::server_builder(); +let server = if params.is_some() { + builder.options(¶ms.unwrap())?.build() +} else { + builder.build() +}; + +let listener = tokio::net::TcpListener::bind(listen_addrs).await?; +loop { + let (conn, _) = listener.accept()?; + let pt_conn = server.reveal(conn).await?; + + // pt_conn wraps conn and is usable as an `AsyncRead + AsyncWrite` object. + tokio::spawn( async move{ + // use the connection (e.g. to echo) + let (mut r, mut w) = tokio::io::split(pt_conn); + if let Err(e) = tokio::io::copy(&mut r, &mut w).await { + warn!("echo closed with error: {e}") + } + }); +} -// TODO fill out example ``` ### Loose Ends: - [X] server / client compatibility test go-to-rust and rust-to-go. +- [x] double check the bit randomization and clearing for high two bits in the `dalek` representative - [ ] length distribution things - [ ] iat mode handling -- [ ] double check the bit randomization and clearing for high two bits in the `dalek` representative ## Performance - comparison to golang -- comparison when kyber is enabled - NaCl encryption library(s) diff --git a/crates/obfs4/src/client.rs b/crates/obfs4/src/client.rs index c61c5a4..75203d1 100644 --- a/crates/obfs4/src/client.rs +++ b/crates/obfs4/src/client.rs @@ -140,9 +140,9 @@ impl Client { /// On a failed handshake the client will read for the remainder of the /// handshake timeout and then close the connection. - pub async fn wrap<'a, T>(self, mut stream: T) -> Result> + pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let session = sessions::new_client_session(self.station_pubkey, self.iat_mode); @@ -156,9 +156,9 @@ impl Client { pub async fn establish<'a, T, E>( self, mut stream_fut: Pin>, - ) -> Result> + ) -> Result where - T: AsyncRead + AsyncWrite + Unpin + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, E: std::error::Error + Send + Sync + 'static, { let stream = stream_fut.await.map_err(|e| Error::Other(Box::new(e)))?; diff --git a/crates/obfs4/src/common/delay/README.md b/crates/obfs4/src/common/delay/README.md new file mode 100644 index 0000000..e697c17 --- /dev/null +++ b/crates/obfs4/src/common/delay/README.md @@ -0,0 +1,49 @@ +# Sink Delays + +Adding Structured Delays to rust sinks on event. + + +Example test using a sampled normal distribution for the delay after each +send (`start_send()` if not using `SinkExt`). + +```rs +#[cfg(test)] +mod testing { + use super::*; + use futures::sink::{self, SinkExt}; + use std::time::Instant; + use rand_distr::{Normal, Distribution}; + + #[tokio::test] + async fn delay_sink() { + let start = Instant::now(); + + let unfold = sink::unfold(0, |mut sum, i: i32| async move { + sum += i; + eprintln!("{} - {:?}", i, Instant::now().duration_since(start)); + Ok::<_, futures::never::Never>(sum) + }); + futures::pin_mut!(unfold); + + // let mut delayed_unfold = DelayedSink::new(unfold, || Duration::from_secs(1)); + let mut delayed_unfold = DelayedSink::new(unfold, delay_distribution); + delayed_unfold.send(5).await.unwrap(); + delayed_unfold.send(4).await.unwrap(); + delayed_unfold.send(3).await.unwrap(); + } + + fn delay_distribution() -> Duration { + let distr = Normal::new(500.0, 100.0).unwrap(); + let dur_ms = distr.sample(&mut rand::thread_rng()); + Duration::from_millis(dur_ms as u64) + } +} +``` + +--- + +But I wanna go fast! Why would I ever want this??? + +-> This lets us control the delays (or leave them out) in between sink events. +As an example, we can control the delay between network writes, which helps when +reshaping the traffic fingerprint of a proxy connection. diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs new file mode 100644 index 0000000..6ad6b5a --- /dev/null +++ b/crates/obfs4/src/common/delay/mod.rs @@ -0,0 +1,107 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures::{sink::Sink, Future}; +use tokio::time::{Instant, Sleep}; + +use pin_project::pin_project; + +type DurationFn = fn() -> Duration; + +#[pin_project] +pub struct DelayedSink { + // #[pin] + // sink: Si, + // #[pin] + // sleep: Sleep, + sink: Pin>, + sleep: Pin>, + delay_fn: DurationFn, + _item: PhantomData, + _error: PhantomData, +} + +impl> DelayedSink { + pub fn new(sink: Si, delay_fn: DurationFn) -> Self { + let delay = delay_fn(); + let sleep = tokio::time::sleep(delay); + Self { + // sink, + // sleep, + sink: Box::pin(sink), + sleep: Box::pin(sleep), + delay_fn, + _item: PhantomData {}, + _error: PhantomData {}, + } + } +} + +impl> Sink for DelayedSink +where + J: Into, +{ + type Error = Si::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let s = self.project(); + match (s.sink.as_mut().poll_ready(cx), s.sleep.as_mut().poll(cx)) { + (Poll::Ready(k), Poll::Ready(_)) => Poll::Ready(k), + _ => Poll::Pending, + } + } + + fn start_send(self: Pin<&mut Self>, item: J) -> Result<(), Self::Error> { + let s = self.project(); + s.sink.as_mut().start_send(item.into())?; + + let delay = (*s.delay_fn)(); + + if delay.is_zero() { + s.sleep.as_mut().reset(Instant::now() + delay); + } + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.as_mut().poll_close(cx) + } +} + +#[cfg(test)] +mod testing { + use super::*; + use futures::sink::{self, SinkExt}; + use rand_distr::{Distribution, Normal}; + use std::time::Instant; + + #[tokio::test] + async fn delay_sink() { + let start = Instant::now(); + + let unfold = sink::unfold(0, |mut sum, i: i32| async move { + sum += i; + eprintln!("{} - {:?}", i, Instant::now().duration_since(start)); + Ok::<_, futures::never::Never>(sum) + }); + futures::pin_mut!(unfold); + + // let mut delayed_unfold = DelayedSink::new(unfold, || Duration::from_secs(1)); + let mut delayed_unfold = DelayedSink::new(unfold, delay_distribution); + delayed_unfold.send(5).await.unwrap(); + delayed_unfold.send(4).await.unwrap(); + delayed_unfold.send(3).await.unwrap(); + } + + fn delay_distribution() -> Duration { + let distr = Normal::new(500.0, 100.0).unwrap(); + let dur_ms = distr.sample(&mut rand::thread_rng()); + Duration::from_millis(dur_ms as u64) + } +} diff --git a/crates/obfs4/src/common/mod.rs b/crates/obfs4/src/common/mod.rs index 09f717d..ca19515 100644 --- a/crates/obfs4/src/common/mod.rs +++ b/crates/obfs4/src/common/mod.rs @@ -10,8 +10,9 @@ pub(crate) mod kdf; mod skip; pub use skip::discard; +pub(crate) mod delay; pub mod drbg; -// pub mod ntor; + pub mod ntor_arti; pub mod probdist; pub mod replay_filter; diff --git a/crates/obfs4/src/framing/codecs.rs b/crates/obfs4/src/framing/codecs.rs index f37e9c5..788fee1 100644 --- a/crates/obfs4/src/framing/codecs.rs +++ b/crates/obfs4/src/framing/codecs.rs @@ -2,6 +2,7 @@ use crate::{ common::drbg::{self, Drbg, Seed}, constants::MESSAGE_OVERHEAD, framing::{FrameError, Messages}, + Error, }; use bytes::{Buf, BufMut, BytesMut}; @@ -69,10 +70,28 @@ impl EncryptingCodec { pub(crate) fn handshake_complete(&mut self) { self.handshake_complete = true; } + + pub(crate) fn into_parts(self) -> (EncryptingEncoder, EncryptingDecoder) { + (self.encoder, self.decoder) + } + + #[allow(unused)] + pub(crate) fn from_parts( + e: EncryptingEncoder, + d: EncryptingDecoder, + hs_complete: bool, + ) -> Self { + Self { + // key, + encoder: e, + decoder: d, + handshake_complete: hs_complete, + } + } } ///Decoder is a frame decoder instance. -struct EncryptingDecoder { +pub(crate) struct EncryptingDecoder { key: [u8; KEY_LENGTH], nonce: NonceBox, drbg: Drbg, @@ -106,8 +125,19 @@ impl EncryptingDecoder { impl Decoder for EncryptingCodec { type Item = Messages; - type Error = FrameError; + type Error = Error; + + fn decode( + &mut self, + src: &mut BytesMut, + ) -> std::result::Result, Self::Error> { + self.decoder.decode(src) + } +} +impl Decoder for EncryptingDecoder { + type Item = Messages; + type Error = Error; // Decode decodes a stream of data and returns the length if any. ErrAgain is // a temporary failure, all other errors MUST be treated as fatal and the // session aborted. @@ -118,28 +148,28 @@ impl Decoder for EncryptingCodec { trace!( "decoding src:{}B {} {}", src.remaining(), - self.decoder.next_length, - self.decoder.next_length_invalid + self.next_length, + self.next_length_invalid ); // A length of 0 indicates that we do not know the expected size of // the next frame. we use this to store the length of a packet when we // receive the length at the beginning, but not the whole packet, since // future reads may not have the who packet (including length) available - if self.decoder.next_length == 0 { + if self.next_length == 0 { // Attempt to pull out the next frame length if LENGTH_LENGTH > src.remaining() { return Ok(None); } // derive the nonce that the peer would have used - self.decoder.next_nonce = self.decoder.nonce.next()?; + self.next_nonce = self.nonce.next()?; // Remove the field length from the buffer // let mut len_buf: [u8; LENGTH_LENGTH] = src[..LENGTH_LENGTH].try_into().unwrap(); let mut length = src.get_u16(); // De-obfuscate the length field - let length_mask = self.decoder.drbg.length_mask(); + let length_mask = self.drbg.length_mask(); trace!( "decoding {length:04x}^{length_mask:04x} {:04x}B", length ^ length_mask @@ -158,35 +188,35 @@ impl Decoder for EncryptingCodec { // paper. let invalid_length = length; - self.decoder.next_length_invalid = true; + self.next_length_invalid = true; length = rand::thread_rng().gen::() % (MAX_FRAME_LENGTH - MIN_FRAME_LENGTH) as u16 + MIN_FRAME_LENGTH as u16; error!( "invalid length {invalid_length} {length} {}", - self.decoder.next_length_invalid + self.next_length_invalid ); } - self.decoder.next_length = length; + self.next_length = length; } - let next_len = self.decoder.next_length as usize; + let next_len = self.next_length as usize; if next_len > src.len() { // The full string has not yet arrived. // // We reserve more space in the buffer. This is not strictly // necessary, but is a good idea performance-wise. - if !self.decoder.next_length_invalid { + if !self.next_length_invalid { src.reserve(next_len - src.len()); } trace!( "next_len > src.len --> reading more {} {}", - self.decoder.next_length, - self.decoder.next_length_invalid + self.next_length, + self.next_length_invalid ); // We inform the Framed that we need more bytes to form the next @@ -198,25 +228,25 @@ impl Decoder for EncryptingCodec { let data = src.get(..next_len).unwrap().to_vec(); // Unseal the frame - let key = GenericArray::from_slice(&self.decoder.key); + let key = GenericArray::from_slice(&self.key); let cipher = XSalsa20Poly1305::new(key); - let nonce = GenericArray::from_slice(&self.decoder.next_nonce); // unique per message + let nonce = GenericArray::from_slice(&self.next_nonce); // unique per message let res = cipher.decrypt(nonce, data.as_ref()); if res.is_err() { let e = res.unwrap_err(); trace!("failed to decrypt result: {e}"); - return Err(e.into()); + return Err(Error::Obfs4Framing(FrameError::from(e))); } - let plaintext = res?; + let plaintext = res.map_err(|e| Error::Obfs4Framing(FrameError::from(e)))?; if plaintext.len() < MESSAGE_OVERHEAD { - return Err(FrameError::InvalidMessage); + return Err(Error::Obfs4Framing(FrameError::InvalidMessage)); } // Clean up and prepare for the next frame // // we read a whole frame, we no longer know the size of the next pkt - self.decoder.next_length = 0; + self.next_length = 0; src.advance(next_len); debug!("decoding {next_len}B src:{}B", src.remaining()); @@ -224,13 +254,13 @@ impl Decoder for EncryptingCodec { Ok(Messages::Padding(_)) => Ok(None), Ok(m) => Ok(Some(m)), Err(FrameError::UnknownMessageType(_)) => Ok(None), - Err(e) => Err(e), + Err(e) => Err(Error::Obfs4Framing(e)), } } } /// Encoder is a frame encoder instance. -struct EncryptingEncoder { +pub(crate) struct EncryptingEncoder { key: [u8; KEY_LENGTH], nonce: NonceBox, drbg: Drbg, @@ -255,8 +285,15 @@ impl EncryptingEncoder { } impl Encoder for EncryptingCodec { - type Error = FrameError; + type Error = Error; + + fn encode(&mut self, plaintext: T, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> { + self.encoder.encode(plaintext, dst) + } +} +impl Encoder for EncryptingEncoder { + type Error = Error; /// Encode encodes a single frame worth of payload and returns. Plaintext /// should either be a handshake message OR a buffer containing one or more /// [`Message`]s already properly marshalled. The proided plaintext can @@ -272,7 +309,7 @@ impl Encoder for EncryptingCodec { // Don't send a frame if it is longer than the other end will accept. if plaintext.remaining() > MAX_FRAME_PAYLOAD_LENGTH { - return Err(FrameError::InvalidPayloadLength(plaintext.remaining())); + return Err(FrameError::InvalidPayloadLength(plaintext.remaining()).into()); } let mut plaintext_frame = BytesMut::new(); @@ -280,18 +317,20 @@ impl Encoder for EncryptingCodec { plaintext_frame.put(plaintext); // Generate a new nonce - let nonce_bytes = self.encoder.nonce.next()?; + let nonce_bytes = self.nonce.next()?; // Encrypt and MAC payload - let key = GenericArray::from_slice(&self.encoder.key); + let key = GenericArray::from_slice(&self.key); let cipher = XSalsa20Poly1305::new(key); let nonce = GenericArray::from_slice(&nonce_bytes); // unique per message - let ciphertext = cipher.encrypt(nonce, plaintext_frame.as_ref())?; + let ciphertext = cipher + .encrypt(nonce, plaintext_frame.as_ref()) + .map_err(|e| Error::Obfs4Framing(FrameError::Crypto(e)))?; // Obfuscate the length let mut length = ciphertext.len() as u16; - let length_mask: u16 = self.encoder.drbg.length_mask(); + let length_mask: u16 = self.drbg.length_mask(); debug!( "encoding➡️ {length}B, {length:04x}^{length_mask:04x} {:04x}", length ^ length_mask diff --git a/crates/obfs4/src/framing/mod.rs b/crates/obfs4/src/framing/mod.rs index 1157683..b744050 100644 --- a/crates/obfs4/src/framing/mod.rs +++ b/crates/obfs4/src/framing/mod.rs @@ -49,7 +49,7 @@ pub use messages_base::*; mod messages_v1; pub use messages_v1::{MessageTypes, Messages}; -mod codecs; +pub(crate) mod codecs; pub use codecs::EncryptingCodec as Obfs4Codec; pub(crate) mod handshake; diff --git a/crates/obfs4/src/framing/testing.rs b/crates/obfs4/src/framing/testing.rs index 516fc9c..9bbb8fd 100644 --- a/crates/obfs4/src/framing/testing.rs +++ b/crates/obfs4/src/framing/testing.rs @@ -10,7 +10,7 @@ /// use super::*; use crate::test_utils::init_subscriber; -use crate::Result; +use crate::{Error, Result}; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; @@ -67,11 +67,21 @@ async fn oversized_flow() -> Result<()> { let mut src = Bytes::from(oversized_messsage); let res = codec.encode(&mut src, &mut b); - assert_eq!( - res.unwrap_err(), - FrameError::InvalidPayloadLength(frame_len) - ); - Ok(()) + let e = res.unwrap_err(); + assert!(matches!( + e, + Error::Obfs4Framing(FrameError::InvalidPayloadLength(_)) + )); + match e { + Error::Obfs4Framing(FrameError::InvalidPayloadLength(f)) => { + if f == frame_len { + Ok(()) + } else { + panic!("expected frame_length {}, got {}", frame_len, f); + } + } + _ => panic!("expected InvalidPayloadLength, got {}", e), + } } #[tokio::test] diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 822a2f8..98bb806 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -1,22 +1,25 @@ use crate::{ common::{ - drbg, + delay, drbg, probdist::{self, WeightedDist}, }, constants::*, - framing, + framing::{self, codecs::EncryptingEncoder, FrameError, Messages, Obfs4Codec}, sessions::Session, Error, Result, }; use bytes::{Buf, BytesMut}; -use futures::{Sink, Stream}; +use futures::{ + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, +}; use pin_project::pin_project; use ptrs::trace; use sha2::{Digest, Sha256}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::time::{Duration, Instant}; -use tokio_util::codec::Framed; +use tokio_util::codec::{FramedRead, FramedWrite}; use std::{ io::Error as IoError, @@ -25,8 +28,6 @@ use std::{ task::{Context, Poll}, }; -use super::framing::{FrameError, Messages}; - #[allow(dead_code, unused)] #[derive(Default, Debug, Clone, Copy, PartialEq)] pub enum IAT { @@ -84,21 +85,18 @@ impl MaybeTimeout { } } +type MsgStream = Box> + Send + Unpin>; +type BytesSink = Box + Send + Unpin>; + #[pin_project] -pub struct Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +pub struct Obfs4Stream { // s: Arc>>, #[pin] - s: O4Stream, + s: O4Stream, } -impl Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn from_o4(o4: O4Stream) -> Self { +impl Obfs4Stream { + pub(crate) fn from_o4(o4: O4Stream) -> Self { Obfs4Stream { // s: Arc::new(Mutex::new(o4)), s: o4, @@ -107,12 +105,13 @@ where } #[pin_project] -pub(crate) struct O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +pub(crate) struct O4Stream { #[pin] - pub stream: Framed, + // pub stream: Framed, + // pub stream: Box>, + pub stream: MsgStream, + #[pin] + pub sink: BytesSink, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -120,17 +119,33 @@ where pub session: Session, } -impl O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new( +impl O4Stream { + pub(crate) fn new( // inner: &'a mut dyn Stream<'a>, inner: T, - codec: framing::Obfs4Codec, - session: Session, - ) -> O4Stream { - let stream = Framed::new(inner, codec); + codec: Obfs4Codec, + mut session: Session, + ) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let delay_fn = match session.get_iat_mode() { + IAT::Off => || Duration::ZERO, + IAT::Enabled | IAT::Paranoid => session.iat_duration_sampler(), + }; + + let (r, w) = tokio::io::split(inner); + let (e, d) = codec.into_parts(); + let encoding_sink = FramedWrite::new(w, e); + let sink = Box::new(delay::DelayedSink::< + FramedWrite, EncryptingEncoder>, + BytesMut, + Error, + >::new(encoding_sink, delay_fn)); + + let decoding_stream = FramedRead::new(r, d); + let stream: MsgStream = Box::new(decoding_stream); + let len_seed = session.len_seed(); let mut hasher = Sha256::new(); @@ -153,70 +168,16 @@ where ); Self { + sink, stream, session, length_dist, iat_dist, } } - - pub(crate) fn try_handle_non_payload_message(&mut self, msg: framing::Messages) -> Result<()> { - match msg { - Messages::Payload(_) => Err(FrameError::InvalidMessage.into()), - Messages::Padding(_) => Ok(()), - - // TODO: Handle other Messages - _ => Ok(()), - } - } - - /*// TODO Apply pad_burst logic and IAT policy to packet assembly (probably as part of AsyncRead / AsyncWrite impl) - /// Attempts to pad a burst of data so that the last packet is of the length - /// `to_pad_to`. This can involve creating multiple packets, making this - /// slightly complex. - /// - /// TODO: document logic more clearly - pub(crate) fn pad_burst(&self, buf: &mut BytesMut, to_pad_to: usize) -> Result<()> { - let tail_len = buf.len() % framing::MAX_SEGMENT_LENGTH; - - let pad_len: usize = if to_pad_to >= tail_len { - to_pad_to - tail_len - } else { - (framing::MAX_SEGMENT_LENGTH - tail_len) + to_pad_to - }; - - if pad_len > HEADER_LENGTH { - // pad_len > 19 - Ok(framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - pad_len - HEADER_LENGTH, - )?) - } else if pad_len > 0 { - framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - framing::MAX_MESSAGE_PAYLOAD_LENGTH, - )?; - // } else { - Ok(framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - pad_len, - )?) - } else { - Ok(()) - } - } */ } -impl AsyncWrite for O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncWrite for O4Stream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -226,31 +187,30 @@ where let mut this = self.as_mut().project(); // determine if the stream is ready to send an event? - if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending { - return Poll::Pending; + if futures::Sink::::poll_ready(this.sink.as_mut(), cx).is_pending() { + return Poll::Pending } // while we have bytes in the buffer write MAX_MESSAGE_PAYLOAD_LENGTH // chunks until we have less than that amount left. // TODO: asyncwrite - apply length_dist instead of just full payloads let mut len_sent: usize = 0; - let mut out_buf = BytesMut::with_capacity(framing::MAX_MESSAGE_PAYLOAD_LENGTH); while msg_len - len_sent > framing::MAX_MESSAGE_PAYLOAD_LENGTH { + let mut out_buf = BytesMut::with_capacity(framing::MAX_MESSAGE_PAYLOAD_LENGTH); // package one chunk of the mesage as a payload - let payload = framing::Messages::Payload( + let payload = Messages::Payload( buf[len_sent..len_sent + framing::MAX_MESSAGE_PAYLOAD_LENGTH].to_vec(), ); // send the marshalled payload payload.marshall(&mut out_buf)?; - this.stream.as_mut().start_send(&mut out_buf)?; + this.sink.as_mut().start_send(out_buf)?; len_sent += framing::MAX_MESSAGE_PAYLOAD_LENGTH; - out_buf.clear(); // determine if the stream is ready to send more data. if not back off - if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending { - return Poll::Ready(Ok(len_sent)); + if futures::Sink::::poll_ready(this.sink.as_mut(), cx).is_pending() { + return Poll::Ready(Ok(len_sent)) } } @@ -258,7 +218,7 @@ where let mut out_buf = BytesMut::new(); payload.marshall(&mut out_buf)?; - this.stream.as_mut().start_send(out_buf)?; + this.sink.as_mut().start_send(out_buf)?; Poll::Ready(Ok(msg_len)) } @@ -266,7 +226,7 @@ where fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{} flushing", self.session.id()); let mut this = self.project(); - match futures::Sink::<&[u8]>::poll_flush(this.stream.as_mut(), cx) { + match futures::Sink::::poll_flush(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, @@ -276,7 +236,7 @@ where fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{} shutting down", self.session.id()); let mut this = self.project(); - match futures::Sink::<&[u8]>::poll_close(this.stream.as_mut(), cx) { + match futures::Sink::::poll_close(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, @@ -284,10 +244,7 @@ where } } -impl AsyncRead for O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncRead for O4Stream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -297,9 +254,9 @@ where // the network. Not all data received is guaranteed to be usable payload, // so do this in a loop until we would block on a read or an error occurs. loop { + let mut this = self.as_mut().project(); let msg = { // mutable borrow of self is dropped at the end of this block - let mut this = self.as_mut().project(); match this.stream.as_mut().poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(res) => { @@ -318,7 +275,7 @@ where } }; - if let framing::Messages::Payload(message) = msg { + if let Messages::Payload(message) = msg { buf.put_slice(&message); return Poll::Ready(Ok(())); } @@ -326,18 +283,18 @@ where continue; } - match self.as_mut().try_handle_non_payload_message(msg) { - Ok(_) => continue, - Err(e) => return Poll::Ready(Err(e.into())), + match msg { + Messages::Payload(_) => return Poll::Ready(Err(FrameError::InvalidMessage.into())), + Messages::Padding(_) => return Poll::Ready(Ok(())), + + // TODO: Handle other Messages + _ => return Poll::Ready(Ok(())), } } } } -impl AsyncWrite for Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncWrite for Obfs4Stream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -349,7 +306,11 @@ where fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - this.s.poll_flush(cx) + match Sink::poll_flush(this.s, cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -358,10 +319,7 @@ where } } -impl AsyncRead for Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncRead for Obfs4Stream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -371,3 +329,206 @@ where this.s.poll_read(cx, buf) } } + +// impl AsRef> for O4Stream { +// fn as_ref(&self) -> &dyn Sink { +// self.sink.as_ref() +// } +// } + +impl Sink for O4Stream { + type Error = Error; + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_ready_unpin(cx) + } + + fn start_send(self: Pin<&mut Self>, item: BytesMut) -> StdResult<(), Self::Error> { + self.project().sink.start_send_unpin(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_flush_unpin(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_close_unpin(cx) + } +} + +impl Stream for O4Stream { + type Item = Result; + + // Required method + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next_unpin(cx) + } +} +// +// ======================================================================== + +// TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl) +/// Attempts to pad a burst of data so that the last [`Message`] is of the length +/// `to_pad_to`. This can involve creating multiple packets, making this +/// slightly complex. +/// +/// TODO: document logic more clearly +pub(crate) fn pad_burst(buf: &mut BytesMut, to_pad_to: usize) -> Result<()> { + let tail_len = buf.len() % framing::MAX_SEGMENT_LENGTH; + + let pad_len: usize = if to_pad_to >= tail_len { + to_pad_to - tail_len + } else { + (framing::MAX_SEGMENT_LENGTH - tail_len) + to_pad_to + }; + + if pad_len > HEADER_LENGTH { + // pad_len > 19 + Ok(framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + pad_len - HEADER_LENGTH, + )?) + } else if pad_len > 0 { + framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + framing::MAX_MESSAGE_PAYLOAD_LENGTH, + )?; + // } else { + Ok(framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + pad_len, + )?) + } else { + Ok(()) + } +} + +// ======================================================================== + +/* +/// +/// Off: +/// pad burst = send max-frame-length frames while available, pad the last with +/// send with no delay +/// [ msg ] +/// [ max-pkt ][ max-pkt ][ max-pkt ][ max-pkt ][pkt]{pad} +/// Enabled: +/// pad burst = send max-frame-length frames while available, pad the last with +/// send with sampled delay +/// [ msg ] +/// [ max-pkt ]... [ max-pkt ]. [ max-pkt ].. [ max-pkt ].... [pkt]{pad} +/// Paranoid: +/// ?? +/// send with sampled delay +/// [ msg ] +/// [ max-pkt ]... [ max-pkt ]. [ max-pkt ].. [ max-pkt ].... [pkt]{pad} +fn split_and_pad(iat: IAT) { + // Send maximum sized frames. while they are available + let payload_chunks = b.chunks(MAX_MESSAGE_PAYLOAD_LENGTH); + + match iat { + IAT::Off => {} + IAT::Enabled => {} + IAT::Paranoid => {} + } +} + + + if conn.iatMode != iatParanoid { + // For non-paranoid IAT, pad once per burst. Paranoid IAT handles + // things differently. + if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { + return 0, err + } + } + + // Write the pending data onto the network. Partial writes are fatal, + // because the frame encoder state is advanced, and the code doesn't keep + // frameBuf around. In theory, write timeouts and whatnot could be + // supported if this wasn't the case, but that complicates the code. + if conn.iatMode != iatNone { + var iatFrame [framing.MaximumSegmentLength]byte + for frameBuf.Len() > 0 { + iatWrLen := 0 + + switch conn.iatMode { + case iatEnabled: + // Standard (ScrambleSuit-style) IAT obfuscation optimizes for + // bulk transport and will write ~MTU sized frames when + // possible. + iatWrLen, err = frameBuf.Read(iatFrame[:]) + + case iatParanoid: + // Paranoid IAT obfuscation throws performance out of the + // window and will sample the length distribution every time a + // write is scheduled. + targetLen := conn.lenDist.Sample() + if frameBuf.Len() < targetLen { + // There's not enough data buffered for the target write, + // so padding must be inserted. + if err = conn.padBurst(&frameBuf, targetLen); err != nil { + return 0, err + } + if frameBuf.Len() != targetLen { + // Ugh, padding came out to a value that required more + // than one frame, this is relatively unlikely so just + // resample since there's enough data to ensure that + // the next sample will be written. + continue + } + } + iatWrLen, err = frameBuf.Read(iatFrame[:targetLen]) + } + if err != nil { + return 0, err + } else if iatWrLen == 0 { + panic(fmt.Sprintf("BUG: Write(), iat length was 0")) + } + + // Calculate the delay. The delay resolution is 100 usec, leading + // to a maximum delay of 10 msec. + iatDelta := time.Duration(conn.iatDist.Sample() * 100) + + // Write then sleep. + _, err = conn.Conn.Write(iatFrame[:iatWrLen]) + if err != nil { + return 0, err + } + time.Sleep(iatDelta * time.Microsecond) + } + } else { + _, err = conn.Conn.Write(frameBuf.Bytes()) + } + + return +} + +/* + chopBuf := bytes.NewBuffer(b) + var payload [maxPacketPayloadLength]byte + var frameBuf bytes.Buffer + + + // Chop the pending data into payload frames. + for chopBuf.Len() > 0 { + rdLen := 0 + rdLen, err = chopBuf.Read(payload[:]) + if err != nil { + return 0, err + } else if rdLen == 0 { + panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) + } + n += rdLen + + err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) + if err != nil { + return 0, err + } + } +*/ +*/ diff --git a/crates/obfs4/src/pt.rs b/crates/obfs4/src/pt.rs index 7645213..7b2d181 100644 --- a/crates/obfs4/src/pt.rs +++ b/crates/obfs4/src/pt.rs @@ -2,7 +2,7 @@ use crate::{ constants::*, handshake::Obfs4NtorPublicKey, proto::{Obfs4Stream, IAT}, - Error, OBFS4_NAME, + Client, ClientBuilder, Error, Server, ServerBuilder, OBFS4_NAME, }; use ptrs::{args::Args, FutureResult as F}; @@ -14,6 +14,7 @@ use std::{ time::Duration, }; +use futures::TryFutureExt; use hex::FromHex; use ptrs::trace; use tokio::{ @@ -35,32 +36,32 @@ impl ptrs::PluggableTransport for Transport where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ClientBuilder = crate::ClientBuilder; - type ServerBuilder = crate::ServerBuilder; + type ClientBuilder = ClientBuilder; + type ServerBuilder = ServerBuilder; fn name() -> String { OBFS4_NAME.into() } fn client_builder() -> >::ClientBuilder { - crate::ClientBuilder::default() + ClientBuilder::default() } fn server_builder() -> >::ServerBuilder { - crate::ServerBuilder::default() + ServerBuilder::default() } } -impl ptrs::ServerBuilder for crate::ServerBuilder +impl ptrs::ServerBuilder for ServerBuilder where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ServerPT = crate::Server; + type ServerPT = Server; type Error = Error; type Transport = Transport; fn build(&self) -> Self::ServerPT { - crate::ServerBuilder::build(self) + ServerBuilder::build(self) } fn method_name() -> String { @@ -105,11 +106,11 @@ where } } -impl ptrs::ClientBuilder for crate::ClientBuilder +impl ptrs::ClientBuilder for ClientBuilder where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ClientPT = crate::Client; + type ClientPT = Client; type Error = Error; type Transport = Transport; @@ -122,7 +123,7 @@ where /// **Errors** /// If a required field has not been initialized. fn build(&self) -> Self::ClientPT { - crate::ClientBuilder::build(self) + ClientBuilder::build(self) } /// Pluggable transport attempts to parse and validate options from a string, @@ -208,21 +209,21 @@ where /// Example wrapping transport that just passes the incoming connection future through /// unmodified as a proof of concept. -impl ptrs::ClientTransport for crate::Client +impl ptrs::ClientTransport for Client where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, InErr: std::error::Error + Send + Sync + 'static, { - type OutRW = Obfs4Stream; + type OutRW = Obfs4Stream; type OutErr = Error; - type Builder = crate::ClientBuilder; + type Builder = ClientBuilder; fn establish(self, input: Pin>) -> Pin> { - Box::pin(crate::Client::establish(self, input)) + Box::pin(Client::establish(self, input)) } fn wrap(self, io: InRW) -> Pin> { - Box::pin(crate::Client::wrap(self, io)) + Box::pin(Client::wrap(self, io).map_err(|e| e)) } fn method_name() -> String { @@ -230,17 +231,21 @@ where } } -impl ptrs::ServerTransport for crate::Server +impl ptrs::ServerTransport for Server where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type OutRW = Obfs4Stream; + // Our out read/write is an obfs4strea, that has an error of IoError. + // If an error occurs while revealing it will be returned as an [`Error`]. + // As input we require a Serverbuilder with an InRW type that implements + // the async read and write traits. + type OutRW = Obfs4Stream; type OutErr = Error; - type Builder = crate::ServerBuilder; + type Builder = ServerBuilder; /// Use something that can be accessed reference (Arc, Rc, etc.) fn reveal(self, io: InRW) -> Pin> { - Box::pin(crate::Server::wrap(self, io)) + Box::pin(Server::wrap(self, io)) } fn method_name() -> String { @@ -257,18 +262,16 @@ mod test { let pt_name = >::name(); assert_eq!(pt_name, Obfs4PT::NAME); - let cb_name = >::method_name(); + let cb_name = >::method_name(); assert_eq!(cb_name, Obfs4PT::NAME); - let sb_name = - as ptrs::ServerBuilder>::method_name(); + let sb_name = as ptrs::ServerBuilder>::method_name(); assert_eq!(sb_name, Obfs4PT::NAME); - let ct_name = - >::method_name(); + let ct_name = >::method_name(); assert_eq!(ct_name, Obfs4PT::NAME); - let st_name = >::method_name(); + let st_name = >::method_name(); assert_eq!(st_name, Obfs4PT::NAME); } } diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index d7dfd42..8762c90 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -4,26 +4,28 @@ use super::*; use crate::{ client::ClientBuilder, common::{ - colorize, drbg, + colorize, discard, drbg, + ntor_arti::{RelayHandshakeError, ServerHandshake}, replay_filter::{self, ReplayFilter}, x25519_elligator2::{PublicKey, StaticSecret}, HmacSha256, }, constants::*, framing::{FrameError, Marshall, Obfs4Codec, TryParse, KEY_LENGTH}, - handshake::{Obfs4NtorPublicKey, Obfs4NtorSecretKey}, - proto::{MaybeTimeout, Obfs4Stream, IAT}, - sessions::Session, + handshake::{Obfs4Keygen, Obfs4NtorPublicKey, Obfs4NtorSecretKey, SHSMaterials}, + proto::{MaybeTimeout, O4Stream, Obfs4Stream, IAT}, + sessions::{Established, Fault, Initialized, Session}, Error, Result, }; use ptrs::args::Args; +use ptrs::{debug, info, trace}; +use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::{borrow::BorrowMut, marker::PhantomData, ops::Deref, str::FromStr, sync::Arc}; use bytes::{Buf, BufMut, Bytes}; use hex::FromHex; use hmac::{Hmac, Mac}; -use ptrs::{debug, info}; use rand::prelude::*; use subtle::ConstantTimeEq; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -303,14 +305,62 @@ impl Server { Self::new_from_key(identity_keys) } - pub async fn wrap(self, stream: T) -> Result> + // ====================================================================== // + // Server Handshake // + // ====================================================================== // + + pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let session = self.new_server_session()?; let deadline = self.handshake_timeout.map(|d| Instant::now() + d); - session.handshake(&self, stream, deadline).await + let hs_materials = SHSMaterials::new( + &session.identity_keys, + session.session_id(), + session.len_seed.to_bytes(), + ); + + let mut session = session.transition(ServerHandshaking {}); + + let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; + let handshake_fut = self.complete_handshake(&mut stream, hs_materials, deadline); + + let mut keygen = + match tokio::time::timeout_at(deadline.unwrap_or(d_def), handshake_fut).await { + Ok(result) => match result { + Ok(handshake) => handshake, + Err(e) => { + // non-timeout error, + let id = session.session_id(); + let _ = session.fault(ServerHandshakeFailed { + details: format!("{id} handshake failed {e}"), + }); + return Err(e); + } + }, + Err(_) => { + let id = session.session_id(); + let _ = session.fault(ServerHandshakeFailed { + details: format!("{id} timed out"), + }); + return Err(Error::HandshakeTimeout); + } + }; + + // post handshake state updates + session.set_session_id(keygen.session_id()); + let mut codec: framing::Obfs4Codec = keygen.into(); + + // mark session as Established + let session_state: ServerSession = session.transition(Established {}); + + codec.handshake_complete(); + let o4 = O4Stream::new(stream, codec, Session::Server(session_state)); + + Ok(Obfs4Stream::from_o4(o4)) + // session.handshake(&self, stream, deadline).await } // pub fn set_iat_mode(&mut self, mode: IAT) -> &Self { @@ -340,14 +390,13 @@ impl Server { } } - pub(crate) fn new_server_session( - &self, - ) -> Result> { + pub(crate) fn new_server_session(&self) -> Result> { let mut session_id = [0u8; SESSION_ID_LEN]; rand::thread_rng().fill_bytes(&mut session_id); - Ok(sessions::ServerSession { + Ok(ServerSession { // fixed by server identity_keys: self.identity_keys.clone(), + iat_mode: self.iat_mode, biased: self.biased, // generated per session @@ -355,9 +404,141 @@ impl Server { len_seed: drbg::Seed::new().unwrap(), iat_seed: drbg::Seed::new().unwrap(), - _state: sessions::Initialized {}, + _state: Initialized {}, }) } + + /// Complete the handshake with the client. This function assumes that the + /// client has already sent a message and that we do not know yet if the + /// message is valid. + async fn complete_handshake( + &self, + mut stream: T, + materials: SHSMaterials, + deadline: Option, + ) -> Result + where + T: AsyncRead + AsyncWrite + Unpin, + { + let session_id = materials.session_id.clone(); + + // wait for and attempt to consume the client hello message + let mut buf = [0_u8; MAX_HANDSHAKE_LENGTH]; + loop { + let n = stream.read(&mut buf).await?; + if n == 0 { + stream.shutdown().await?; + return Err(IoError::from(IoErrorKind::UnexpectedEof).into()); + } + trace!("{} successful read {n}B", session_id); + + match self.server(&mut |_: &()| Some(()), &[materials.clone()], &buf[..n]) { + Ok((keygen, response)) => { + stream.write_all(&response).await?; + info!("{} handshake complete", session_id); + return Ok(keygen); + } + Err(RelayHandshakeError::EAgain) => { + trace!("{} reading more", session_id); + continue; + } + Err(e) => { + trace!("{} failed to parse client handshake: {e}", session_id); + // if a deadline was set and has not passed already, discard + // from the stream until the deadline, then close. + if deadline.is_some_and(|d| d > Instant::now()) { + debug!("{} discarding due to: {e}", session_id); + discard(&mut stream, deadline.unwrap() - Instant::now()).await? + } + stream.shutdown().await?; + return Err(e.into()); + } + }; + } + } +} + +// ================================================================ // +// Server Sessions States // +// ================================================================ // + +pub(crate) struct ServerSession { + // fixed by server + pub(crate) identity_keys: Obfs4NtorSecretKey, + pub(crate) iat_mode: IAT, + pub(crate) biased: bool, + // pub(crate) server: &'a Server, + + // generated per session + pub(crate) session_id: [u8; SESSION_ID_LEN], + pub(crate) len_seed: drbg::Seed, + pub(crate) iat_seed: drbg::Seed, + + pub(crate) _state: S, +} + +pub(crate) struct ServerHandshaking {} + +#[allow(unused)] +pub(crate) struct ServerHandshakeFailed { + details: String, +} + +pub(crate) trait ServerSessionState {} +impl ServerSessionState for Initialized {} +impl ServerSessionState for ServerHandshaking {} +impl ServerSessionState for Established {} + +impl ServerSessionState for ServerHandshakeFailed {} +impl Fault for ServerHandshakeFailed {} + +impl ServerSession { + pub fn session_id(&self) -> String { + String::from("s-") + &colorize(self.session_id) + } + + pub(crate) fn set_session_id(&mut self, id: [u8; SESSION_ID_LEN]) { + debug!( + "{} -> {} server updating session id", + colorize(self.session_id), + colorize(id) + ); + self.session_id = id; + } + + /// Helper function to perform state transitions. + pub(crate) fn transition(self, _state: T) -> ServerSession { + ServerSession { + // fixed by server + identity_keys: self.identity_keys, + iat_mode: self.iat_mode, + biased: self.biased, + + // generated per session + session_id: self.session_id, + len_seed: self.len_seed, + iat_seed: self.iat_seed, + + _state, + } + } + + /// Helper function to perform state transition on error. + pub(crate) fn fault(self, f: F) -> ServerSession { + ServerSession { + // fixed by server + identity_keys: self.identity_keys, + iat_mode: self.iat_mode, + biased: self.biased, + + // generated per session + session_id: self.session_id, + len_seed: self.len_seed, + iat_seed: self.iat_seed, + + _state: f, + } + } } #[cfg(test)] diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index 1703aa7..09c4c31 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -5,26 +5,23 @@ use crate::{ common::{ colorize, discard, drbg, - ntor_arti::{ClientHandshake, RelayHandshakeError, ServerHandshake}, + ntor_arti::{ClientHandshake, RelayHandshakeError}, }, constants::*, framing, - handshake::{ - CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey, Obfs4NtorSecretKey, - SHSMaterials, - }, + handshake::{CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey}, proto::{O4Stream, Obfs4Stream, IAT}, - server::Server, + server::ServerSession, Error, Result, }; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use bytes::BytesMut; -use ptrs::{debug, info, trace}; +use ptrs::{debug, info}; use rand_core::RngCore; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::time::Instant; +use tokio::time::{Duration, Instant}; use tokio_util::codec::Decoder; /// Initial state for a Session, created with any params. @@ -34,7 +31,7 @@ pub(crate) struct Initialized; pub(crate) struct Established; /// The session broke due to something like a timeout, reset, lost connection, etc. -trait Fault {} +pub(crate) trait Fault {} pub(crate) enum Session { Client(ClientSession), @@ -53,7 +50,7 @@ impl Session { pub fn biased(&self) -> bool { match self { Session::Client(cs) => cs.biased, - Session::Server(ss) => ss.biased, //biased, + Session::Server(ss) => ss.biased, } } @@ -63,6 +60,21 @@ impl Session { Session::Server(ss) => ss.len_seed.clone(), } } + + pub(crate) fn get_iat_mode(&self) -> IAT { + match self { + Session::Client(cs) => cs.iat_mode, + Session::Server(ss) => ss.iat_mode, + } + } + + pub(crate) fn iat_duration_sampler(&mut self) -> fn() -> Duration { + || Duration::from_secs(1) + } + + pub(crate) fn sample_iat_length() -> usize { + 0usize + } } // ================================================================ // @@ -167,13 +179,13 @@ impl ClientSession { /// TODO: make sure failure modes align with golang obfs4 /// - FIN/RST based on buffered data. /// - etc. - pub async fn handshake( + pub async fn handshake<'a, T>( self, mut stream: T, deadline: Option, - ) -> Result> + ) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { // set up for handshake let mut session = self.transition(ClientHandshaking {}); @@ -305,193 +317,28 @@ impl std::fmt::Debug for ClientSession { } } -// ================================================================ // -// Server Sessions States // -// ================================================================ // - -pub(crate) struct ServerSession { - // fixed by server - pub(crate) identity_keys: Obfs4NtorSecretKey, - pub(crate) biased: bool, - // pub(crate) server: &'a Server, - - // generated per session - pub(crate) session_id: [u8; SESSION_ID_LEN], - pub(crate) len_seed: drbg::Seed, - pub(crate) iat_seed: drbg::Seed, - - pub(crate) _state: S, -} - -pub(crate) struct ServerHandshaking {} - -#[allow(unused)] -pub(crate) struct ServerHandshakeFailed { - details: String, -} - -pub(crate) trait ServerSessionState {} -impl ServerSessionState for Initialized {} -impl ServerSessionState for ServerHandshaking {} -impl ServerSessionState for Established {} - -impl ServerSessionState for ServerHandshakeFailed {} -impl Fault for ServerHandshakeFailed {} - -impl ServerSession { - pub fn session_id(&self) -> String { - String::from("s-") + &colorize(self.session_id) - } - - pub(crate) fn set_session_id(&mut self, id: [u8; SESSION_ID_LEN]) { - debug!( - "{} -> {} server updating session id", - colorize(self.session_id), - colorize(id) - ); - self.session_id = id; - } - - /// Helper function to perform state transitions. - fn transition(self, _state: T) -> ServerSession { - ServerSession { - // fixed by server - identity_keys: self.identity_keys, - biased: self.biased, - - // generated per session - session_id: self.session_id, - len_seed: self.len_seed, - iat_seed: self.iat_seed, - - _state, - } - } - - /// Helper function to perform state transition on error. - fn fault(self, f: F) -> ServerSession { - ServerSession { - // fixed by server - identity_keys: self.identity_keys, - biased: self.biased, - - // generated per session - session_id: self.session_id, - len_seed: self.len_seed, - iat_seed: self.iat_seed, - - _state: f, - } - } -} - -impl ServerSession { - /// Attempt to complete the handshake with a new client connection. - pub async fn handshake( - self, - server: &Server, - mut stream: T, - deadline: Option, - ) -> Result> - where - T: AsyncRead + AsyncWrite + Unpin, - { - // set up for handshake - let mut session = self.transition(ServerHandshaking {}); - - let materials = SHSMaterials::new( - &session.identity_keys, - session.session_id(), - session.len_seed.to_bytes(), - ); - - // default deadline - let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; - let handshake_fut = server.complete_handshake(&mut stream, materials, deadline); - - let mut keygen = - match tokio::time::timeout_at(deadline.unwrap_or(d_def), handshake_fut).await { - Ok(result) => match result { - Ok(handshake) => handshake, - Err(e) => { - // non-timeout error, - let id = session.session_id(); - let _ = session.fault(ServerHandshakeFailed { - details: format!("{id} handshake failed {e}"), - }); - return Err(e); - } - }, - Err(_) => { - let id = session.session_id(); - let _ = session.fault(ServerHandshakeFailed { - details: format!("{id} timed out"), - }); - return Err(Error::HandshakeTimeout); - } - }; - - // post handshake state updates - session.set_session_id(keygen.session_id()); - let mut codec: framing::Obfs4Codec = keygen.into(); - - // mark session as Established - let session_state: ServerSession = session.transition(Established {}); - - codec.handshake_complete(); - let o4 = O4Stream::new(stream, codec, Session::Server(session_state)); - - Ok(Obfs4Stream::from_o4(o4)) - } -} - -impl Server { - /// Complete the handshake with the client. This function assumes that the - /// client has already sent a message and that we do not know yet if the - /// message is valid. - async fn complete_handshake( - &self, - mut stream: T, - materials: SHSMaterials, - deadline: Option, - ) -> Result - where - T: AsyncRead + AsyncWrite + Unpin, - { - let session_id = materials.session_id.clone(); - - // wait for and attempt to consume the client hello message - let mut buf = [0_u8; MAX_HANDSHAKE_LENGTH]; - loop { - let n = stream.read(&mut buf).await?; - if n == 0 { - stream.shutdown().await?; - return Err(IoError::from(IoErrorKind::UnexpectedEof).into()); - } - trace!("{} successful read {n}B", session_id); - - match self.server(&mut |_: &()| Some(()), &[materials.clone()], &buf[..n]) { - Ok((keygen, response)) => { - stream.write_all(&response).await?; - info!("{} handshake complete", session_id); - return Ok(keygen); - } - Err(RelayHandshakeError::EAgain) => { - trace!("{} reading more", session_id); - continue; - } - Err(e) => { - trace!("{} failed to parse client handshake: {e}", session_id); - // if a deadline was set and has not passed already, discard - // from the stream until the deadline, then close. - if deadline.is_some_and(|d| d > Instant::now()) { - debug!("{} discarding due to: {e}", session_id); - discard(&mut stream, deadline.unwrap() - Instant::now()).await? - } - stream.shutdown().await?; - return Err(e.into()); - } - }; - } - } -} +// impl ServerSession { +// /// Attempt to complete the handshake with a new client connection. +// pub async fn handshake( +// self, +// server: &Server, +// mut stream: T, +// deadline: Option, +// ) -> Result +// where +// T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +// { +// // set up for handshake +// +// let materials = SHSMaterials::new( +// &session.identity_keys, +// session.session_id(), +// session.len_seed.to_bytes(), +// ); +// +// // default deadline +// let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; +// let handshake_fut = server.complete_handshake(&mut stream, materials, deadline); +// +// } +// } diff --git a/crates/obfs4/src/testing.rs b/crates/obfs4/src/testing.rs index d2997c9..85736ad 100644 --- a/crates/obfs4/src/testing.rs +++ b/crates/obfs4/src/testing.rs @@ -9,19 +9,19 @@ use std::time::Duration; #[tokio::test] async fn public_handshake() -> Result<()> { init_subscriber(); - let (mut c, mut s) = tokio::io::duplex(65_536); + let (c, s) = tokio::io::duplex(65_536); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let _ = tokio::io::split(o4s_stream); }); let o4_client = client_config.build(); - let _o4c_stream = o4_client.wrap(&mut c).await?; + let _o4c_stream = o4_client.wrap(c).await?; Ok(()) } @@ -30,14 +30,14 @@ async fn public_handshake() -> Result<()> { async fn public_iface() -> Result<()> { init_subscriber(); let message = b"awoewaeojawenwaefaw lfawn;awe da;wfenalw fawf aw"; - let (mut c, mut s) = tokio::io::duplex(65_536); + let (c, s) = tokio::io::duplex(65_536); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let mut o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let mut o4s_stream = o4_server.wrap(s).await.unwrap(); // let (mut r, mut w) = tokio::io::split(o4s_stream); // tokio::io::copy(&mut r, &mut w).await.unwrap(); @@ -52,7 +52,7 @@ async fn public_iface() -> Result<()> { }); let o4_client = client_config.build(); - let mut o4c_stream = o4_client.wrap(&mut c).await?; + let mut o4c_stream = o4_client.wrap(c).await?; o4c_stream.write_all(&message[..]).await?; o4c_stream.flush().await?; @@ -76,14 +76,14 @@ async fn public_iface() -> Result<()> { async fn transfer_10k_x1() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -129,13 +129,13 @@ async fn transfer_10k_x1() -> Result<()> { async fn transfer_10k_x3() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let o4_server = Server::getrandom(); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -183,14 +183,14 @@ async fn transfer_10k_x3() -> Result<()> { async fn transfer_1M_1024x1024() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -236,14 +236,14 @@ async fn transfer_1M_1024x1024() -> Result<()> { async fn transfer_512k_x1() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 512); + let (c, s) = tokio::io::duplex(1024 * 512); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -290,14 +290,14 @@ async fn transfer_512k_x1() -> Result<()> { async fn transfer_2_x() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); });