From a6662b8399a6a6b29882ec0168e440e6cae61665 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Sat, 27 Apr 2024 21:44:49 -0600 Subject: [PATCH 1/9] delayed sink wrapper for iat delays between writes --- crates/obfs4/Cargo.toml | 1 + crates/obfs4/README.md | 30 +++++-- crates/obfs4/src/common/delay/README.md | 49 +++++++++++ crates/obfs4/src/common/delay/mod.rs | 104 ++++++++++++++++++++++++ crates/obfs4/src/common/mod.rs | 3 +- 5 files changed, 181 insertions(+), 6 deletions(-) create mode 100644 crates/obfs4/src/common/delay/README.md create mode 100644 crates/obfs4/src/common/delay/mod.rs diff --git a/crates/obfs4/Cargo.toml b/crates/obfs4/Cargo.toml index 9b2ef33..51c42c5 100644 --- a/crates/obfs4/Cargo.toml +++ b/crates/obfs4/Cargo.toml @@ -79,6 +79,7 @@ simple_asn1 = { version="0.6.1", optional=true} tracing-subscriber = "0.3.18" hex-literal = "0.4.1" tor-basic-utils = "0.18.0" +rand_distr = "0.4.3" # o5 pqc test # pqc_kyber = {version="0.7.1", features=["kyber1024", "std"]} diff --git a/crates/obfs4/README.md b/crates/obfs4/README.md index c4675d6..1c8c389 100644 --- a/crates/obfs4/README.md +++ b/crates/obfs4/README.md @@ -65,22 +65,42 @@ tokio::spawn(async move { Server example using [ptrs](../ptrs) ```rs -use ptrs::{ServerBuilder, ServerTransport}; -... +use ptrs::{ServerBuilder as _, ServerTransport as _}; +use obfs4::Obfs4PT; + +let mut builder = Obfs4PT::server_builder(); +let server = if params.is_some() { + builder.options(¶ms.unwrap())?.build() +} else { + builder.build() +}; + +let listener = tokio::net::TcpListener::bind(listen_addrs).await?; +loop { + let (conn, _) = listener.accept()?; + let pt_conn = server.reveal(conn).await?; + + // pt_conn wraps conn and is usable as an `AsyncRead + AsyncWrite` object. + tokio::spawn( async move{ + // use the connection (e.g. to echo) + let (mut r, mut w) = tokio::io::split(pt_conn); + if let Err(e) = tokio::io::copy(&mut r, &mut w).await { + warn!("echo closed with error: {e}") + } + }); +} -// TODO fill out example ``` ### Loose Ends: - [X] server / client compatibility test go-to-rust and rust-to-go. +- [x] double check the bit randomization and clearing for high two bits in the `dalek` representative - [ ] length distribution things - [ ] iat mode handling -- [ ] double check the bit randomization and clearing for high two bits in the `dalek` representative ## Performance - comparison to golang -- comparison when kyber is enabled - NaCl encryption library(s) diff --git a/crates/obfs4/src/common/delay/README.md b/crates/obfs4/src/common/delay/README.md new file mode 100644 index 0000000..e697c17 --- /dev/null +++ b/crates/obfs4/src/common/delay/README.md @@ -0,0 +1,49 @@ +# Sink Delays + +Adding Structured Delays to rust sinks on event. + + +Example test using a sampled normal distribution for the delay after each +send (`start_send()` if not using `SinkExt`). + +```rs +#[cfg(test)] +mod testing { + use super::*; + use futures::sink::{self, SinkExt}; + use std::time::Instant; + use rand_distr::{Normal, Distribution}; + + #[tokio::test] + async fn delay_sink() { + let start = Instant::now(); + + let unfold = sink::unfold(0, |mut sum, i: i32| async move { + sum += i; + eprintln!("{} - {:?}", i, Instant::now().duration_since(start)); + Ok::<_, futures::never::Never>(sum) + }); + futures::pin_mut!(unfold); + + // let mut delayed_unfold = DelayedSink::new(unfold, || Duration::from_secs(1)); + let mut delayed_unfold = DelayedSink::new(unfold, delay_distribution); + delayed_unfold.send(5).await.unwrap(); + delayed_unfold.send(4).await.unwrap(); + delayed_unfold.send(3).await.unwrap(); + } + + fn delay_distribution() -> Duration { + let distr = Normal::new(500.0, 100.0).unwrap(); + let dur_ms = distr.sample(&mut rand::thread_rng()); + Duration::from_millis(dur_ms as u64) + } +} +``` + +--- + +But I wanna go fast! Why would I ever want this??? + +-> This lets us control the delays (or leave them out) in between sink events. +As an example, we can control the delay between network writes, which helps when +reshaping the traffic fingerprint of a proxy connection. diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs new file mode 100644 index 0000000..e4f6367 --- /dev/null +++ b/crates/obfs4/src/common/delay/mod.rs @@ -0,0 +1,104 @@ +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures::{sink::Sink, Future}; +use tokio::time::{Instant, Sleep}; + +use pin_project::pin_project; + +type DurationFn = fn() -> Duration; + +#[pin_project] +pub struct DelayedSink { + // #[pin] + // sink: Si, + // #[pin] + // sleep: Sleep, + sink: Pin>, + sleep: Pin>, + delay_fn: DurationFn, + _item: PhantomData, +} + +impl> DelayedSink { + pub fn new(sink: Si, delay_fn: DurationFn) -> Self { + let delay = delay_fn(); + let sleep = tokio::time::sleep(delay); + Self { + // sink, + // sleep, + sink: Box::pin(sink), + sleep: Box::pin(sleep), + delay_fn, + _item: PhantomData {}, + } + } +} + +impl> Sink for DelayedSink { + type Error = Si::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let s = self.project(); + match (s.sink.as_mut().poll_ready(cx), s.sleep.as_mut().poll(cx)) { + (Poll::Ready(k), Poll::Ready(_)) => Poll::Ready(k), + _ => Poll::Pending, + } + } + + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + let s = self.project(); + if let Err(e) = s.sink.as_mut().start_send(item) { + return Err(e); + } + + let delay = (*s.delay_fn)(); + + s.sleep + .as_mut() + .reset(Instant::now() + delay); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.as_mut().poll_close(cx) + } +} + +#[cfg(test)] +mod testing { + use super::*; + use futures::sink::{self, SinkExt}; + use std::time::Instant; + use rand_distr::{Normal, Distribution}; + + #[tokio::test] + async fn delay_sink() { + let start = Instant::now(); + + let unfold = sink::unfold(0, |mut sum, i: i32| async move { + sum += i; + eprintln!("{} - {:?}", i, Instant::now().duration_since(start)); + Ok::<_, futures::never::Never>(sum) + }); + futures::pin_mut!(unfold); + + // let mut delayed_unfold = DelayedSink::new(unfold, || Duration::from_secs(1)); + let mut delayed_unfold = DelayedSink::new(unfold, delay_distribution); + delayed_unfold.send(5).await.unwrap(); + delayed_unfold.send(4).await.unwrap(); + delayed_unfold.send(3).await.unwrap(); + } + + fn delay_distribution() -> Duration { + let distr = Normal::new(500.0, 100.0).unwrap(); + let dur_ms = distr.sample(&mut rand::thread_rng()); + Duration::from_millis(dur_ms as u64) + } +} diff --git a/crates/obfs4/src/common/mod.rs b/crates/obfs4/src/common/mod.rs index 3ea0ed6..eae9499 100644 --- a/crates/obfs4/src/common/mod.rs +++ b/crates/obfs4/src/common/mod.rs @@ -12,7 +12,8 @@ mod skip; pub use skip::discard; pub mod drbg; -// pub mod ntor; +pub(crate) mod delay; + pub mod ntor_arti; pub mod probdist; pub mod replay_filter; From 9361776335328d9eb0d3128c9b1ea64c7613051f Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Sun, 28 Apr 2024 11:26:44 -0600 Subject: [PATCH 2/9] making a mess to clean up for IAT durations and lengths - interim commit --- crates/obfs4/src/client.rs | 4 +- crates/obfs4/src/common/drbg.rs | 1 + crates/obfs4/src/proto.rs | 297 +++++++++++++++++++++++--------- crates/obfs4/src/pt.rs | 4 +- crates/obfs4/src/server.rs | 3 +- crates/obfs4/src/sessions.rs | 29 +++- 6 files changed, 245 insertions(+), 93 deletions(-) diff --git a/crates/obfs4/src/client.rs b/crates/obfs4/src/client.rs index c61c5a4..6be9f0a 100644 --- a/crates/obfs4/src/client.rs +++ b/crates/obfs4/src/client.rs @@ -140,7 +140,7 @@ impl Client { /// On a failed handshake the client will read for the remainder of the /// handshake timeout and then close the connection. - pub async fn wrap<'a, T>(self, mut stream: T) -> Result> + pub async fn wrap<'a, T>(self, mut stream: T) -> Result where T: AsyncRead + AsyncWrite + Unpin + 'a, { @@ -156,7 +156,7 @@ impl Client { pub async fn establish<'a, T, E>( self, mut stream_fut: Pin>, - ) -> Result> + ) -> Result where T: AsyncRead + AsyncWrite + Unpin + 'a, E: std::error::Error + Send + Sync + 'static, diff --git a/crates/obfs4/src/common/drbg.rs b/crates/obfs4/src/common/drbg.rs index 92c98fa..d38fe44 100644 --- a/crates/obfs4/src/common/drbg.rs +++ b/crates/obfs4/src/common/drbg.rs @@ -207,6 +207,7 @@ impl RngCore for Drbg { } } + #[cfg(test)] mod test { use super::*; diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 822a2f8..4ef1ea4 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -1,16 +1,15 @@ use crate::{ common::{ - drbg, - probdist::{self, WeightedDist}, + delay, drbg, probdist::{self, WeightedDist} }, constants::*, - framing, + framing::{self, Message}, sessions::Session, Error, Result, }; use bytes::{Buf, BytesMut}; -use futures::{Sink, Stream}; +use futures::{sink::Sink, stream::Stream}; use pin_project::pin_project; use ptrs::trace; use sha2::{Digest, Sha256}; @@ -85,20 +84,14 @@ impl MaybeTimeout { } #[pin_project] -pub struct Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +pub struct Obfs4Stream { // s: Arc>>, #[pin] - s: O4Stream, + s: O4Stream, } -impl Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn from_o4(o4: O4Stream) -> Self { +impl Obfs4Stream { + pub(crate) fn from_o4(o4: O4Stream) -> Self { Obfs4Stream { // s: Arc::new(Mutex::new(o4)), s: o4, @@ -107,12 +100,10 @@ where } #[pin_project] -pub(crate) struct O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +pub(crate) struct O4Stream { #[pin] - pub stream: Framed, + // pub stream: Framed, + pub stream: Box>, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -120,17 +111,23 @@ where pub session: Session, } -impl O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ - pub(crate) fn new( +impl O4Stream { + pub(crate) fn new( // inner: &'a mut dyn Stream<'a>, inner: T, codec: framing::Obfs4Codec, session: Session, - ) -> O4Stream { - let stream = Framed::new(inner, codec); + ) -> O4Stream + where + T: AsyncRead + AsyncWrite + Unpin, + { + let stream: Box> = match session.get_iat_mode() { + IAT::Off => Box::new(Framed::new(inner, codec)), + IAT::Enabled | IAT::Paranoid => { + let f = Framed::new(inner, codec); + Box::new(delay::DelayedSink::new(f, session.iat_duration_sampler())) + } + }; let len_seed = session.len_seed(); let mut hasher = Sha256::new(); @@ -169,54 +166,10 @@ where _ => Ok(()), } } - - /*// TODO Apply pad_burst logic and IAT policy to packet assembly (probably as part of AsyncRead / AsyncWrite impl) - /// Attempts to pad a burst of data so that the last packet is of the length - /// `to_pad_to`. This can involve creating multiple packets, making this - /// slightly complex. - /// - /// TODO: document logic more clearly - pub(crate) fn pad_burst(&self, buf: &mut BytesMut, to_pad_to: usize) -> Result<()> { - let tail_len = buf.len() % framing::MAX_SEGMENT_LENGTH; - - let pad_len: usize = if to_pad_to >= tail_len { - to_pad_to - tail_len - } else { - (framing::MAX_SEGMENT_LENGTH - tail_len) + to_pad_to - }; - - if pad_len > HEADER_LENGTH { - // pad_len > 19 - Ok(framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - pad_len - HEADER_LENGTH, - )?) - } else if pad_len > 0 { - framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - framing::MAX_MESSAGE_PAYLOAD_LENGTH, - )?; - // } else { - Ok(framing::build_and_marshall( - buf, - MessageTypes::Payload.into(), - vec![], - pad_len, - )?) - } else { - Ok(()) - } - } */ } -impl AsyncWrite for O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ + +impl AsyncWrite for O4Stream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -284,10 +237,7 @@ where } } -impl AsyncRead for O4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncRead for O4Stream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -334,10 +284,7 @@ where } } -impl AsyncWrite for Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncWrite for Obfs4Stream { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -358,10 +305,7 @@ where } } -impl AsyncRead for Obfs4Stream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncRead for Obfs4Stream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -371,3 +315,188 @@ where this.s.poll_read(cx, buf) } } + +impl Sink for O4Stream { + type Error = (); + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!(); + } + + fn start_send(self: Pin<&mut Self>, item: Messages) -> StdResult<(), Self::Error> { + todo!(); + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!(); + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!(); + } +} + + +// TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl) +/// Attempts to pad a burst of data so that the last [`Message`] is of the length +/// `to_pad_to`. This can involve creating multiple packets, making this +/// slightly complex. +/// +/// TODO: document logic more clearly +pub(crate) fn pad_burst(buf: &mut BytesMut, to_pad_to: usize) -> Result<()> { + let tail_len = buf.len() % framing::MAX_SEGMENT_LENGTH; + + let pad_len: usize = if to_pad_to >= tail_len { + to_pad_to - tail_len + } else { + (framing::MAX_SEGMENT_LENGTH - tail_len) + to_pad_to + }; + + if pad_len > HEADER_LENGTH { + // pad_len > 19 + Ok(framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + pad_len - HEADER_LENGTH, + )?) + } else if pad_len > 0 { + framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + framing::MAX_MESSAGE_PAYLOAD_LENGTH, + )?; + // } else { + Ok(framing::build_and_marshall( + buf, + framing::MessageTypes::Payload.into(), + vec![], + pad_len, + )?) + } else { + Ok(()) + } +} + +/* +/// +/// Off: +/// pad burst = send max-frame-length frames while available, pad the last with +/// send with no delay +/// [ msg ] +/// [ max-pkt ][ max-pkt ][ max-pkt ][ max-pkt ][pkt]{pad} +/// Enabled: +/// pad burst = send max-frame-length frames while available, pad the last with +/// send with sampled delay +/// [ msg ] +/// [ max-pkt ]... [ max-pkt ]. [ max-pkt ].. [ max-pkt ].... [pkt]{pad} +/// Paranoid: +/// ?? +/// send with sampled delay +/// [ msg ] +/// [ max-pkt ]... [ max-pkt ]. [ max-pkt ].. [ max-pkt ].... [pkt]{pad} +fn split_and_pad(iat: IAT) { + // Send maximum sized frames. while they are available + let payload_chunks = b.chunks(MAX_MESSAGE_PAYLOAD_LENGTH); + + match iat { + IAT::Off => {} + IAT::Enabled => {} + IAT::Paranoid => {} + } +} + + + if conn.iatMode != iatParanoid { + // For non-paranoid IAT, pad once per burst. Paranoid IAT handles + // things differently. + if err = conn.padBurst(&frameBuf, conn.lenDist.Sample()); err != nil { + return 0, err + } + } + + // Write the pending data onto the network. Partial writes are fatal, + // because the frame encoder state is advanced, and the code doesn't keep + // frameBuf around. In theory, write timeouts and whatnot could be + // supported if this wasn't the case, but that complicates the code. + if conn.iatMode != iatNone { + var iatFrame [framing.MaximumSegmentLength]byte + for frameBuf.Len() > 0 { + iatWrLen := 0 + + switch conn.iatMode { + case iatEnabled: + // Standard (ScrambleSuit-style) IAT obfuscation optimizes for + // bulk transport and will write ~MTU sized frames when + // possible. + iatWrLen, err = frameBuf.Read(iatFrame[:]) + + case iatParanoid: + // Paranoid IAT obfuscation throws performance out of the + // window and will sample the length distribution every time a + // write is scheduled. + targetLen := conn.lenDist.Sample() + if frameBuf.Len() < targetLen { + // There's not enough data buffered for the target write, + // so padding must be inserted. + if err = conn.padBurst(&frameBuf, targetLen); err != nil { + return 0, err + } + if frameBuf.Len() != targetLen { + // Ugh, padding came out to a value that required more + // than one frame, this is relatively unlikely so just + // resample since there's enough data to ensure that + // the next sample will be written. + continue + } + } + iatWrLen, err = frameBuf.Read(iatFrame[:targetLen]) + } + if err != nil { + return 0, err + } else if iatWrLen == 0 { + panic(fmt.Sprintf("BUG: Write(), iat length was 0")) + } + + // Calculate the delay. The delay resolution is 100 usec, leading + // to a maximum delay of 10 msec. + iatDelta := time.Duration(conn.iatDist.Sample() * 100) + + // Write then sleep. + _, err = conn.Conn.Write(iatFrame[:iatWrLen]) + if err != nil { + return 0, err + } + time.Sleep(iatDelta * time.Microsecond) + } + } else { + _, err = conn.Conn.Write(frameBuf.Bytes()) + } + + return +} + +/* + chopBuf := bytes.NewBuffer(b) + var payload [maxPacketPayloadLength]byte + var frameBuf bytes.Buffer + + + // Chop the pending data into payload frames. + for chopBuf.Len() > 0 { + rdLen := 0 + rdLen, err = chopBuf.Read(payload[:]) + if err != nil { + return 0, err + } else if rdLen == 0 { + panic(fmt.Sprintf("BUG: Write(), chopping length was 0")) + } + n += rdLen + + err = conn.makePacket(&frameBuf, packetTypePayload, payload[:rdLen], 0) + if err != nil { + return 0, err + } + } +*/ +*/ diff --git a/crates/obfs4/src/pt.rs b/crates/obfs4/src/pt.rs index 7645213..b1be3c3 100644 --- a/crates/obfs4/src/pt.rs +++ b/crates/obfs4/src/pt.rs @@ -213,7 +213,7 @@ where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, InErr: std::error::Error + Send + Sync + 'static, { - type OutRW = Obfs4Stream; + type OutRW = Obfs4Stream; type OutErr = Error; type Builder = crate::ClientBuilder; @@ -234,7 +234,7 @@ impl ptrs::ServerTransport for crate::Server where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type OutRW = Obfs4Stream; + type OutRW = Obfs4Stream; type OutErr = Error; type Builder = crate::ServerBuilder; diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index 0c2f3c5..070e5eb 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -304,7 +304,7 @@ impl Server { Self::new_from_key(identity_keys) } - pub async fn wrap(self, stream: T) -> Result> + pub async fn wrap(self, stream: T) -> Result where T: AsyncRead + AsyncWrite + Unpin, { @@ -349,6 +349,7 @@ impl Server { Ok(sessions::ServerSession { // fixed by server identity_keys: self.identity_keys.clone(), + iat_mode: self.iat_mode, biased: self.biased, // generated per session diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index 1703aa7..4a1cdcc 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -24,7 +24,7 @@ use bytes::BytesMut; use ptrs::{debug, info, trace}; use rand_core::RngCore; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::time::Instant; +use tokio::time::{Instant, Duration}; use tokio_util::codec::Decoder; /// Initial state for a Session, created with any params. @@ -53,7 +53,7 @@ impl Session { pub fn biased(&self) -> bool { match self { Session::Client(cs) => cs.biased, - Session::Server(ss) => ss.biased, //biased, + Session::Server(ss) => ss.biased, } } @@ -63,6 +63,24 @@ impl Session { Session::Server(ss) => ss.len_seed.clone(), } } + + pub(crate) fn get_iat_mode(&self)-> IAT { + match self { + Session::Client(cs) => cs.iat_mode, + Session::Server(ss) => ss.iat_mode, + } + } + + pub(crate) fn iat_duration_sampler(&mut self) -> fn() ->Duration { + || { + Duration::from_secs(1) + } + } + + + pub(crate) fn sample_iat_length() -> usize { + 0usize + } } // ================================================================ // @@ -171,7 +189,7 @@ impl ClientSession { self, mut stream: T, deadline: Option, - ) -> Result> + ) -> Result where T: AsyncRead + AsyncWrite + Unpin, { @@ -312,6 +330,7 @@ impl std::fmt::Debug for ClientSession { pub(crate) struct ServerSession { // fixed by server pub(crate) identity_keys: Obfs4NtorSecretKey, + pub(crate) iat_mode: IAT, pub(crate) biased: bool, // pub(crate) server: &'a Server, @@ -357,6 +376,7 @@ impl ServerSession { ServerSession { // fixed by server identity_keys: self.identity_keys, + iat_mode: self.iat_mode, biased: self.biased, // generated per session @@ -373,6 +393,7 @@ impl ServerSession { ServerSession { // fixed by server identity_keys: self.identity_keys, + iat_mode: self.iat_mode, biased: self.biased, // generated per session @@ -392,7 +413,7 @@ impl ServerSession { server: &Server, mut stream: T, deadline: Option, - ) -> Result> + ) -> Result where T: AsyncRead + AsyncWrite + Unpin, { From d6761bdfd47e8d96ee93712524a6cefc175ace67 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Sat, 4 May 2024 15:40:50 -0600 Subject: [PATCH 3/9] shifting further towards sink/stream from asyncread/write is hard --- crates/obfs4/src/proto.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 4ef1ea4..df9f0be 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -3,12 +3,12 @@ use crate::{ delay, drbg, probdist::{self, WeightedDist} }, constants::*, - framing::{self, Message}, + framing, sessions::Session, Error, Result, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BytesMut, Bytes}; use futures::{sink::Sink, stream::Stream}; use pin_project::pin_project; use ptrs::trace; @@ -103,7 +103,7 @@ impl Obfs4Stream { pub(crate) struct O4Stream { #[pin] // pub stream: Framed, - pub stream: Box>, + pub stream: Box + Send + Unpin>, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -121,7 +121,7 @@ impl O4Stream { where T: AsyncRead + AsyncWrite + Unpin, { - let stream: Box> = match session.get_iat_mode() { + let stream: Box+Send+Unpin> = match session.get_iat_mode() { IAT::Off => Box::new(Framed::new(inner, codec)), IAT::Enabled | IAT::Paranoid => { let f = Framed::new(inner, codec); @@ -296,7 +296,7 @@ impl AsyncWrite for Obfs4Stream { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - this.s.poll_flush(cx) + Sink::poll_flush(this.s, cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -318,19 +318,19 @@ impl AsyncRead for Obfs4Stream { impl Sink for O4Stream { type Error = (); - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } - fn start_send(self: Pin<&mut Self>, item: Messages) -> StdResult<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, _item: Messages) -> StdResult<(), Self::Error> { todo!(); } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } } From 2ccf0e17a8d03af79180d8d88d83ebca4accfd93 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Mon, 27 May 2024 23:09:02 -0600 Subject: [PATCH 4/9] progress? sinks and streams are weird --- crates/obfs4/src/client.rs | 4 +- crates/obfs4/src/common/delay/mod.rs | 8 +-- crates/obfs4/src/proto.rs | 76 ++++++++++++++++++---------- crates/obfs4/src/server.rs | 2 +- crates/obfs4/src/sessions.rs | 4 +- 5 files changed, 59 insertions(+), 35 deletions(-) diff --git a/crates/obfs4/src/client.rs b/crates/obfs4/src/client.rs index 6be9f0a..a13cb11 100644 --- a/crates/obfs4/src/client.rs +++ b/crates/obfs4/src/client.rs @@ -142,7 +142,7 @@ impl Client { /// handshake timeout and then close the connection. pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'a, { let session = sessions::new_client_session(self.station_pubkey, self.iat_mode); @@ -158,7 +158,7 @@ impl Client { mut stream_fut: Pin>, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'a, E: std::error::Error + Send + Sync + 'static, { let stream = stream_fut.await.map_err(|e| Error::Other(Box::new(e)))?; diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs index e4f6367..8774d3e 100644 --- a/crates/obfs4/src/common/delay/mod.rs +++ b/crates/obfs4/src/common/delay/mod.rs @@ -56,9 +56,11 @@ impl> Sink for DelayedSink { let delay = (*s.delay_fn)(); - s.sleep - .as_mut() - .reset(Instant::now() + delay); + if delay.is_zero() { + s.sleep + .as_mut() + .reset(Instant::now() + delay); + } Ok(()) } diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index df9f0be..55bdb94 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -8,8 +8,8 @@ use crate::{ Error, Result, }; -use bytes::{Buf, BytesMut, Bytes}; -use futures::{sink::Sink, stream::Stream}; +use bytes::{Buf, BytesMut}; +use futures::{sink::Sink, stream::{Stream, StreamExt}}; use pin_project::pin_project; use ptrs::trace; use sha2::{Digest, Sha256}; @@ -35,6 +35,8 @@ pub enum IAT { Paranoid, } +pub trait Transport: Sink + Stream + Unpin + Send {} + #[derive(Debug, Clone)] pub(crate) enum MaybeTimeout { Default_, @@ -91,7 +93,7 @@ pub struct Obfs4Stream { } impl Obfs4Stream { - pub(crate) fn from_o4(o4: O4Stream) -> Self { + pub(crate) fn from_o4(o4: O4Stream<>) -> Self { Obfs4Stream { // s: Arc::new(Mutex::new(o4)), s: o4, @@ -100,10 +102,13 @@ impl Obfs4Stream { } #[pin_project] -pub(crate) struct O4Stream { +pub(crate) struct O4Stream{ #[pin] // pub stream: Framed, - pub stream: Box + Send + Unpin>, + // pub stream: Box>, + pub stream: Box + Send + Unpin>, + #[pin] + pub sink: Box + Send + Unpin>, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -116,18 +121,21 @@ impl O4Stream { // inner: &'a mut dyn Stream<'a>, inner: T, codec: framing::Obfs4Codec, - session: Session, + mut session: Session, ) -> O4Stream where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send, { - let stream: Box+Send+Unpin> = match session.get_iat_mode() { - IAT::Off => Box::new(Framed::new(inner, codec)), - IAT::Enabled | IAT::Paranoid => { - let f = Framed::new(inner, codec); - Box::new(delay::DelayedSink::new(f, session.iat_duration_sampler())) - } + let delay_fn = match session.get_iat_mode() { + IAT::Off => || Duration::ZERO, + IAT::Enabled | IAT::Paranoid => session.iat_duration_sampler(), }; + let (sink, stream) = Framed::new(inner, codec).split(); + let sink = delay::DelayedSink::new(sink, delay_fn); + + let sink: Box + Send + Unpin> = Box::new(sink); + let stream: Box + Send + Unpin> = Box::new(stream); + let len_seed = session.len_seed(); let mut hasher = Sha256::new(); @@ -150,6 +158,7 @@ impl O4Stream { ); Self { + sink, stream, session, length_dist, @@ -179,8 +188,9 @@ impl AsyncWrite for O4Stream { let mut this = self.as_mut().project(); // determine if the stream is ready to send an event? - if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending { - return Poll::Pending; + match futures::Sink::::poll_ready(this.sink.as_mut(), cx) { + Poll::Pending => return Poll::Pending, + _ => {} } // while we have bytes in the buffer write MAX_MESSAGE_PAYLOAD_LENGTH @@ -202,8 +212,9 @@ impl AsyncWrite for O4Stream { out_buf.clear(); // determine if the stream is ready to send more data. if not back off - if futures::Sink::<&[u8]>::poll_ready(this.stream.as_mut(), cx) == Poll::Pending { - return Poll::Ready(Ok(len_sent)); + match futures::Sink::::poll_ready(this.sink.as_mut(), cx) { + Poll::Pending => return Poll::Ready(Ok(len_sent)), + _ => {} } } @@ -211,7 +222,7 @@ impl AsyncWrite for O4Stream { let mut out_buf = BytesMut::new(); payload.marshall(&mut out_buf)?; - this.stream.as_mut().start_send(out_buf)?; + this.sink.as_mut().start_send(out_buf)?; Poll::Ready(Ok(msg_len)) } @@ -219,7 +230,7 @@ impl AsyncWrite for O4Stream { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{} flushing", self.session.id()); let mut this = self.project(); - match futures::Sink::<&[u8]>::poll_flush(this.stream.as_mut(), cx) { + match futures::Sink::::poll_flush(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, @@ -229,7 +240,7 @@ impl AsyncWrite for O4Stream { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { trace!("{} shutting down", self.session.id()); let mut this = self.project(); - match futures::Sink::<&[u8]>::poll_close(this.stream.as_mut(), cx) { + match futures::Sink::::poll_close(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, @@ -260,10 +271,7 @@ impl AsyncRead for O4Stream { return Poll::Ready(Ok(())); } - match res.unwrap() { - Ok(m) => m, - Err(e) => Err(e)?, - } + res.unwrap() } } }; @@ -316,13 +324,13 @@ impl AsyncRead for Obfs4Stream { } } -impl Sink for O4Stream { - type Error = (); +impl Sink for O4Stream { + type Error = IoError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } - fn start_send(self: Pin<&mut Self>, _item: Messages) -> StdResult<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, _item: BytesMut) -> StdResult<(), Self::Error> { todo!(); } @@ -335,6 +343,20 @@ impl Sink for O4Stream { } } +impl Stream for O4Stream { + type Item = Messages; + + // Required method + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + todo!(); + } +} + +impl Transport for O4Stream {} + // TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl) /// Attempts to pad a burst of data so that the last [`Message`] is of the length diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index 070e5eb..fa35639 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -306,7 +306,7 @@ impl Server { pub async fn wrap(self, stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send, { let session = self.new_server_session()?; let deadline = self.handshake_timeout.map(|d| Instant::now() + d); diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index 4a1cdcc..596a4d7 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -191,7 +191,7 @@ impl ClientSession { deadline: Option, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send, { // set up for handshake let mut session = self.transition(ClientHandshaking {}); @@ -415,7 +415,7 @@ impl ServerSession { deadline: Option, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + Send, { // set up for handshake let mut session = self.transition(ServerHandshaking {}); From 61e66a0f9356e80e1f941b96a9924f8b255566c6 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Mon, 1 Jul 2024 08:06:22 -0600 Subject: [PATCH 5/9] progress commit --- crates/obfs4/src/common/delay/mod.rs | 6 +- crates/obfs4/src/common/drbg.rs | 1 - crates/obfs4/src/common/mod.rs | 2 +- crates/obfs4/src/proto.rs | 99 +++++++++++++++++----------- crates/obfs4/src/sessions.rs | 17 ++--- 5 files changed, 70 insertions(+), 55 deletions(-) diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs index 8774d3e..5667392 100644 --- a/crates/obfs4/src/common/delay/mod.rs +++ b/crates/obfs4/src/common/delay/mod.rs @@ -57,9 +57,7 @@ impl> Sink for DelayedSink { let delay = (*s.delay_fn)(); if delay.is_zero() { - s.sleep - .as_mut() - .reset(Instant::now() + delay); + s.sleep.as_mut().reset(Instant::now() + delay); } Ok(()) } @@ -77,8 +75,8 @@ impl> Sink for DelayedSink { mod testing { use super::*; use futures::sink::{self, SinkExt}; + use rand_distr::{Distribution, Normal}; use std::time::Instant; - use rand_distr::{Normal, Distribution}; #[tokio::test] async fn delay_sink() { diff --git a/crates/obfs4/src/common/drbg.rs b/crates/obfs4/src/common/drbg.rs index d38fe44..92c98fa 100644 --- a/crates/obfs4/src/common/drbg.rs +++ b/crates/obfs4/src/common/drbg.rs @@ -207,7 +207,6 @@ impl RngCore for Drbg { } } - #[cfg(test)] mod test { use super::*; diff --git a/crates/obfs4/src/common/mod.rs b/crates/obfs4/src/common/mod.rs index ea4b3f9..ca19515 100644 --- a/crates/obfs4/src/common/mod.rs +++ b/crates/obfs4/src/common/mod.rs @@ -10,8 +10,8 @@ pub(crate) mod kdf; mod skip; pub use skip::discard; -pub mod drbg; pub(crate) mod delay; +pub mod drbg; pub mod ntor_arti; pub mod probdist; diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 55bdb94..612db7f 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -1,6 +1,8 @@ use crate::{ common::{ - delay, drbg, probdist::{self, WeightedDist} + delay::{self, DelayedSink}, + drbg, + probdist::{self, WeightedDist}, }, constants::*, framing, @@ -9,7 +11,10 @@ use crate::{ }; use bytes::{Buf, BytesMut}; -use futures::{sink::Sink, stream::{Stream, StreamExt}}; +use futures::{ + sink::Sink, + stream::{Stream, StreamExt}, +}; use pin_project::pin_project; use ptrs::trace; use sha2::{Digest, Sha256}; @@ -35,8 +40,6 @@ pub enum IAT { Paranoid, } -pub trait Transport: Sink + Stream + Unpin + Send {} - #[derive(Debug, Clone)] pub(crate) enum MaybeTimeout { Default_, @@ -45,6 +48,9 @@ pub(crate) enum MaybeTimeout { Unset, } +type MsgStream = Box> + Send + Unpin>; +type BytesSink = Box + Send + Unpin>; + impl std::str::FromStr for IAT { type Err = Error; fn from_str(s: &str) -> StdResult { @@ -86,14 +92,20 @@ impl MaybeTimeout { } #[pin_project] -pub struct Obfs4Stream { +pub struct Obfs4Stream { // s: Arc>>, #[pin] - s: O4Stream, + s: O4Stream, } -impl Obfs4Stream { - pub(crate) fn from_o4(o4: O4Stream<>) -> Self { +impl Obfs4Stream +where + E: From, +{ + pub(crate) fn from_o4(o4: O4Stream) -> Self + where + E: From, + { Obfs4Stream { // s: Arc::new(Mutex::new(o4)), s: o4, @@ -102,13 +114,13 @@ impl Obfs4Stream { } #[pin_project] -pub(crate) struct O4Stream{ +pub(crate) struct O4Stream { #[pin] // pub stream: Framed, // pub stream: Box>, - pub stream: Box + Send + Unpin>, + pub stream: MsgStream, #[pin] - pub sink: Box + Send + Unpin>, + pub sink: BytesSink, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -116,13 +128,16 @@ pub(crate) struct O4Stream{ pub session: Session, } -impl O4Stream { +impl O4Stream +where + E: From, +{ pub(crate) fn new( // inner: &'a mut dyn Stream<'a>, inner: T, codec: framing::Obfs4Codec, mut session: Session, - ) -> O4Stream + ) -> O4Stream where T: AsyncRead + AsyncWrite + Unpin + Send, { @@ -131,10 +146,10 @@ impl O4Stream { IAT::Enabled | IAT::Paranoid => session.iat_duration_sampler(), }; let (sink, stream) = Framed::new(inner, codec).split(); - let sink = delay::DelayedSink::new(sink, delay_fn); + let sink = &delay::DelayedSink::new(sink, delay_fn); - let sink: Box + Send + Unpin> = Box::new(sink); - let stream: Box + Send + Unpin> = Box::new(stream); + let sink: BytesSink = Box::new(sink); + let stream: MsgStream = Box::new(stream); let len_seed = session.len_seed(); @@ -177,8 +192,10 @@ impl O4Stream { } } - -impl AsyncWrite for O4Stream { +impl AsyncWrite for O4Stream +where + E: From, +{ fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -197,8 +214,8 @@ impl AsyncWrite for O4Stream { // chunks until we have less than that amount left. // TODO: asyncwrite - apply length_dist instead of just full payloads let mut len_sent: usize = 0; - let mut out_buf = BytesMut::with_capacity(framing::MAX_MESSAGE_PAYLOAD_LENGTH); while msg_len - len_sent > framing::MAX_MESSAGE_PAYLOAD_LENGTH { + let mut out_buf = BytesMut::with_capacity(framing::MAX_MESSAGE_PAYLOAD_LENGTH); // package one chunk of the mesage as a payload let payload = framing::Messages::Payload( buf[len_sent..len_sent + framing::MAX_MESSAGE_PAYLOAD_LENGTH].to_vec(), @@ -206,10 +223,9 @@ impl AsyncWrite for O4Stream { // send the marshalled payload payload.marshall(&mut out_buf)?; - this.stream.as_mut().start_send(&mut out_buf)?; + this.sink.as_mut().start_send(out_buf)?; len_sent += framing::MAX_MESSAGE_PAYLOAD_LENGTH; - out_buf.clear(); // determine if the stream is ready to send more data. if not back off match futures::Sink::::poll_ready(this.sink.as_mut(), cx) { @@ -222,7 +238,7 @@ impl AsyncWrite for O4Stream { let mut out_buf = BytesMut::new(); payload.marshall(&mut out_buf)?; - this.sink.as_mut().start_send(out_buf)?; + this.sink.as_mut().start_send(out_buf).into()?; Poll::Ready(Ok(msg_len)) } @@ -232,7 +248,7 @@ impl AsyncWrite for O4Stream { let mut this = self.project(); match futures::Sink::::poll_flush(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } } @@ -242,13 +258,16 @@ impl AsyncWrite for O4Stream { let mut this = self.project(); match futures::Sink::::poll_close(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, } } } -impl AsyncRead for O4Stream { +impl AsyncRead for O4Stream +where + E: From, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -271,7 +290,10 @@ impl AsyncRead for O4Stream { return Poll::Ready(Ok(())); } - res.unwrap() + match res.unwrap() { + Ok(m) => m, + Err(e) => Err(e)?, + } } } }; @@ -313,7 +335,10 @@ impl AsyncWrite for Obfs4Stream { } } -impl AsyncRead for Obfs4Stream { +impl AsyncRead for Obfs4Stream +where + E: From, +{ fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -324,8 +349,11 @@ impl AsyncRead for Obfs4Stream { } } -impl Sink for O4Stream { - type Error = IoError; +impl Sink for O4Stream +where + E: From, +{ + type Error = E; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } @@ -343,21 +371,18 @@ impl Sink for O4Stream { } } -impl Stream for O4Stream { +impl Stream for O4Stream +where + E: From, +{ type Item = Messages; // Required method - fn poll_next( - self: Pin<&mut Self>, - _cx: &mut Context<'_> - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } } -impl Transport for O4Stream {} - - // TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl) /// Attempts to pad a burst of data so that the last [`Message`] is of the length /// `to_pad_to`. This can involve creating multiple packets, making this diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index 596a4d7..181ff49 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -24,7 +24,7 @@ use bytes::BytesMut; use ptrs::{debug, info, trace}; use rand_core::RngCore; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::time::{Instant, Duration}; +use tokio::time::{Duration, Instant}; use tokio_util::codec::Decoder; /// Initial state for a Session, created with any params. @@ -64,20 +64,17 @@ impl Session { } } - pub(crate) fn get_iat_mode(&self)-> IAT { + pub(crate) fn get_iat_mode(&self) -> IAT { match self { Session::Client(cs) => cs.iat_mode, Session::Server(ss) => ss.iat_mode, } } - pub(crate) fn iat_duration_sampler(&mut self) -> fn() ->Duration { - || { - Duration::from_secs(1) - } + pub(crate) fn iat_duration_sampler(&mut self) -> fn() -> Duration { + || Duration::from_secs(1) } - pub(crate) fn sample_iat_length() -> usize { 0usize } @@ -185,11 +182,7 @@ impl ClientSession { /// TODO: make sure failure modes align with golang obfs4 /// - FIN/RST based on buffered data. /// - etc. - pub async fn handshake( - self, - mut stream: T, - deadline: Option, - ) -> Result + pub async fn handshake(self, mut stream: T, deadline: Option) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send, { From be3d507e29c207c98c12aea290c7cc46d5c13419 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Wed, 3 Jul 2024 13:16:10 -0600 Subject: [PATCH 6/9] progress commit - getting closer --- crates/obfs4/src/common/delay/mod.rs | 15 ++- crates/obfs4/src/framing/codecs.rs | 90 ++++++++++++------ crates/obfs4/src/framing/testing.rs | 22 +++-- crates/obfs4/src/proto.rs | 134 ++++++++++++--------------- crates/obfs4/src/pt.rs | 48 +++++----- crates/obfs4/src/server.rs | 1 + 6 files changed, 178 insertions(+), 132 deletions(-) diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs index 5667392..b66acc1 100644 --- a/crates/obfs4/src/common/delay/mod.rs +++ b/crates/obfs4/src/common/delay/mod.rs @@ -11,7 +11,7 @@ use pin_project::pin_project; type DurationFn = fn() -> Duration; #[pin_project] -pub struct DelayedSink { +pub struct DelayedSink { // #[pin] // sink: Si, // #[pin] @@ -20,9 +20,10 @@ pub struct DelayedSink { sleep: Pin>, delay_fn: DurationFn, _item: PhantomData, + _error: PhantomData, } -impl> DelayedSink { +impl> DelayedSink { pub fn new(sink: Si, delay_fn: DurationFn) -> Self { let delay = delay_fn(); let sleep = tokio::time::sleep(delay); @@ -33,11 +34,15 @@ impl> DelayedSink { sleep: Box::pin(sleep), delay_fn, _item: PhantomData {}, + _error: PhantomData {}, } } } -impl> Sink for DelayedSink { +impl> Sink for DelayedSink +where + J: Into, +{ type Error = Si::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -48,9 +53,9 @@ impl> Sink for DelayedSink { } } - fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: J) -> Result<(), Self::Error> { let s = self.project(); - if let Err(e) = s.sink.as_mut().start_send(item) { + if let Err(e) = s.sink.as_mut().start_send(item.into()) { return Err(e); } diff --git a/crates/obfs4/src/framing/codecs.rs b/crates/obfs4/src/framing/codecs.rs index f37e9c5..cafdd19 100644 --- a/crates/obfs4/src/framing/codecs.rs +++ b/crates/obfs4/src/framing/codecs.rs @@ -2,6 +2,7 @@ use crate::{ common::drbg::{self, Drbg, Seed}, constants::MESSAGE_OVERHEAD, framing::{FrameError, Messages}, + Error, }; use bytes::{Buf, BufMut, BytesMut}; @@ -69,6 +70,23 @@ impl EncryptingCodec { pub(crate) fn handshake_complete(&mut self) { self.handshake_complete = true; } + + pub(crate) fn to_parts(self) -> (EncryptingEncoder, EncryptingDecoder) { + (self.encoder, self.decoder) + } + + pub(crate) fn from_parts( + e: EncryptingEncoder, + d: EncryptingDecoder, + hs_complete: bool, + ) -> Self { + Self { + // key, + encoder: e, + decoder: d, + handshake_complete: hs_complete, + } + } } ///Decoder is a frame decoder instance. @@ -106,8 +124,19 @@ impl EncryptingDecoder { impl Decoder for EncryptingCodec { type Item = Messages; - type Error = FrameError; + type Error = Error; + + fn decode( + &mut self, + src: &mut BytesMut, + ) -> std::result::Result, Self::Error> { + self.decoder.decode(src) + } +} +impl Decoder for EncryptingDecoder { + type Item = Messages; + type Error = Error; // Decode decodes a stream of data and returns the length if any. ErrAgain is // a temporary failure, all other errors MUST be treated as fatal and the // session aborted. @@ -118,28 +147,28 @@ impl Decoder for EncryptingCodec { trace!( "decoding src:{}B {} {}", src.remaining(), - self.decoder.next_length, - self.decoder.next_length_invalid + self.next_length, + self.next_length_invalid ); // A length of 0 indicates that we do not know the expected size of // the next frame. we use this to store the length of a packet when we // receive the length at the beginning, but not the whole packet, since // future reads may not have the who packet (including length) available - if self.decoder.next_length == 0 { + if self.next_length == 0 { // Attempt to pull out the next frame length if LENGTH_LENGTH > src.remaining() { return Ok(None); } // derive the nonce that the peer would have used - self.decoder.next_nonce = self.decoder.nonce.next()?; + self.next_nonce = self.nonce.next()?; // Remove the field length from the buffer // let mut len_buf: [u8; LENGTH_LENGTH] = src[..LENGTH_LENGTH].try_into().unwrap(); let mut length = src.get_u16(); // De-obfuscate the length field - let length_mask = self.decoder.drbg.length_mask(); + let length_mask = self.drbg.length_mask(); trace!( "decoding {length:04x}^{length_mask:04x} {:04x}B", length ^ length_mask @@ -158,35 +187,35 @@ impl Decoder for EncryptingCodec { // paper. let invalid_length = length; - self.decoder.next_length_invalid = true; + self.next_length_invalid = true; length = rand::thread_rng().gen::() % (MAX_FRAME_LENGTH - MIN_FRAME_LENGTH) as u16 + MIN_FRAME_LENGTH as u16; error!( "invalid length {invalid_length} {length} {}", - self.decoder.next_length_invalid + self.next_length_invalid ); } - self.decoder.next_length = length; + self.next_length = length; } - let next_len = self.decoder.next_length as usize; + let next_len = self.next_length as usize; if next_len > src.len() { // The full string has not yet arrived. // // We reserve more space in the buffer. This is not strictly // necessary, but is a good idea performance-wise. - if !self.decoder.next_length_invalid { + if !self.next_length_invalid { src.reserve(next_len - src.len()); } trace!( "next_len > src.len --> reading more {} {}", - self.decoder.next_length, - self.decoder.next_length_invalid + self.next_length, + self.next_length_invalid ); // We inform the Framed that we need more bytes to form the next @@ -198,25 +227,25 @@ impl Decoder for EncryptingCodec { let data = src.get(..next_len).unwrap().to_vec(); // Unseal the frame - let key = GenericArray::from_slice(&self.decoder.key); + let key = GenericArray::from_slice(&self.key); let cipher = XSalsa20Poly1305::new(key); - let nonce = GenericArray::from_slice(&self.decoder.next_nonce); // unique per message + let nonce = GenericArray::from_slice(&self.next_nonce); // unique per message let res = cipher.decrypt(nonce, data.as_ref()); if res.is_err() { let e = res.unwrap_err(); trace!("failed to decrypt result: {e}"); - return Err(e.into()); + return Err(Error::Obfs4Framing(FrameError::from(e))); } - let plaintext = res?; + let plaintext = res.map_err(|e| Error::Obfs4Framing(FrameError::from(e)))?; if plaintext.len() < MESSAGE_OVERHEAD { - return Err(FrameError::InvalidMessage); + return Err(Error::Obfs4Framing(FrameError::InvalidMessage)); } // Clean up and prepare for the next frame // // we read a whole frame, we no longer know the size of the next pkt - self.decoder.next_length = 0; + self.next_length = 0; src.advance(next_len); debug!("decoding {next_len}B src:{}B", src.remaining()); @@ -224,7 +253,7 @@ impl Decoder for EncryptingCodec { Ok(Messages::Padding(_)) => Ok(None), Ok(m) => Ok(Some(m)), Err(FrameError::UnknownMessageType(_)) => Ok(None), - Err(e) => Err(e), + Err(e) => Err(Error::Obfs4Framing(e)), } } } @@ -255,8 +284,15 @@ impl EncryptingEncoder { } impl Encoder for EncryptingCodec { - type Error = FrameError; + type Error = Error; + + fn encode(&mut self, plaintext: T, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> { + self.encoder.encode(plaintext, dst) + } +} +impl Encoder for EncryptingEncoder { + type Error = Error; /// Encode encodes a single frame worth of payload and returns. Plaintext /// should either be a handshake message OR a buffer containing one or more /// [`Message`]s already properly marshalled. The proided plaintext can @@ -272,7 +308,7 @@ impl Encoder for EncryptingCodec { // Don't send a frame if it is longer than the other end will accept. if plaintext.remaining() > MAX_FRAME_PAYLOAD_LENGTH { - return Err(FrameError::InvalidPayloadLength(plaintext.remaining())); + return Err(FrameError::InvalidPayloadLength(plaintext.remaining()).into()); } let mut plaintext_frame = BytesMut::new(); @@ -280,18 +316,20 @@ impl Encoder for EncryptingCodec { plaintext_frame.put(plaintext); // Generate a new nonce - let nonce_bytes = self.encoder.nonce.next()?; + let nonce_bytes = self.nonce.next()?; // Encrypt and MAC payload - let key = GenericArray::from_slice(&self.encoder.key); + let key = GenericArray::from_slice(&self.key); let cipher = XSalsa20Poly1305::new(key); let nonce = GenericArray::from_slice(&nonce_bytes); // unique per message - let ciphertext = cipher.encrypt(nonce, plaintext_frame.as_ref())?; + let ciphertext = cipher + .encrypt(nonce, plaintext_frame.as_ref()) + .map_err(|e| Error::Obfs4Framing(FrameError::Crypto(e)))?; // Obfuscate the length let mut length = ciphertext.len() as u16; - let length_mask: u16 = self.encoder.drbg.length_mask(); + let length_mask: u16 = self.drbg.length_mask(); debug!( "encoding➡️ {length}B, {length:04x}^{length_mask:04x} {:04x}", length ^ length_mask diff --git a/crates/obfs4/src/framing/testing.rs b/crates/obfs4/src/framing/testing.rs index 516fc9c..9f163e4 100644 --- a/crates/obfs4/src/framing/testing.rs +++ b/crates/obfs4/src/framing/testing.rs @@ -10,7 +10,7 @@ /// use super::*; use crate::test_utils::init_subscriber; -use crate::Result; +use crate::{Result, Error}; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; @@ -67,11 +67,21 @@ async fn oversized_flow() -> Result<()> { let mut src = Bytes::from(oversized_messsage); let res = codec.encode(&mut src, &mut b); - assert_eq!( - res.unwrap_err(), - FrameError::InvalidPayloadLength(frame_len) - ); - Ok(()) + let e = res.unwrap_err(); + assert!(matches!( + e, + Error::Obfs4Framing(FrameError::InvalidPayloadLength(_)) + )); + match e { + Error::Obfs4Framing(FrameError::InvalidPayloadLength(f)) => { + if f == frame_len { + Ok(()) + } else { + panic!("expected frame_length {}, got {}", frame_len, f); + } + } + _ => panic!("expected InvalidPayloadLength, got {}", e), + } } #[tokio::test] diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 612db7f..df986db 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -1,11 +1,10 @@ use crate::{ common::{ - delay::{self, DelayedSink}, - drbg, + delay, drbg, probdist::{self, WeightedDist}, }, constants::*, - framing, + framing::{self, FrameError, Messages, Obfs4Codec}, sessions::Session, Error, Result, }; @@ -13,14 +12,14 @@ use crate::{ use bytes::{Buf, BytesMut}; use futures::{ sink::Sink, - stream::{Stream, StreamExt}, + stream::Stream, // StreamExt}, }; use pin_project::pin_project; use ptrs::trace; use sha2::{Digest, Sha256}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::time::{Duration, Instant}; -use tokio_util::codec::Framed; +use tokio_util::codec::{FramedRead, FramedWrite}; use std::{ io::Error as IoError, @@ -29,8 +28,6 @@ use std::{ task::{Context, Poll}, }; -use super::framing::{FrameError, Messages}; - #[allow(dead_code, unused)] #[derive(Default, Debug, Clone, Copy, PartialEq)] pub enum IAT { @@ -48,8 +45,8 @@ pub(crate) enum MaybeTimeout { Unset, } -type MsgStream = Box> + Send + Unpin>; -type BytesSink = Box + Send + Unpin>; +type MsgStream = Box> + Send + Unpin>; +type BytesSink = Box + Send + Unpin>; impl std::str::FromStr for IAT { type Err = Error; @@ -91,21 +88,20 @@ impl MaybeTimeout { } } +pub trait SinkStream { + type Item; + type Error; +} + #[pin_project] -pub struct Obfs4Stream { +pub struct Obfs4Stream { // s: Arc>>, #[pin] - s: O4Stream, + s: O4Stream, } -impl Obfs4Stream -where - E: From, -{ - pub(crate) fn from_o4(o4: O4Stream) -> Self - where - E: From, - { +impl Obfs4Stream { + pub(crate) fn from_o4(o4: O4Stream) -> Self { Obfs4Stream { // s: Arc::new(Mutex::new(o4)), s: o4, @@ -114,13 +110,13 @@ where } #[pin_project] -pub(crate) struct O4Stream { +pub(crate) struct O4Stream { #[pin] // pub stream: Framed, // pub stream: Box>, pub stream: MsgStream, #[pin] - pub sink: BytesSink, + pub sink: BytesSink, pub length_dist: probdist::WeightedDist, pub iat_dist: probdist::WeightedDist, @@ -128,28 +124,32 @@ pub(crate) struct O4Stream { pub session: Session, } -impl O4Stream -where - E: From, -{ - pub(crate) fn new( +impl O4Stream { + pub(crate) fn new<'a, T>( // inner: &'a mut dyn Stream<'a>, inner: T, - codec: framing::Obfs4Codec, + mut codec: Obfs4Codec, mut session: Session, - ) -> O4Stream + ) -> Self where - T: AsyncRead + AsyncWrite + Unpin + Send, + T: AsyncRead + AsyncWrite + Unpin + Send + 'a, { let delay_fn = match session.get_iat_mode() { IAT::Off => || Duration::ZERO, IAT::Enabled | IAT::Paranoid => session.iat_duration_sampler(), }; - let (sink, stream) = Framed::new(inner, codec).split(); - let sink = &delay::DelayedSink::new(sink, delay_fn); - let sink: BytesSink = Box::new(sink); - let stream: MsgStream = Box::new(stream); + let (r, w) = tokio::io::split(inner); + let (e, d) = codec.to_parts(); + let encoding_sink = FramedWrite::new(w, e); + let sink = Box::new(&mut delay::DelayedSink::< + FramedWrite, EncryptingEncoder>, + Item, + Error, + >::new(encoding_sink, delay_fn)); + + let decoded_stream = FramedRead::new(r, d); + let stream: MsgStream = Box::new(decoded_stream); let len_seed = session.len_seed(); @@ -180,22 +180,9 @@ where iat_dist, } } - - pub(crate) fn try_handle_non_payload_message(&mut self, msg: framing::Messages) -> Result<()> { - match msg { - Messages::Payload(_) => Err(FrameError::InvalidMessage.into()), - Messages::Padding(_) => Ok(()), - - // TODO: Handle other Messages - _ => Ok(()), - } - } } -impl AsyncWrite for O4Stream -where - E: From, -{ +impl AsyncWrite for O4Stream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -217,7 +204,7 @@ where while msg_len - len_sent > framing::MAX_MESSAGE_PAYLOAD_LENGTH { let mut out_buf = BytesMut::with_capacity(framing::MAX_MESSAGE_PAYLOAD_LENGTH); // package one chunk of the mesage as a payload - let payload = framing::Messages::Payload( + let payload = Messages::Payload( buf[len_sent..len_sent + framing::MAX_MESSAGE_PAYLOAD_LENGTH].to_vec(), ); @@ -238,7 +225,7 @@ where let mut out_buf = BytesMut::new(); payload.marshall(&mut out_buf)?; - this.sink.as_mut().start_send(out_buf).into()?; + this.sink.as_mut().start_send(out_buf)?; Poll::Ready(Ok(msg_len)) } @@ -248,7 +235,7 @@ where let mut this = self.project(); match futures::Sink::::poll_flush(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, } } @@ -258,16 +245,13 @@ where let mut this = self.project(); match futures::Sink::::poll_close(this.sink.as_mut(), cx) { Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), Poll::Pending => Poll::Pending, } } } -impl AsyncRead for O4Stream -where - E: From, -{ +impl AsyncRead for O4Stream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -277,9 +261,9 @@ where // the network. Not all data received is guaranteed to be usable payload, // so do this in a loop until we would block on a read or an error occurs. loop { + let mut this = self.as_mut().project(); let msg = { // mutable borrow of self is dropped at the end of this block - let mut this = self.as_mut().project(); match this.stream.as_mut().poll_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(res) => { @@ -298,7 +282,7 @@ where } }; - if let framing::Messages::Payload(message) = msg { + if let Messages::Payload(message) = msg { buf.put_slice(&message); return Poll::Ready(Ok(())); } @@ -306,9 +290,12 @@ where continue; } - match self.as_mut().try_handle_non_payload_message(msg) { - Ok(_) => continue, - Err(e) => return Poll::Ready(Err(e.into())), + match msg { + Messages::Payload(_) => return Poll::Ready(Err(FrameError::InvalidMessage.into())), + Messages::Padding(_) => return Poll::Ready(Ok(())), + + // TODO: Handle other Messages + _ => return Poll::Ready(Ok(())), } } } @@ -326,7 +313,11 @@ impl AsyncWrite for Obfs4Stream { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - Sink::poll_flush(this.s, cx) + match Sink::poll_flush(this.s, cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -335,10 +326,7 @@ impl AsyncWrite for Obfs4Stream { } } -impl AsyncRead for Obfs4Stream -where - E: From, -{ +impl AsyncRead for Obfs4Stream { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -349,11 +337,8 @@ where } } -impl Sink for O4Stream -where - E: From, -{ - type Error = E; +impl Sink for O4Stream { + type Error = Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { todo!(); } @@ -371,10 +356,7 @@ where } } -impl Stream for O4Stream -where - E: From, -{ +impl Stream for O4Stream { type Item = Messages; // Required method @@ -382,6 +364,8 @@ where todo!(); } } +// +// ======================================================================== // TODO Apply pad_burst logic and IAT policy to Message assembly (probably as part of AsyncRead / AsyncWrite impl) /// Attempts to pad a burst of data so that the last [`Message`] is of the length @@ -425,6 +409,8 @@ pub(crate) fn pad_burst(buf: &mut BytesMut, to_pad_to: usize) -> Result<()> { } } +// ======================================================================== + /* /// /// Off: diff --git a/crates/obfs4/src/pt.rs b/crates/obfs4/src/pt.rs index b1be3c3..cfc6b65 100644 --- a/crates/obfs4/src/pt.rs +++ b/crates/obfs4/src/pt.rs @@ -3,6 +3,7 @@ use crate::{ handshake::Obfs4NtorPublicKey, proto::{Obfs4Stream, IAT}, Error, OBFS4_NAME, + Client, ClientBuilder, Server, ServerBuilder, }; use ptrs::{args::Args, FutureResult as F}; @@ -14,6 +15,7 @@ use std::{ time::Duration, }; +use futures::TryFutureExt; use hex::FromHex; use ptrs::trace; use tokio::{ @@ -35,32 +37,32 @@ impl ptrs::PluggableTransport for Transport where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ClientBuilder = crate::ClientBuilder; - type ServerBuilder = crate::ServerBuilder; + type ClientBuilder = ClientBuilder; + type ServerBuilder = ServerBuilder; fn name() -> String { OBFS4_NAME.into() } fn client_builder() -> >::ClientBuilder { - crate::ClientBuilder::default() + ClientBuilder::default() } fn server_builder() -> >::ServerBuilder { - crate::ServerBuilder::default() + ServerBuilder::default() } } -impl ptrs::ServerBuilder for crate::ServerBuilder +impl ptrs::ServerBuilder for ServerBuilder where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ServerPT = crate::Server; + type ServerPT = Server; type Error = Error; type Transport = Transport; fn build(&self) -> Self::ServerPT { - crate::ServerBuilder::build(self) + ServerBuilder::build(self) } fn method_name() -> String { @@ -105,11 +107,11 @@ where } } -impl ptrs::ClientBuilder for crate::ClientBuilder +impl ptrs::ClientBuilder for ClientBuilder where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - type ClientPT = crate::Client; + type ClientPT = Client; type Error = Error; type Transport = Transport; @@ -122,7 +124,7 @@ where /// **Errors** /// If a required field has not been initialized. fn build(&self) -> Self::ClientPT { - crate::ClientBuilder::build(self) + ClientBuilder::build(self) } /// Pluggable transport attempts to parse and validate options from a string, @@ -208,21 +210,21 @@ where /// Example wrapping transport that just passes the incoming connection future through /// unmodified as a proof of concept. -impl ptrs::ClientTransport for crate::Client +impl ptrs::ClientTransport for Client where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, InErr: std::error::Error + Send + Sync + 'static, { type OutRW = Obfs4Stream; type OutErr = Error; - type Builder = crate::ClientBuilder; + type Builder = ClientBuilder; fn establish(self, input: Pin>) -> Pin> { - Box::pin(crate::Client::establish(self, input)) + Box::pin(Client::establish(self, input)) } fn wrap(self, io: InRW) -> Pin> { - Box::pin(crate::Client::wrap(self, io)) + Box::pin(Client::wrap(self, io).map_err(|e| e.into())) } fn method_name() -> String { @@ -230,17 +232,21 @@ where } } -impl ptrs::ServerTransport for crate::Server +impl ptrs::ServerTransport for Server where InRW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { + // Our out read/write is an obfs4strea, that has an error of IoError. + // If an error occurs while revealing it will be returned as an [`Error`]. + // As input we require a Serverbuilder with an InRW type that implements + // the async read and write traits. type OutRW = Obfs4Stream; type OutErr = Error; - type Builder = crate::ServerBuilder; + type Builder = ServerBuilder; /// Use something that can be accessed reference (Arc, Rc, etc.) fn reveal(self, io: InRW) -> Pin> { - Box::pin(crate::Server::wrap(self, io)) + Box::pin(Server::wrap(self, io)) } fn method_name() -> String { @@ -257,18 +263,18 @@ mod test { let pt_name = >::name(); assert_eq!(pt_name, Obfs4PT::NAME); - let cb_name = >::method_name(); + let cb_name = >::method_name(); assert_eq!(cb_name, Obfs4PT::NAME); let sb_name = - as ptrs::ServerBuilder>::method_name(); + as ptrs::ServerBuilder>::method_name(); assert_eq!(sb_name, Obfs4PT::NAME); let ct_name = - >::method_name(); + >::method_name(); assert_eq!(ct_name, Obfs4PT::NAME); - let st_name = >::method_name(); + let st_name = >::method_name(); assert_eq!(st_name, Obfs4PT::NAME); } } diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index 76853a1..284cec5 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -19,6 +19,7 @@ use crate::{ use ptrs::args::Args; use std::{borrow::BorrowMut, marker::PhantomData, ops::Deref, str::FromStr, sync::Arc}; +use std::io::Error as IoError; use bytes::{Buf, BufMut, Bytes}; use hex::FromHex; From d71de9fed8d251d35eee516689cf6dec3cfb8204 Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Wed, 3 Jul 2024 22:25:34 -0600 Subject: [PATCH 7/9] progress commit - again, getting closer - dreaded static lifetimes --- crates/obfs4/src/framing/codecs.rs | 2 +- crates/obfs4/src/framing/mod.rs | 2 +- crates/obfs4/src/proto.rs | 12 +- crates/obfs4/src/server.rs | 208 ++++++++++++++++++++++++-- crates/obfs4/src/sessions.rs | 232 ++++------------------------- 5 files changed, 236 insertions(+), 220 deletions(-) diff --git a/crates/obfs4/src/framing/codecs.rs b/crates/obfs4/src/framing/codecs.rs index cafdd19..261fca0 100644 --- a/crates/obfs4/src/framing/codecs.rs +++ b/crates/obfs4/src/framing/codecs.rs @@ -259,7 +259,7 @@ impl Decoder for EncryptingDecoder { } /// Encoder is a frame encoder instance. -struct EncryptingEncoder { +pub(crate) struct EncryptingEncoder { key: [u8; KEY_LENGTH], nonce: NonceBox, drbg: Drbg, diff --git a/crates/obfs4/src/framing/mod.rs b/crates/obfs4/src/framing/mod.rs index 1157683..b744050 100644 --- a/crates/obfs4/src/framing/mod.rs +++ b/crates/obfs4/src/framing/mod.rs @@ -49,7 +49,7 @@ pub use messages_base::*; mod messages_v1; pub use messages_v1::{MessageTypes, Messages}; -mod codecs; +pub(crate) mod codecs; pub use codecs::EncryptingCodec as Obfs4Codec; pub(crate) mod handshake; diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index df986db..9af475f 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -4,7 +4,7 @@ use crate::{ probdist::{self, WeightedDist}, }, constants::*, - framing::{self, FrameError, Messages, Obfs4Codec}, + framing::{self, codecs::EncryptingEncoder, FrameError, Messages, Obfs4Codec}, sessions::Session, Error, Result, }; @@ -128,7 +128,7 @@ impl O4Stream { pub(crate) fn new<'a, T>( // inner: &'a mut dyn Stream<'a>, inner: T, - mut codec: Obfs4Codec, + codec: Obfs4Codec, mut session: Session, ) -> Self where @@ -142,14 +142,14 @@ impl O4Stream { let (r, w) = tokio::io::split(inner); let (e, d) = codec.to_parts(); let encoding_sink = FramedWrite::new(w, e); - let sink = Box::new(&mut delay::DelayedSink::< + let sink = Box::new(delay::DelayedSink::< FramedWrite, EncryptingEncoder>, - Item, + BytesMut, Error, >::new(encoding_sink, delay_fn)); - let decoded_stream = FramedRead::new(r, d); - let stream: MsgStream = Box::new(decoded_stream); + let decoding_stream = FramedRead::new(r, d); + let stream: MsgStream = Box::new(decoding_stream); let len_seed = session.len_seed(); diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index 284cec5..8f9a171 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -4,27 +4,28 @@ use super::*; use crate::{ client::ClientBuilder, common::{ - colorize, drbg, + colorize, drbg, discard, replay_filter::{self, ReplayFilter}, x25519_elligator2::{PublicKey, StaticSecret}, HmacSha256, + ntor_arti::{ServerHandshake, RelayHandshakeError}, }, constants::*, framing::{FrameError, Marshall, Obfs4Codec, TryParse, KEY_LENGTH}, - handshake::{Obfs4NtorPublicKey, Obfs4NtorSecretKey}, - proto::{MaybeTimeout, Obfs4Stream, IAT}, - sessions::Session, + handshake::{Obfs4NtorPublicKey, Obfs4NtorSecretKey, Obfs4Keygen, SHSMaterials}, + proto::{MaybeTimeout, O4Stream, Obfs4Stream, IAT}, + sessions::{Session, Initialized, Established, Fault}, Error, Result, }; use ptrs::args::Args; +use ptrs::{debug, info, trace}; use std::{borrow::BorrowMut, marker::PhantomData, ops::Deref, str::FromStr, sync::Arc}; -use std::io::Error as IoError; +use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use bytes::{Buf, BufMut, Bytes}; use hex::FromHex; use hmac::{Hmac, Mac}; -use ptrs::{debug, info}; use rand::prelude::*; use subtle::ConstantTimeEq; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -304,14 +305,63 @@ impl Server { Self::new_from_key(identity_keys) } - pub async fn wrap(self, stream: T) -> Result + // ====================================================================== // + // Server Handshake // + // ====================================================================== // + + + pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send, + T: AsyncRead + AsyncWrite + Unpin + Send + 'a, { let session = self.new_server_session()?; let deadline = self.handshake_timeout.map(|d| Instant::now() + d); - session.handshake(&self, stream, deadline).await + let hs_materials = SHSMaterials::new( + &session.identity_keys, + session.session_id(), + session.len_seed.to_bytes(), + ); + + let mut session = session.transition(ServerHandshaking {}); + + let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; + let handshake_fut = self.complete_handshake(&mut stream, hs_materials, deadline); + + let mut keygen = + match tokio::time::timeout_at(deadline.unwrap_or(d_def), handshake_fut).await { + Ok(result) => match result { + Ok(handshake) => handshake, + Err(e) => { + // non-timeout error, + let id = session.session_id(); + let _ = session.fault(ServerHandshakeFailed { + details: format!("{id} handshake failed {e}"), + }); + return Err(e); + } + }, + Err(_) => { + let id = session.session_id(); + let _ = session.fault(ServerHandshakeFailed { + details: format!("{id} timed out"), + }); + return Err(Error::HandshakeTimeout); + } + }; + + // post handshake state updates + session.set_session_id(keygen.session_id()); + let mut codec: framing::Obfs4Codec = keygen.into(); + + // mark session as Established + let session_state: ServerSession = session.transition(Established {}); + + codec.handshake_complete(); + let o4 = O4Stream::new(stream, codec, Session::Server(session_state)); + + Ok(Obfs4Stream::from_o4(o4)) + // session.handshake(&self, stream, deadline).await } // pub fn set_iat_mode(&mut self, mode: IAT) -> &Self { @@ -343,10 +393,10 @@ impl Server { pub(crate) fn new_server_session( &self, - ) -> Result> { + ) -> Result> { let mut session_id = [0u8; SESSION_ID_LEN]; rand::thread_rng().fill_bytes(&mut session_id); - Ok(sessions::ServerSession { + Ok(ServerSession { // fixed by server identity_keys: self.identity_keys.clone(), iat_mode: self.iat_mode, @@ -357,11 +407,145 @@ impl Server { len_seed: drbg::Seed::new().unwrap(), iat_seed: drbg::Seed::new().unwrap(), - _state: sessions::Initialized {}, + _state: Initialized {}, }) } + + /// Complete the handshake with the client. This function assumes that the + /// client has already sent a message and that we do not know yet if the + /// message is valid. + async fn complete_handshake( + &self, + mut stream: T, + materials: SHSMaterials, + deadline: Option, + ) -> Result + where + T: AsyncRead + AsyncWrite + Unpin, + { + let session_id = materials.session_id.clone(); + + // wait for and attempt to consume the client hello message + let mut buf = [0_u8; MAX_HANDSHAKE_LENGTH]; + loop { + let n = stream.read(&mut buf).await?; + if n == 0 { + stream.shutdown().await?; + return Err(IoError::from(IoErrorKind::UnexpectedEof).into()); + } + trace!("{} successful read {n}B", session_id); + + match self.server(&mut |_: &()| Some(()), &[materials.clone()], &buf[..n]) { + Ok((keygen, response)) => { + stream.write_all(&response).await?; + info!("{} handshake complete", session_id); + return Ok(keygen); + } + Err(RelayHandshakeError::EAgain) => { + trace!("{} reading more", session_id); + continue; + } + Err(e) => { + trace!("{} failed to parse client handshake: {e}", session_id); + // if a deadline was set and has not passed already, discard + // from the stream until the deadline, then close. + if deadline.is_some_and(|d| d > Instant::now()) { + debug!("{} discarding due to: {e}", session_id); + discard(&mut stream, deadline.unwrap() - Instant::now()).await? + } + stream.shutdown().await?; + return Err(e.into()); + } + }; + } + } + +} + +// ================================================================ // +// Server Sessions States // +// ================================================================ // + +pub(crate) struct ServerSession { + // fixed by server + pub(crate) identity_keys: Obfs4NtorSecretKey, + pub(crate) iat_mode: IAT, + pub(crate) biased: bool, + // pub(crate) server: &'a Server, + + // generated per session + pub(crate) session_id: [u8; SESSION_ID_LEN], + pub(crate) len_seed: drbg::Seed, + pub(crate) iat_seed: drbg::Seed, + + pub(crate) _state: S, +} + +pub(crate) struct ServerHandshaking {} + +#[allow(unused)] +pub(crate) struct ServerHandshakeFailed { + details: String, +} + +pub(crate) trait ServerSessionState {} +impl ServerSessionState for Initialized {} +impl ServerSessionState for ServerHandshaking {} +impl ServerSessionState for Established {} + +impl ServerSessionState for ServerHandshakeFailed {} +impl Fault for ServerHandshakeFailed {} + +impl ServerSession { + pub fn session_id(&self) -> String { + String::from("s-") + &colorize(self.session_id) + } + + pub(crate) fn set_session_id(&mut self, id: [u8; SESSION_ID_LEN]) { + debug!( + "{} -> {} server updating session id", + colorize(self.session_id), + colorize(id) + ); + self.session_id = id; + } + + /// Helper function to perform state transitions. + pub(crate) fn transition(self, _state: T) -> ServerSession { + ServerSession { + // fixed by server + identity_keys: self.identity_keys, + iat_mode: self.iat_mode, + biased: self.biased, + + // generated per session + session_id: self.session_id, + len_seed: self.len_seed, + iat_seed: self.iat_seed, + + _state, + } + } + + /// Helper function to perform state transition on error. + pub(crate) fn fault(self, f: F) -> ServerSession { + ServerSession { + // fixed by server + identity_keys: self.identity_keys, + iat_mode: self.iat_mode, + biased: self.biased, + + // generated per session + session_id: self.session_id, + len_seed: self.len_seed, + iat_seed: self.iat_seed, + + _state: f, + } + } } + #[cfg(test)] mod tests { use crate::dev; diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index 181ff49..ad68a2b 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -5,23 +5,22 @@ use crate::{ common::{ colorize, discard, drbg, - ntor_arti::{ClientHandshake, RelayHandshakeError, ServerHandshake}, + ntor_arti::{ClientHandshake, RelayHandshakeError}, }, constants::*, framing, handshake::{ - CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey, Obfs4NtorSecretKey, - SHSMaterials, + CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey, }, proto::{O4Stream, Obfs4Stream, IAT}, - server::Server, + server::ServerSession, Error, Result, }; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use bytes::BytesMut; -use ptrs::{debug, info, trace}; +use ptrs::{debug, info}; use rand_core::RngCore; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::{Duration, Instant}; @@ -34,7 +33,7 @@ pub(crate) struct Initialized; pub(crate) struct Established; /// The session broke due to something like a timeout, reset, lost connection, etc. -trait Fault {} +pub(crate) trait Fault {} pub(crate) enum Session { Client(ClientSession), @@ -182,9 +181,9 @@ impl ClientSession { /// TODO: make sure failure modes align with golang obfs4 /// - FIN/RST based on buffered data. /// - etc. - pub async fn handshake(self, mut stream: T, deadline: Option) -> Result + pub async fn handshake<'a, T>(self, mut stream: T, deadline: Option) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send, + T: AsyncRead + AsyncWrite + Unpin + Send + 'a, { // set up for handshake let mut session = self.transition(ClientHandshaking {}); @@ -316,196 +315,29 @@ impl std::fmt::Debug for ClientSession { } } -// ================================================================ // -// Server Sessions States // -// ================================================================ // - -pub(crate) struct ServerSession { - // fixed by server - pub(crate) identity_keys: Obfs4NtorSecretKey, - pub(crate) iat_mode: IAT, - pub(crate) biased: bool, - // pub(crate) server: &'a Server, - - // generated per session - pub(crate) session_id: [u8; SESSION_ID_LEN], - pub(crate) len_seed: drbg::Seed, - pub(crate) iat_seed: drbg::Seed, - - pub(crate) _state: S, -} - -pub(crate) struct ServerHandshaking {} - -#[allow(unused)] -pub(crate) struct ServerHandshakeFailed { - details: String, -} - -pub(crate) trait ServerSessionState {} -impl ServerSessionState for Initialized {} -impl ServerSessionState for ServerHandshaking {} -impl ServerSessionState for Established {} - -impl ServerSessionState for ServerHandshakeFailed {} -impl Fault for ServerHandshakeFailed {} - -impl ServerSession { - pub fn session_id(&self) -> String { - String::from("s-") + &colorize(self.session_id) - } - - pub(crate) fn set_session_id(&mut self, id: [u8; SESSION_ID_LEN]) { - debug!( - "{} -> {} server updating session id", - colorize(self.session_id), - colorize(id) - ); - self.session_id = id; - } - - /// Helper function to perform state transitions. - fn transition(self, _state: T) -> ServerSession { - ServerSession { - // fixed by server - identity_keys: self.identity_keys, - iat_mode: self.iat_mode, - biased: self.biased, - - // generated per session - session_id: self.session_id, - len_seed: self.len_seed, - iat_seed: self.iat_seed, - - _state, - } - } - - /// Helper function to perform state transition on error. - fn fault(self, f: F) -> ServerSession { - ServerSession { - // fixed by server - identity_keys: self.identity_keys, - iat_mode: self.iat_mode, - biased: self.biased, - - // generated per session - session_id: self.session_id, - len_seed: self.len_seed, - iat_seed: self.iat_seed, - - _state: f, - } - } -} - -impl ServerSession { - /// Attempt to complete the handshake with a new client connection. - pub async fn handshake( - self, - server: &Server, - mut stream: T, - deadline: Option, - ) -> Result - where - T: AsyncRead + AsyncWrite + Unpin + Send, - { - // set up for handshake - let mut session = self.transition(ServerHandshaking {}); - - let materials = SHSMaterials::new( - &session.identity_keys, - session.session_id(), - session.len_seed.to_bytes(), - ); - - // default deadline - let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; - let handshake_fut = server.complete_handshake(&mut stream, materials, deadline); +// impl ServerSession { +// /// Attempt to complete the handshake with a new client connection. +// pub async fn handshake( +// self, +// server: &Server, +// mut stream: T, +// deadline: Option, +// ) -> Result +// where +// T: AsyncRead + AsyncWrite + Unpin + Send + 'static, +// { +// // set up for handshake +// +// let materials = SHSMaterials::new( +// &session.identity_keys, +// session.session_id(), +// session.len_seed.to_bytes(), +// ); +// +// // default deadline +// let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; +// let handshake_fut = server.complete_handshake(&mut stream, materials, deadline); +// +// } +// } - let mut keygen = - match tokio::time::timeout_at(deadline.unwrap_or(d_def), handshake_fut).await { - Ok(result) => match result { - Ok(handshake) => handshake, - Err(e) => { - // non-timeout error, - let id = session.session_id(); - let _ = session.fault(ServerHandshakeFailed { - details: format!("{id} handshake failed {e}"), - }); - return Err(e); - } - }, - Err(_) => { - let id = session.session_id(); - let _ = session.fault(ServerHandshakeFailed { - details: format!("{id} timed out"), - }); - return Err(Error::HandshakeTimeout); - } - }; - - // post handshake state updates - session.set_session_id(keygen.session_id()); - let mut codec: framing::Obfs4Codec = keygen.into(); - - // mark session as Established - let session_state: ServerSession = session.transition(Established {}); - - codec.handshake_complete(); - let o4 = O4Stream::new(stream, codec, Session::Server(session_state)); - - Ok(Obfs4Stream::from_o4(o4)) - } -} - -impl Server { - /// Complete the handshake with the client. This function assumes that the - /// client has already sent a message and that we do not know yet if the - /// message is valid. - async fn complete_handshake( - &self, - mut stream: T, - materials: SHSMaterials, - deadline: Option, - ) -> Result - where - T: AsyncRead + AsyncWrite + Unpin, - { - let session_id = materials.session_id.clone(); - - // wait for and attempt to consume the client hello message - let mut buf = [0_u8; MAX_HANDSHAKE_LENGTH]; - loop { - let n = stream.read(&mut buf).await?; - if n == 0 { - stream.shutdown().await?; - return Err(IoError::from(IoErrorKind::UnexpectedEof).into()); - } - trace!("{} successful read {n}B", session_id); - - match self.server(&mut |_: &()| Some(()), &[materials.clone()], &buf[..n]) { - Ok((keygen, response)) => { - stream.write_all(&response).await?; - info!("{} handshake complete", session_id); - return Ok(keygen); - } - Err(RelayHandshakeError::EAgain) => { - trace!("{} reading more", session_id); - continue; - } - Err(e) => { - trace!("{} failed to parse client handshake: {e}", session_id); - // if a deadline was set and has not passed already, discard - // from the stream until the deadline, then close. - if deadline.is_some_and(|d| d > Instant::now()) { - debug!("{} discarding due to: {e}", session_id); - discard(&mut stream, deadline.unwrap() - Instant::now()).await? - } - stream.shutdown().await?; - return Err(e.into()); - } - }; - } - } -} From d0b946fdeaa0a3429e2f8b1dcfcc54022378307d Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Sat, 6 Jul 2024 08:05:51 -0600 Subject: [PATCH 8/9] compiling with newly added static lifetimes :/ - missing Sink/Stream Impls --- crates/obfs4/src/client.rs | 4 ++-- crates/obfs4/src/framing/codecs.rs | 3 ++- crates/obfs4/src/proto.rs | 13 ++++-------- crates/obfs4/src/server.rs | 2 +- crates/obfs4/src/sessions.rs | 8 ++++---- crates/obfs4/src/testing.rs | 32 +++++++++++++++--------------- 6 files changed, 29 insertions(+), 33 deletions(-) diff --git a/crates/obfs4/src/client.rs b/crates/obfs4/src/client.rs index a13cb11..75203d1 100644 --- a/crates/obfs4/src/client.rs +++ b/crates/obfs4/src/client.rs @@ -142,7 +142,7 @@ impl Client { /// handshake timeout and then close the connection. pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let session = sessions::new_client_session(self.station_pubkey, self.iat_mode); @@ -158,7 +158,7 @@ impl Client { mut stream_fut: Pin>, ) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, E: std::error::Error + Send + Sync + 'static, { let stream = stream_fut.await.map_err(|e| Error::Other(Box::new(e)))?; diff --git a/crates/obfs4/src/framing/codecs.rs b/crates/obfs4/src/framing/codecs.rs index 261fca0..6d0244f 100644 --- a/crates/obfs4/src/framing/codecs.rs +++ b/crates/obfs4/src/framing/codecs.rs @@ -75,6 +75,7 @@ impl EncryptingCodec { (self.encoder, self.decoder) } + #[allow(unused)] pub(crate) fn from_parts( e: EncryptingEncoder, d: EncryptingDecoder, @@ -90,7 +91,7 @@ impl EncryptingCodec { } ///Decoder is a frame decoder instance. -struct EncryptingDecoder { +pub(crate) struct EncryptingDecoder { key: [u8; KEY_LENGTH], nonce: NonceBox, drbg: Drbg, diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index 9af475f..ab45588 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -45,9 +45,6 @@ pub(crate) enum MaybeTimeout { Unset, } -type MsgStream = Box> + Send + Unpin>; -type BytesSink = Box + Send + Unpin>; - impl std::str::FromStr for IAT { type Err = Error; fn from_str(s: &str) -> StdResult { @@ -88,10 +85,8 @@ impl MaybeTimeout { } } -pub trait SinkStream { - type Item; - type Error; -} +type MsgStream = Box> + Send + Unpin>; +type BytesSink = Box + Send + Unpin>; #[pin_project] pub struct Obfs4Stream { @@ -125,14 +120,14 @@ pub(crate) struct O4Stream { } impl O4Stream { - pub(crate) fn new<'a, T>( + pub(crate) fn new( // inner: &'a mut dyn Stream<'a>, inner: T, codec: Obfs4Codec, mut session: Session, ) -> Self where - T: AsyncRead + AsyncWrite + Unpin + Send + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let delay_fn = match session.get_iat_mode() { IAT::Off => || Duration::ZERO, diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index 8f9a171..f8e70ca 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -312,7 +312,7 @@ impl Server { pub async fn wrap<'a, T>(self, mut stream: T) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let session = self.new_server_session()?; let deadline = self.handshake_timeout.map(|d| Instant::now() + d); diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index ad68a2b..a348a71 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -183,7 +183,7 @@ impl ClientSession { /// - etc. pub async fn handshake<'a, T>(self, mut stream: T, deadline: Option) -> Result where - T: AsyncRead + AsyncWrite + Unpin + Send + 'a, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { // set up for handshake let mut session = self.transition(ClientHandshaking {}); @@ -327,17 +327,17 @@ impl std::fmt::Debug for ClientSession { // T: AsyncRead + AsyncWrite + Unpin + Send + 'static, // { // // set up for handshake -// +// // let materials = SHSMaterials::new( // &session.identity_keys, // session.session_id(), // session.len_seed.to_bytes(), // ); -// +// // // default deadline // let d_def = Instant::now() + SERVER_HANDSHAKE_TIMEOUT; // let handshake_fut = server.complete_handshake(&mut stream, materials, deadline); -// +// // } // } diff --git a/crates/obfs4/src/testing.rs b/crates/obfs4/src/testing.rs index d2997c9..85736ad 100644 --- a/crates/obfs4/src/testing.rs +++ b/crates/obfs4/src/testing.rs @@ -9,19 +9,19 @@ use std::time::Duration; #[tokio::test] async fn public_handshake() -> Result<()> { init_subscriber(); - let (mut c, mut s) = tokio::io::duplex(65_536); + let (c, s) = tokio::io::duplex(65_536); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let _ = tokio::io::split(o4s_stream); }); let o4_client = client_config.build(); - let _o4c_stream = o4_client.wrap(&mut c).await?; + let _o4c_stream = o4_client.wrap(c).await?; Ok(()) } @@ -30,14 +30,14 @@ async fn public_handshake() -> Result<()> { async fn public_iface() -> Result<()> { init_subscriber(); let message = b"awoewaeojawenwaefaw lfawn;awe da;wfenalw fawf aw"; - let (mut c, mut s) = tokio::io::duplex(65_536); + let (c, s) = tokio::io::duplex(65_536); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let mut o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let mut o4s_stream = o4_server.wrap(s).await.unwrap(); // let (mut r, mut w) = tokio::io::split(o4s_stream); // tokio::io::copy(&mut r, &mut w).await.unwrap(); @@ -52,7 +52,7 @@ async fn public_iface() -> Result<()> { }); let o4_client = client_config.build(); - let mut o4c_stream = o4_client.wrap(&mut c).await?; + let mut o4c_stream = o4_client.wrap(c).await?; o4c_stream.write_all(&message[..]).await?; o4c_stream.flush().await?; @@ -76,14 +76,14 @@ async fn public_iface() -> Result<()> { async fn transfer_10k_x1() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -129,13 +129,13 @@ async fn transfer_10k_x1() -> Result<()> { async fn transfer_10k_x3() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let o4_server = Server::getrandom(); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -183,14 +183,14 @@ async fn transfer_10k_x3() -> Result<()> { async fn transfer_1M_1024x1024() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -236,14 +236,14 @@ async fn transfer_1M_1024x1024() -> Result<()> { async fn transfer_512k_x1() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 512); + let (c, s) = tokio::io::duplex(1024 * 512); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); @@ -290,14 +290,14 @@ async fn transfer_512k_x1() -> Result<()> { async fn transfer_2_x() -> Result<()> { init_subscriber(); - let (c, mut s) = tokio::io::duplex(1024 * 1000); + let (c, s) = tokio::io::duplex(1024 * 1000); let mut rng = rand::thread_rng(); let o4_server = Server::new_from_random(&mut rng); let client_config = o4_server.client_params(); tokio::spawn(async move { - let o4s_stream = o4_server.wrap(&mut s).await.unwrap(); + let o4s_stream = o4_server.wrap(s).await.unwrap(); let (mut r, mut w) = tokio::io::split(o4s_stream); tokio::io::copy(&mut r, &mut w).await.unwrap(); }); From 51b50bd899bebfecbfbc05a7ed7c0a13dc5037aa Mon Sep 17 00:00:00 2001 From: jmwample <8297368+jmwample@users.noreply.github.com> Date: Sat, 6 Jul 2024 15:07:24 -0600 Subject: [PATCH 9/9] build, some tests still failing --- crates/obfs4/src/common/delay/mod.rs | 4 +-- crates/obfs4/src/framing/codecs.rs | 2 +- crates/obfs4/src/framing/testing.rs | 2 +- crates/obfs4/src/proto.rs | 44 +++++++++++++++------------- crates/obfs4/src/pt.rs | 11 +++---- crates/obfs4/src/server.rs | 17 ++++------- crates/obfs4/src/sessions.rs | 11 +++---- 7 files changed, 43 insertions(+), 48 deletions(-) diff --git a/crates/obfs4/src/common/delay/mod.rs b/crates/obfs4/src/common/delay/mod.rs index b66acc1..6ad6b5a 100644 --- a/crates/obfs4/src/common/delay/mod.rs +++ b/crates/obfs4/src/common/delay/mod.rs @@ -55,9 +55,7 @@ where fn start_send(self: Pin<&mut Self>, item: J) -> Result<(), Self::Error> { let s = self.project(); - if let Err(e) = s.sink.as_mut().start_send(item.into()) { - return Err(e); - } + s.sink.as_mut().start_send(item.into())?; let delay = (*s.delay_fn)(); diff --git a/crates/obfs4/src/framing/codecs.rs b/crates/obfs4/src/framing/codecs.rs index 6d0244f..788fee1 100644 --- a/crates/obfs4/src/framing/codecs.rs +++ b/crates/obfs4/src/framing/codecs.rs @@ -71,7 +71,7 @@ impl EncryptingCodec { self.handshake_complete = true; } - pub(crate) fn to_parts(self) -> (EncryptingEncoder, EncryptingDecoder) { + pub(crate) fn into_parts(self) -> (EncryptingEncoder, EncryptingDecoder) { (self.encoder, self.decoder) } diff --git a/crates/obfs4/src/framing/testing.rs b/crates/obfs4/src/framing/testing.rs index 9f163e4..9bbb8fd 100644 --- a/crates/obfs4/src/framing/testing.rs +++ b/crates/obfs4/src/framing/testing.rs @@ -10,7 +10,7 @@ /// use super::*; use crate::test_utils::init_subscriber; -use crate::{Result, Error}; +use crate::{Error, Result}; use bytes::{Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; diff --git a/crates/obfs4/src/proto.rs b/crates/obfs4/src/proto.rs index ab45588..98bb806 100644 --- a/crates/obfs4/src/proto.rs +++ b/crates/obfs4/src/proto.rs @@ -11,8 +11,8 @@ use crate::{ use bytes::{Buf, BytesMut}; use futures::{ - sink::Sink, - stream::Stream, // StreamExt}, + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, }; use pin_project::pin_project; use ptrs::trace; @@ -135,7 +135,7 @@ impl O4Stream { }; let (r, w) = tokio::io::split(inner); - let (e, d) = codec.to_parts(); + let (e, d) = codec.into_parts(); let encoding_sink = FramedWrite::new(w, e); let sink = Box::new(delay::DelayedSink::< FramedWrite, EncryptingEncoder>, @@ -187,9 +187,8 @@ impl AsyncWrite for O4Stream { let mut this = self.as_mut().project(); // determine if the stream is ready to send an event? - match futures::Sink::::poll_ready(this.sink.as_mut(), cx) { - Poll::Pending => return Poll::Pending, - _ => {} + if futures::Sink::::poll_ready(this.sink.as_mut(), cx).is_pending() { + return Poll::Pending } // while we have bytes in the buffer write MAX_MESSAGE_PAYLOAD_LENGTH @@ -210,9 +209,8 @@ impl AsyncWrite for O4Stream { len_sent += framing::MAX_MESSAGE_PAYLOAD_LENGTH; // determine if the stream is ready to send more data. if not back off - match futures::Sink::::poll_ready(this.sink.as_mut(), cx) { - Poll::Pending => return Poll::Ready(Ok(len_sent)), - _ => {} + if futures::Sink::::poll_ready(this.sink.as_mut(), cx).is_pending() { + return Poll::Ready(Ok(len_sent)) } } @@ -332,31 +330,37 @@ impl AsyncRead for Obfs4Stream { } } +// impl AsRef> for O4Stream { +// fn as_ref(&self) -> &dyn Sink { +// self.sink.as_ref() +// } +// } + impl Sink for O4Stream { type Error = Error; - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!(); + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_ready_unpin(cx) } - fn start_send(self: Pin<&mut Self>, _item: BytesMut) -> StdResult<(), Self::Error> { - todo!(); + fn start_send(self: Pin<&mut Self>, item: BytesMut) -> StdResult<(), Self::Error> { + self.project().sink.start_send_unpin(item) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!(); + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_flush_unpin(cx) } - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!(); + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().sink.poll_close_unpin(cx) } } impl Stream for O4Stream { - type Item = Messages; + type Item = Result; // Required method - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - todo!(); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_next_unpin(cx) } } // diff --git a/crates/obfs4/src/pt.rs b/crates/obfs4/src/pt.rs index cfc6b65..7b2d181 100644 --- a/crates/obfs4/src/pt.rs +++ b/crates/obfs4/src/pt.rs @@ -2,8 +2,7 @@ use crate::{ constants::*, handshake::Obfs4NtorPublicKey, proto::{Obfs4Stream, IAT}, - Error, OBFS4_NAME, - Client, ClientBuilder, Server, ServerBuilder, + Client, ClientBuilder, Error, Server, ServerBuilder, OBFS4_NAME, }; use ptrs::{args::Args, FutureResult as F}; @@ -224,7 +223,7 @@ where } fn wrap(self, io: InRW) -> Pin> { - Box::pin(Client::wrap(self, io).map_err(|e| e.into())) + Box::pin(Client::wrap(self, io).map_err(|e| e)) } fn method_name() -> String { @@ -266,12 +265,10 @@ mod test { let cb_name = >::method_name(); assert_eq!(cb_name, Obfs4PT::NAME); - let sb_name = - as ptrs::ServerBuilder>::method_name(); + let sb_name = as ptrs::ServerBuilder>::method_name(); assert_eq!(sb_name, Obfs4PT::NAME); - let ct_name = - >::method_name(); + let ct_name = >::method_name(); assert_eq!(ct_name, Obfs4PT::NAME); let st_name = >::method_name(); diff --git a/crates/obfs4/src/server.rs b/crates/obfs4/src/server.rs index f8e70ca..8762c90 100644 --- a/crates/obfs4/src/server.rs +++ b/crates/obfs4/src/server.rs @@ -4,24 +4,24 @@ use super::*; use crate::{ client::ClientBuilder, common::{ - colorize, drbg, discard, + colorize, discard, drbg, + ntor_arti::{RelayHandshakeError, ServerHandshake}, replay_filter::{self, ReplayFilter}, x25519_elligator2::{PublicKey, StaticSecret}, HmacSha256, - ntor_arti::{ServerHandshake, RelayHandshakeError}, }, constants::*, framing::{FrameError, Marshall, Obfs4Codec, TryParse, KEY_LENGTH}, - handshake::{Obfs4NtorPublicKey, Obfs4NtorSecretKey, Obfs4Keygen, SHSMaterials}, + handshake::{Obfs4Keygen, Obfs4NtorPublicKey, Obfs4NtorSecretKey, SHSMaterials}, proto::{MaybeTimeout, O4Stream, Obfs4Stream, IAT}, - sessions::{Session, Initialized, Established, Fault}, + sessions::{Established, Fault, Initialized, Session}, Error, Result, }; use ptrs::args::Args; use ptrs::{debug, info, trace}; -use std::{borrow::BorrowMut, marker::PhantomData, ops::Deref, str::FromStr, sync::Arc}; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; +use std::{borrow::BorrowMut, marker::PhantomData, ops::Deref, str::FromStr, sync::Arc}; use bytes::{Buf, BufMut, Bytes}; use hex::FromHex; @@ -309,7 +309,6 @@ impl Server { // Server Handshake // // ====================================================================== // - pub async fn wrap<'a, T>(self, mut stream: T) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -391,9 +390,7 @@ impl Server { } } - pub(crate) fn new_server_session( - &self, - ) -> Result> { + pub(crate) fn new_server_session(&self) -> Result> { let mut session_id = [0u8; SESSION_ID_LEN]; rand::thread_rng().fill_bytes(&mut session_id); Ok(ServerSession { @@ -459,7 +456,6 @@ impl Server { }; } } - } // ================================================================ // @@ -545,7 +541,6 @@ impl ServerSession { } } - #[cfg(test)] mod tests { use crate::dev; diff --git a/crates/obfs4/src/sessions.rs b/crates/obfs4/src/sessions.rs index a348a71..09c4c31 100644 --- a/crates/obfs4/src/sessions.rs +++ b/crates/obfs4/src/sessions.rs @@ -9,9 +9,7 @@ use crate::{ }, constants::*, framing, - handshake::{ - CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey, - }, + handshake::{CHSMaterials, Obfs4Keygen, Obfs4NtorHandshake, Obfs4NtorPublicKey}, proto::{O4Stream, Obfs4Stream, IAT}, server::ServerSession, Error, Result, @@ -181,7 +179,11 @@ impl ClientSession { /// TODO: make sure failure modes align with golang obfs4 /// - FIN/RST based on buffered data. /// - etc. - pub async fn handshake<'a, T>(self, mut stream: T, deadline: Option) -> Result + pub async fn handshake<'a, T>( + self, + mut stream: T, + deadline: Option, + ) -> Result where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -340,4 +342,3 @@ impl std::fmt::Debug for ClientSession { // // } // } -