@@ -13,6 +13,7 @@ use sqlpage::webserver::http::create_app;
1313use std:: collections:: HashMap ;
1414use std:: sync:: { Arc , Mutex } ;
1515use std:: time:: Duration ;
16+ use tokio:: sync:: Notify ;
1617use tokio_util:: sync:: { CancellationToken , DropGuard } ;
1718
1819fn base64url_encode ( data : & [ u8 ] ) -> String {
@@ -52,6 +53,7 @@ struct ProviderState<'a> {
5253 auth_codes : HashMap < String , String > , // code -> nonce
5354 jwt_customizer : Option < Box < JwtCustomizer < ' a > > > ,
5455 token_endpoint_delay : Duration ,
56+ token_endpoint_started : Option < Arc < Notify > > ,
5557}
5658
5759type ProviderStateWithLifetime < ' a > = ProviderState < ' a > ;
@@ -145,6 +147,7 @@ async fn token_endpoint(
145147 . unwrap_or_else ( || make_jwt ( & claims, & state. secret ) ) ;
146148
147149 let delay = state. token_endpoint_delay ;
150+ let started = state. token_endpoint_started . clone ( ) ;
148151 drop ( state) ;
149152
150153 let response = TokenResponse {
@@ -156,6 +159,10 @@ async fn token_endpoint(
156159
157160 let json_bytes = serde_json:: to_vec ( & response) . unwrap ( ) ;
158161 let body = futures_util:: stream:: once ( async move {
162+ // Signal that HTTP headers have been sent and the body stream started.
163+ if let Some ( started) = started {
164+ started. notify_one ( ) ;
165+ }
159166 tokio:: time:: sleep ( delay) . await ;
160167 Ok :: < web:: Bytes , actix_web:: Error > ( web:: Bytes :: from ( json_bytes) )
161168 } ) ;
@@ -196,6 +203,7 @@ impl FakeOidcProvider {
196203 auth_codes : HashMap :: new ( ) ,
197204 jwt_customizer : None ,
198205 token_endpoint_delay : Duration :: ZERO ,
206+ token_endpoint_started : None ,
199207 } ) ) ;
200208
201209 let state_for_server = Arc :: clone ( & state) ;
@@ -237,8 +245,15 @@ impl FakeOidcProvider {
237245 f ( & mut state)
238246 }
239247
240- pub fn set_token_endpoint_delay ( & self , delay : Duration ) {
241- self . with_state_mut ( |s| s. token_endpoint_delay = delay) ;
248+ /// Set a delay on the token endpoint body and return a Notify that fires
249+ /// once the endpoint has sent response headers and started the body stream.
250+ pub fn set_token_endpoint_delay ( & self , delay : Duration ) -> Arc < Notify > {
251+ let started = Arc :: new ( Notify :: new ( ) ) ;
252+ self . with_state_mut ( |s| {
253+ s. token_endpoint_delay = delay;
254+ s. token_endpoint_started = Some ( started. clone ( ) ) ;
255+ } ) ;
256+ started
242257 }
243258
244259 pub fn store_auth_code ( & self , code : String , nonce : String ) {
@@ -571,7 +586,7 @@ async fn test_slow_token_endpoint_does_not_freeze_server() {
571586 let redirect_uri = get_query_param ( & auth_url, "redirect_uri" ) ;
572587 provider. store_auth_code ( "test_auth_code" . to_string ( ) , nonce) ;
573588
574- provider. set_token_endpoint_delay ( Duration :: from_secs ( 999 ) ) ;
589+ let body_started = provider. set_token_endpoint_delay ( Duration :: from_secs ( 999 ) ) ;
575590
576591 let callback_uri = format ! (
577592 "{}?code=test_auth_code&state={}" ,
@@ -589,10 +604,15 @@ async fn test_slow_token_endpoint_does_not_freeze_server() {
589604 test:: call_service ( & app, req. to_request ( ) ) . await
590605 } ) ;
591606
592- // Let the localhost TCP round trip complete in real time (microseconds).
593- for _ in 0 ..1000 {
594- tokio:: task:: yield_now ( ) . await ;
595- }
607+ // Wait until the token endpoint has sent HTTP headers and started the body
608+ // stream. At this point headers are in the TCP buffer but the awc client
609+ // may not have read them yet. Yield to let it process the I/O events and
610+ // enter response.body().await — this is deterministic because the data is
611+ // already available in the socket buffer.
612+ body_started. notified ( ) . await ;
613+ tokio:: task:: yield_now ( ) . await ;
614+ tokio:: task:: yield_now ( ) . await ;
615+ tokio:: task:: yield_now ( ) . await ;
596616
597617 // Freeze time and advance past the body-read timeout. If one is set, the
598618 // request completes. If not, only the 999s endpoint delay would wake it.
0 commit comments