Skip to content

Commit cf58271

Browse files
impl
1 parent eab22e2 commit cf58271

File tree

1 file changed

+260
-7
lines changed

1 file changed

+260
-7
lines changed

ballista/core/src/execution_plans/shuffle_reader.rs

Lines changed: 260 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use async_trait::async_trait;
1919
use datafusion::arrow::ipc::reader::StreamReader;
2020
use datafusion::common::stats::Precision;
21+
use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, PushBatchStatus};
2122
use std::any::Any;
2223
use std::collections::HashMap;
2324
use std::fmt::Debug;
@@ -38,12 +39,14 @@ use datafusion::arrow::record_batch::RecordBatch;
3839
use datafusion::common::runtime::SpawnedTask;
3940

4041
use datafusion::error::{DataFusionError, Result};
41-
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
42+
use datafusion::physical_plan::metrics::{
43+
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
44+
};
4245
use datafusion::physical_plan::{
4346
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
4447
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
4548
};
46-
use futures::{Stream, StreamExt, TryStreamExt};
49+
use futures::{ready, Stream, StreamExt, TryStreamExt};
4750

4851
use crate::error::BallistaError;
4952
use datafusion::execution::context::TaskContext;
@@ -162,6 +165,7 @@ impl ExecutionPlan for ShuffleReaderExec {
162165
let max_message_size = config.ballista_grpc_client_max_message_size();
163166
let force_remote_read = config.ballista_shuffle_reader_force_remote_read();
164167
let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
168+
let batch_size = config.batch_size();
165169

166170
if force_remote_read {
167171
debug!(
@@ -171,7 +175,7 @@ impl ExecutionPlan for ShuffleReaderExec {
171175
}
172176

173177
log::debug!(
174-
"ShuffleReaderExec::execute({task_id}) max_request_num: {max_request_num}, max_message_size: {max_message_size}"
178+
"ShuffleReaderExec::execute({task_id}) max_request_num: {max_request_num}, max_message_size: {max_message_size}, batch_size: {batch_size}"
175179
);
176180
let mut partition_locations = HashMap::new();
177181
for p in &self.partition[partition] {
@@ -197,11 +201,22 @@ impl ExecutionPlan for ShuffleReaderExec {
197201
prefer_flight,
198202
);
199203

200-
let result = RecordBatchStreamAdapter::new(
201-
Arc::new(self.schema.as_ref().clone()),
204+
let input_stream = Box::pin(RecordBatchStreamAdapter::new(
205+
self.schema.clone(),
202206
response_receiver.try_flatten(),
203-
);
204-
Ok(Box::pin(result))
207+
));
208+
209+
Ok(Box::pin(CoalescedShuffleReaderStream {
210+
schema: self.schema.clone(),
211+
input: input_stream,
212+
coalescer: LimitedBatchCoalescer::new(
213+
self.schema.clone(),
214+
batch_size,
215+
None, // No fetch limit
216+
),
217+
completed: false,
218+
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
219+
}))
205220
}
206221

207222
fn metrics(&self) -> Option<MetricsSet> {
@@ -558,6 +573,77 @@ async fn fetch_partition_object_store(
558573
))
559574
}
560575

576+
struct CoalescedShuffleReaderStream {
577+
schema: SchemaRef,
578+
input: SendableRecordBatchStream,
579+
coalescer: LimitedBatchCoalescer,
580+
completed: bool,
581+
baseline_metrics: BaselineMetrics,
582+
}
583+
584+
impl Stream for CoalescedShuffleReaderStream {
585+
type Item = Result<RecordBatch>;
586+
587+
fn poll_next(
588+
mut self: Pin<&mut Self>,
589+
cx: &mut Context<'_>,
590+
) -> Poll<Option<Self::Item>> {
591+
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
592+
let _timer = elapsed_compute.timer();
593+
594+
loop {
595+
// If there is already a completed batch ready, return it directly
596+
if let Some(batch) = self.coalescer.next_completed_batch() {
597+
self.baseline_metrics.record_output(batch.num_rows());
598+
return Poll::Ready(Some(Ok(batch)));
599+
}
600+
601+
// If the upstream is completed, then it is completed for this stream too
602+
if self.completed {
603+
return Poll::Ready(None);
604+
}
605+
606+
// Pull from upstream
607+
match ready!(self.input.poll_next_unpin(cx)) {
608+
// If upstream is completed, then flush remaning buffered batches
609+
None => {
610+
self.completed = true;
611+
if let Err(e) = self.coalescer.finish() {
612+
return Poll::Ready(Some(Err(e)));
613+
}
614+
}
615+
// If upstream is not completed, then push to coalescer
616+
Some(Ok(batch)) => {
617+
if batch.num_rows() > 0 {
618+
// Try to push to coalescer
619+
match self.coalescer.push_batch(batch) {
620+
// If push is successful, then continue
621+
Ok(PushBatchStatus::Continue) => {
622+
continue;
623+
}
624+
// If limit is reached, then finish coalescer and set completed to true
625+
Ok(PushBatchStatus::LimitReached) => {
626+
self.completed = true;
627+
if let Err(e) = self.coalescer.finish() {
628+
return Poll::Ready(Some(Err(e)));
629+
}
630+
}
631+
Err(e) => return Poll::Ready(Some(Err(e))),
632+
}
633+
}
634+
}
635+
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
636+
}
637+
}
638+
}
639+
}
640+
641+
impl RecordBatchStream for CoalescedShuffleReaderStream {
642+
fn schema(&self) -> SchemaRef {
643+
self.schema.clone()
644+
}
645+
}
646+
561647
#[cfg(test)]
562648
mod tests {
563649
use super::*;
@@ -1016,10 +1102,177 @@ mod tests {
10161102
.unwrap()
10171103
}
10181104

1105+
fn create_custom_test_batch(rows: usize) -> RecordBatch {
1106+
let schema = create_test_schema();
1107+
1108+
// 1. Create number column (0, 1, 2, ..., rows-1)
1109+
let number_vec: Vec<u32> = (0..rows as u32).collect();
1110+
let number_array = UInt32Array::from(number_vec);
1111+
1112+
// 2. Create string column ("s0", "s1", ..., "s{rows-1}")
1113+
// Just to fill data, the content is not important
1114+
let string_vec: Vec<String> = (0..rows).map(|i| format!("s{}", i)).collect();
1115+
let string_array = StringArray::from(string_vec);
1116+
1117+
RecordBatch::try_new(schema, vec![Arc::new(number_array), Arc::new(string_array)])
1118+
.unwrap()
1119+
}
1120+
10191121
fn create_test_schema() -> SchemaRef {
10201122
Arc::new(Schema::new(vec![
10211123
Field::new("number", DataType::UInt32, true),
10221124
Field::new("str", DataType::Utf8, true),
10231125
]))
10241126
}
1127+
1128+
use datafusion::physical_plan::memory::MemoryStream;
1129+
1130+
#[tokio::test]
1131+
async fn test_coalesce_stream_logic() -> Result<()> {
1132+
// 1. Create test data - 10 small batches, each with 3 rows
1133+
let schema = create_test_schema();
1134+
let small_batch = create_test_batch();
1135+
let batches = vec![small_batch.clone(); 10];
1136+
1137+
// 2. Create mock upstream stream (Input Stream)
1138+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1139+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1140+
1141+
// 3. Configure Coalescer: target batch size to 10 rows
1142+
let target_batch_size = 10;
1143+
1144+
// 4. Manually build the CoalescedShuffleReaderStream
1145+
let coalesced_stream = CoalescedShuffleReaderStream {
1146+
schema: schema.clone(),
1147+
input: input_stream,
1148+
coalescer: LimitedBatchCoalescer::new(schema, target_batch_size, None),
1149+
completed: false,
1150+
baseline_metrics: BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
1151+
};
1152+
1153+
// 5. Execute stream and collect results
1154+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1155+
1156+
// 6. Assertions
1157+
// Assert A: Data total not lost (30 rows)
1158+
let total_rows: usize = output_batches.iter().map(|b| b.num_rows()).sum();
1159+
assert_eq!(total_rows, 30);
1160+
1161+
// Assert B: Batch count reduced (10 -> 3)
1162+
assert_eq!(output_batches.len(), 3);
1163+
1164+
// Assert C: Each batch size is correct (all should be 10)
1165+
assert_eq!(output_batches[0].num_rows(), 10);
1166+
assert_eq!(output_batches[1].num_rows(), 10);
1167+
assert_eq!(output_batches[2].num_rows(), 10);
1168+
1169+
Ok(())
1170+
}
1171+
1172+
#[tokio::test]
1173+
async fn test_coalesce_stream_remainder_flush() -> Result<()> {
1174+
let schema = create_test_schema();
1175+
// Create 10 small batch, each with 3 rows. Total 30 rows.
1176+
let small_batch = create_test_batch();
1177+
let batches = vec![small_batch.clone(); 10];
1178+
1179+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1180+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1181+
1182+
// Target set to 100 rows.
1183+
// Because 30 < 100, it can never be filled. Must depend on the `finish()` mechanism to flush out these 30 rows at the end of the stream.
1184+
let target_batch_size = 100;
1185+
1186+
let coalesced_stream = CoalescedShuffleReaderStream {
1187+
schema: schema.clone(),
1188+
input: input_stream,
1189+
coalescer: LimitedBatchCoalescer::new(schema, target_batch_size, None),
1190+
completed: false,
1191+
baseline_metrics: BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
1192+
};
1193+
1194+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1195+
1196+
// Assertions
1197+
assert_eq!(output_batches.len(), 1); // Should only have 1 batch
1198+
assert_eq!(output_batches[0].num_rows(), 30); // Should contain all 30 rows
1199+
1200+
Ok(())
1201+
}
1202+
1203+
#[tokio::test]
1204+
async fn test_coalesce_stream_large_batch() -> Result<()> {
1205+
let schema = create_test_schema();
1206+
1207+
// 1. Create a large batch (20 rows)
1208+
let big_batch = create_custom_test_batch(20);
1209+
let batches = vec![big_batch.clone(); 10]; // Total 200 rows
1210+
1211+
let input_stream = MemoryStream::try_new(batches, schema.clone(), None)?;
1212+
let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
1213+
1214+
// 2. Target set to small size, 10 rows
1215+
let target_batch_size = 10;
1216+
1217+
let coalesced_stream = CoalescedShuffleReaderStream {
1218+
schema: schema.clone(),
1219+
input: input_stream,
1220+
coalescer: LimitedBatchCoalescer::new(schema, target_batch_size, None),
1221+
completed: false,
1222+
baseline_metrics: BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
1223+
};
1224+
1225+
let output_batches = common::collect(Box::pin(coalesced_stream)).await?;
1226+
1227+
// 3. Validation: It should not split the large batch, but directly output it
1228+
// Coalescer will not split the batch if size > (max_batch_size / 2)
1229+
assert_eq!(output_batches.len(), 10);
1230+
assert_eq!(output_batches[0].num_rows(), 20);
1231+
1232+
Ok(())
1233+
}
1234+
1235+
use futures::stream;
1236+
1237+
#[tokio::test]
1238+
async fn test_coalesce_stream_error_propagation() -> Result<()> {
1239+
let schema = create_test_schema();
1240+
let small_batch = create_test_batch(); // 3行
1241+
1242+
// 1. Construct a stream with error
1243+
let batches = vec![
1244+
Ok(small_batch),
1245+
Err(DataFusionError::Execution(
1246+
"Network connection failed".to_string(),
1247+
)),
1248+
];
1249+
1250+
// 2. Construct a stream with error
1251+
let stream = stream::iter(batches);
1252+
let input_stream =
1253+
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream));
1254+
1255+
// 3. Configure Coalescer
1256+
let target_batch_size = 10;
1257+
1258+
let coalesced_stream = CoalescedShuffleReaderStream {
1259+
schema: schema.clone(),
1260+
input: input_stream,
1261+
coalescer: LimitedBatchCoalescer::new(schema, target_batch_size, None),
1262+
completed: false,
1263+
baseline_metrics: BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
1264+
};
1265+
1266+
// 4. Execute stream
1267+
let result = common::collect(Box::pin(coalesced_stream)).await;
1268+
1269+
// 5. Validation
1270+
assert!(result.is_err());
1271+
assert!(result
1272+
.unwrap_err()
1273+
.to_string()
1274+
.contains("Network connection failed"));
1275+
1276+
Ok(())
1277+
}
10251278
}

0 commit comments

Comments
 (0)