1818use async_trait:: async_trait;
1919use datafusion:: arrow:: ipc:: reader:: StreamReader ;
2020use datafusion:: common:: stats:: Precision ;
21+ use datafusion:: physical_plan:: coalesce:: { LimitedBatchCoalescer , PushBatchStatus } ;
2122use std:: any:: Any ;
2223use std:: collections:: HashMap ;
2324use std:: fmt:: Debug ;
@@ -38,12 +39,14 @@ use datafusion::arrow::record_batch::RecordBatch;
3839use datafusion:: common:: runtime:: SpawnedTask ;
3940
4041use datafusion:: error:: { DataFusionError , Result } ;
41- use datafusion:: physical_plan:: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
42+ use datafusion:: physical_plan:: metrics:: {
43+ BaselineMetrics , ExecutionPlanMetricsSet , MetricsSet ,
44+ } ;
4245use datafusion:: physical_plan:: {
4346 ColumnStatistics , DisplayAs , DisplayFormatType , ExecutionPlan , Partitioning ,
4447 PlanProperties , RecordBatchStream , SendableRecordBatchStream , Statistics ,
4548} ;
46- use futures:: { Stream , StreamExt , TryStreamExt } ;
49+ use futures:: { ready , Stream , StreamExt , TryStreamExt } ;
4750
4851use crate :: error:: BallistaError ;
4952use datafusion:: execution:: context:: TaskContext ;
@@ -162,6 +165,7 @@ impl ExecutionPlan for ShuffleReaderExec {
162165 let max_message_size = config. ballista_grpc_client_max_message_size ( ) ;
163166 let force_remote_read = config. ballista_shuffle_reader_force_remote_read ( ) ;
164167 let prefer_flight = config. ballista_shuffle_reader_remote_prefer_flight ( ) ;
168+ let batch_size = config. batch_size ( ) ;
165169
166170 if force_remote_read {
167171 debug ! (
@@ -171,7 +175,7 @@ impl ExecutionPlan for ShuffleReaderExec {
171175 }
172176
173177 log:: debug!(
174- "ShuffleReaderExec::execute({task_id}) max_request_num: {max_request_num}, max_message_size: {max_message_size}"
178+ "ShuffleReaderExec::execute({task_id}) max_request_num: {max_request_num}, max_message_size: {max_message_size}, batch_size: {batch_size} "
175179 ) ;
176180 let mut partition_locations = HashMap :: new ( ) ;
177181 for p in & self . partition [ partition] {
@@ -197,11 +201,22 @@ impl ExecutionPlan for ShuffleReaderExec {
197201 prefer_flight,
198202 ) ;
199203
200- let result = RecordBatchStreamAdapter :: new (
201- Arc :: new ( self . schema . as_ref ( ) . clone ( ) ) ,
204+ let input_stream = Box :: pin ( RecordBatchStreamAdapter :: new (
205+ self . schema . clone ( ) ,
202206 response_receiver. try_flatten ( ) ,
203- ) ;
204- Ok ( Box :: pin ( result) )
207+ ) ) ;
208+
209+ Ok ( Box :: pin ( CoalescedShuffleReaderStream {
210+ schema : self . schema . clone ( ) ,
211+ input : input_stream,
212+ coalescer : LimitedBatchCoalescer :: new (
213+ self . schema . clone ( ) ,
214+ batch_size,
215+ None , // No fetch limit
216+ ) ,
217+ completed : false ,
218+ baseline_metrics : BaselineMetrics :: new ( & self . metrics , partition) ,
219+ } ) )
205220 }
206221
207222 fn metrics ( & self ) -> Option < MetricsSet > {
@@ -558,6 +573,77 @@ async fn fetch_partition_object_store(
558573 ) )
559574}
560575
576+ struct CoalescedShuffleReaderStream {
577+ schema : SchemaRef ,
578+ input : SendableRecordBatchStream ,
579+ coalescer : LimitedBatchCoalescer ,
580+ completed : bool ,
581+ baseline_metrics : BaselineMetrics ,
582+ }
583+
584+ impl Stream for CoalescedShuffleReaderStream {
585+ type Item = Result < RecordBatch > ;
586+
587+ fn poll_next (
588+ mut self : Pin < & mut Self > ,
589+ cx : & mut Context < ' _ > ,
590+ ) -> Poll < Option < Self :: Item > > {
591+ let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
592+ let _timer = elapsed_compute. timer ( ) ;
593+
594+ loop {
595+ // If there is already a completed batch ready, return it directly
596+ if let Some ( batch) = self . coalescer . next_completed_batch ( ) {
597+ self . baseline_metrics . record_output ( batch. num_rows ( ) ) ;
598+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
599+ }
600+
601+ // If the upstream is completed, then it is completed for this stream too
602+ if self . completed {
603+ return Poll :: Ready ( None ) ;
604+ }
605+
606+ // Pull from upstream
607+ match ready ! ( self . input. poll_next_unpin( cx) ) {
608+ // If upstream is completed, then flush remaning buffered batches
609+ None => {
610+ self . completed = true ;
611+ if let Err ( e) = self . coalescer . finish ( ) {
612+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
613+ }
614+ }
615+ // If upstream is not completed, then push to coalescer
616+ Some ( Ok ( batch) ) => {
617+ if batch. num_rows ( ) > 0 {
618+ // Try to push to coalescer
619+ match self . coalescer . push_batch ( batch) {
620+ // If push is successful, then continue
621+ Ok ( PushBatchStatus :: Continue ) => {
622+ continue ;
623+ }
624+ // If limit is reached, then finish coalescer and set completed to true
625+ Ok ( PushBatchStatus :: LimitReached ) => {
626+ self . completed = true ;
627+ if let Err ( e) = self . coalescer . finish ( ) {
628+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
629+ }
630+ }
631+ Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
632+ }
633+ }
634+ }
635+ Some ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
636+ }
637+ }
638+ }
639+ }
640+
641+ impl RecordBatchStream for CoalescedShuffleReaderStream {
642+ fn schema ( & self ) -> SchemaRef {
643+ self . schema . clone ( )
644+ }
645+ }
646+
561647#[ cfg( test) ]
562648mod tests {
563649 use super :: * ;
@@ -1016,10 +1102,177 @@ mod tests {
10161102 . unwrap ( )
10171103 }
10181104
1105+ fn create_custom_test_batch ( rows : usize ) -> RecordBatch {
1106+ let schema = create_test_schema ( ) ;
1107+
1108+ // 1. Create number column (0, 1, 2, ..., rows-1)
1109+ let number_vec: Vec < u32 > = ( 0 ..rows as u32 ) . collect ( ) ;
1110+ let number_array = UInt32Array :: from ( number_vec) ;
1111+
1112+ // 2. Create string column ("s0", "s1", ..., "s{rows-1}")
1113+ // Just to fill data, the content is not important
1114+ let string_vec: Vec < String > = ( 0 ..rows) . map ( |i| format ! ( "s{}" , i) ) . collect ( ) ;
1115+ let string_array = StringArray :: from ( string_vec) ;
1116+
1117+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( number_array) , Arc :: new( string_array) ] )
1118+ . unwrap ( )
1119+ }
1120+
10191121 fn create_test_schema ( ) -> SchemaRef {
10201122 Arc :: new ( Schema :: new ( vec ! [
10211123 Field :: new( "number" , DataType :: UInt32 , true ) ,
10221124 Field :: new( "str" , DataType :: Utf8 , true ) ,
10231125 ] ) )
10241126 }
1127+
1128+ use datafusion:: physical_plan:: memory:: MemoryStream ;
1129+
1130+ #[ tokio:: test]
1131+ async fn test_coalesce_stream_logic ( ) -> Result < ( ) > {
1132+ // 1. Create test data - 10 small batches, each with 3 rows
1133+ let schema = create_test_schema ( ) ;
1134+ let small_batch = create_test_batch ( ) ;
1135+ let batches = vec ! [ small_batch. clone( ) ; 10 ] ;
1136+
1137+ // 2. Create mock upstream stream (Input Stream)
1138+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1139+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1140+
1141+ // 3. Configure Coalescer: target batch size to 10 rows
1142+ let target_batch_size = 10 ;
1143+
1144+ // 4. Manually build the CoalescedShuffleReaderStream
1145+ let coalesced_stream = CoalescedShuffleReaderStream {
1146+ schema : schema. clone ( ) ,
1147+ input : input_stream,
1148+ coalescer : LimitedBatchCoalescer :: new ( schema, target_batch_size, None ) ,
1149+ completed : false ,
1150+ baseline_metrics : BaselineMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ,
1151+ } ;
1152+
1153+ // 5. Execute stream and collect results
1154+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1155+
1156+ // 6. Assertions
1157+ // Assert A: Data total not lost (30 rows)
1158+ let total_rows: usize = output_batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
1159+ assert_eq ! ( total_rows, 30 ) ;
1160+
1161+ // Assert B: Batch count reduced (10 -> 3)
1162+ assert_eq ! ( output_batches. len( ) , 3 ) ;
1163+
1164+ // Assert C: Each batch size is correct (all should be 10)
1165+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 10 ) ;
1166+ assert_eq ! ( output_batches[ 1 ] . num_rows( ) , 10 ) ;
1167+ assert_eq ! ( output_batches[ 2 ] . num_rows( ) , 10 ) ;
1168+
1169+ Ok ( ( ) )
1170+ }
1171+
1172+ #[ tokio:: test]
1173+ async fn test_coalesce_stream_remainder_flush ( ) -> Result < ( ) > {
1174+ let schema = create_test_schema ( ) ;
1175+ // Create 10 small batch, each with 3 rows. Total 30 rows.
1176+ let small_batch = create_test_batch ( ) ;
1177+ let batches = vec ! [ small_batch. clone( ) ; 10 ] ;
1178+
1179+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1180+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1181+
1182+ // Target set to 100 rows.
1183+ // Because 30 < 100, it can never be filled. Must depend on the `finish()` mechanism to flush out these 30 rows at the end of the stream.
1184+ let target_batch_size = 100 ;
1185+
1186+ let coalesced_stream = CoalescedShuffleReaderStream {
1187+ schema : schema. clone ( ) ,
1188+ input : input_stream,
1189+ coalescer : LimitedBatchCoalescer :: new ( schema, target_batch_size, None ) ,
1190+ completed : false ,
1191+ baseline_metrics : BaselineMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ,
1192+ } ;
1193+
1194+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1195+
1196+ // Assertions
1197+ assert_eq ! ( output_batches. len( ) , 1 ) ; // Should only have 1 batch
1198+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 30 ) ; // Should contain all 30 rows
1199+
1200+ Ok ( ( ) )
1201+ }
1202+
1203+ #[ tokio:: test]
1204+ async fn test_coalesce_stream_large_batch ( ) -> Result < ( ) > {
1205+ let schema = create_test_schema ( ) ;
1206+
1207+ // 1. Create a large batch (20 rows)
1208+ let big_batch = create_custom_test_batch ( 20 ) ;
1209+ let batches = vec ! [ big_batch. clone( ) ; 10 ] ; // Total 200 rows
1210+
1211+ let input_stream = MemoryStream :: try_new ( batches, schema. clone ( ) , None ) ?;
1212+ let input_stream = Box :: pin ( input_stream) as SendableRecordBatchStream ;
1213+
1214+ // 2. Target set to small size, 10 rows
1215+ let target_batch_size = 10 ;
1216+
1217+ let coalesced_stream = CoalescedShuffleReaderStream {
1218+ schema : schema. clone ( ) ,
1219+ input : input_stream,
1220+ coalescer : LimitedBatchCoalescer :: new ( schema, target_batch_size, None ) ,
1221+ completed : false ,
1222+ baseline_metrics : BaselineMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ,
1223+ } ;
1224+
1225+ let output_batches = common:: collect ( Box :: pin ( coalesced_stream) ) . await ?;
1226+
1227+ // 3. Validation: It should not split the large batch, but directly output it
1228+ // Coalescer will not split the batch if size > (max_batch_size / 2)
1229+ assert_eq ! ( output_batches. len( ) , 10 ) ;
1230+ assert_eq ! ( output_batches[ 0 ] . num_rows( ) , 20 ) ;
1231+
1232+ Ok ( ( ) )
1233+ }
1234+
1235+ use futures:: stream;
1236+
1237+ #[ tokio:: test]
1238+ async fn test_coalesce_stream_error_propagation ( ) -> Result < ( ) > {
1239+ let schema = create_test_schema ( ) ;
1240+ let small_batch = create_test_batch ( ) ; // 3行
1241+
1242+ // 1. Construct a stream with error
1243+ let batches = vec ! [
1244+ Ok ( small_batch) ,
1245+ Err ( DataFusionError :: Execution (
1246+ "Network connection failed" . to_string( ) ,
1247+ ) ) ,
1248+ ] ;
1249+
1250+ // 2. Construct a stream with error
1251+ let stream = stream:: iter ( batches) ;
1252+ let input_stream =
1253+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema. clone ( ) , stream) ) ;
1254+
1255+ // 3. Configure Coalescer
1256+ let target_batch_size = 10 ;
1257+
1258+ let coalesced_stream = CoalescedShuffleReaderStream {
1259+ schema : schema. clone ( ) ,
1260+ input : input_stream,
1261+ coalescer : LimitedBatchCoalescer :: new ( schema, target_batch_size, None ) ,
1262+ completed : false ,
1263+ baseline_metrics : BaselineMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ,
1264+ } ;
1265+
1266+ // 4. Execute stream
1267+ let result = common:: collect ( Box :: pin ( coalesced_stream) ) . await ;
1268+
1269+ // 5. Validation
1270+ assert ! ( result. is_err( ) ) ;
1271+ assert ! ( result
1272+ . unwrap_err( )
1273+ . to_string( )
1274+ . contains( "Network connection failed" ) ) ;
1275+
1276+ Ok ( ( ) )
1277+ }
10251278}
0 commit comments