Skip to content

Commit 4125690

Browse files
committed
Fix stackoverflow with large pages in paginator
Previously, signalig onNext() to the subscriber was done via recursion, pulling elements from an iterator over the current page returned by the service. However, this can quickly lead to a stackoverflow error since the stack will grow linearly with the size of the page. - Replace sendNextElement recursion with a loop - Ensure that handleRequests does not recurse into itself
1 parent 82dd743 commit 4125690

File tree

3 files changed

+318
-44
lines changed

3 files changed

+318
-44
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/pagination/async/ItemsSubscription.java

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
package software.amazon.awssdk.core.internal.pagination.async;
1717

1818
import java.util.Iterator;
19+
import java.util.concurrent.CompletableFuture;
20+
import java.util.concurrent.atomic.AtomicBoolean;
1921
import java.util.function.Function;
2022
import org.reactivestreams.Subscription;
2123
import software.amazon.awssdk.annotations.SdkInternalApi;
@@ -32,6 +34,8 @@
3234
public final class ItemsSubscription<ResponseT, ItemT> extends PaginationSubscription<ResponseT> {
3335
private final Function<ResponseT, Iterator<ItemT>> getIteratorFunction;
3436
private volatile Iterator<ItemT> singlePageItemsIterator;
37+
private final AtomicBoolean handlingRequests = new AtomicBoolean();
38+
private volatile boolean awaitingNewPage = false;
3539

3640
private ItemsSubscription(BuilderImpl builder) {
3741
super(builder);
@@ -47,61 +51,83 @@ public static Builder builder() {
4751

4852
@Override
4953
protected void handleRequests() {
50-
if (!hasMoreItems() && !hasNextPage()) {
51-
completeSubscription();
54+
// Prevent recursion if we already invoked handleRequests
55+
if (!handlingRequests.compareAndSet(false, true)) {
5256
return;
5357
}
5458

55-
synchronized (this) {
56-
if (outstandingRequests.get() <= 0) {
57-
stopTask();
58-
return;
59-
}
60-
}
61-
62-
if (!isTerminated()) {
63-
/**
64-
* Current page is null only the first time the method is called.
65-
* Once initialized, current page will never be null
66-
*/
67-
if (currentPage == null || (!hasMoreItems() && hasNextPage())) {
68-
fetchNextPage();
69-
70-
} else if (hasMoreItems()) {
71-
sendNextElement();
72-
73-
// All valid cases are covered above. Throw an exception if any combination is missed
74-
} else {
75-
throw new IllegalStateException("Execution should have not reached here");
59+
try {
60+
while (true) {
61+
if (!hasMoreItems() && !hasNextPage()) {
62+
completeSubscription();
63+
return;
64+
}
65+
66+
synchronized (this) {
67+
if (outstandingRequests.get() <= 0) {
68+
stopTask();
69+
return;
70+
}
71+
}
72+
73+
if (isTerminated()) {
74+
return;
75+
}
76+
77+
if (shouldFetchNextPage()) {
78+
awaitingNewPage = true;
79+
fetchNextPage().whenComplete((r, e) -> {
80+
if (e == null) {
81+
awaitingNewPage = false;
82+
handleRequests();
83+
}
84+
// note: signaling onError if e != null is taken care of by fetchNextPage(). No need to do it here.
85+
});
86+
} else if (hasMoreItems()) {
87+
synchronized (this) {
88+
if (outstandingRequests.get() <= 0) {
89+
continue;
90+
}
91+
92+
subscriber.onNext(singlePageItemsIterator.next());
93+
outstandingRequests.getAndDecrement();
94+
}
95+
} else {
96+
// Outstanding demand AND no items in current page AND waiting for next page. Just return for now, and
97+
// we'll handle demand when the new page arrives.
98+
return;
99+
}
76100
}
101+
} finally {
102+
handlingRequests.set(false);
77103
}
78104
}
79105

80-
private void fetchNextPage() {
81-
nextPageFetcher.nextPage(currentPage)
82-
.whenComplete(((response, error) -> {
83-
if (response != null) {
84-
currentPage = response;
85-
singlePageItemsIterator = getIteratorFunction.apply(response);
86-
sendNextElement();
87-
}
88-
if (error != null) {
89-
subscriber.onError(error);
90-
cleanup();
91-
}
92-
}));
106+
private CompletableFuture<ResponseT> fetchNextPage() {
107+
return nextPageFetcher.nextPage(currentPage)
108+
.whenComplete((response, error) -> {
109+
if (response != null) {
110+
currentPage = response;
111+
singlePageItemsIterator = getIteratorFunction.apply(response);
112+
} else if (error != null) {
113+
subscriber.onError(error);
114+
cleanup();
115+
}
116+
});
93117
}
94118

95-
/**
96-
* Calls onNext and calls the recursive method.
97-
*/
98-
private void sendNextElement() {
99-
if (singlePageItemsIterator.hasNext()) {
100-
subscriber.onNext(singlePageItemsIterator.next());
101-
outstandingRequests.getAndDecrement();
119+
// Conditions when to fetch the next page:
120+
// - We're NOT already waiting for a new page AND
121+
// - We still need to fetch the first page OR
122+
// - We've exhausted the current page AND there is a next page available
123+
private boolean shouldFetchNextPage() {
124+
if (awaitingNewPage) {
125+
return false;
102126
}
103127

104-
handleRequests();
128+
// Current page is null only the first time the method is called.
129+
// Once initialized, current page will never be null.
130+
return currentPage == null || (!hasMoreItems() && hasNextPage());
105131
}
106132

107133
private boolean hasMoreItems() {
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package software.amazon.awssdk.core.pagination.async;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.mockito.Mockito.mock;
5+
import static org.mockito.Mockito.when;
6+
7+
import java.util.Iterator;
8+
import java.util.concurrent.CompletableFuture;
9+
import java.util.concurrent.TimeUnit;
10+
import java.util.concurrent.atomic.AtomicLong;
11+
import java.util.function.Function;
12+
import org.junit.jupiter.api.Test;
13+
import org.junit.jupiter.api.Timeout;
14+
import software.amazon.awssdk.core.SdkResponse;
15+
16+
public class PaginatedItemsPublisherTest {
17+
@Test
18+
@Timeout(value = 1, unit = TimeUnit.MINUTES)
19+
public void subscribe_largePage_doesNotFail() throws Exception {
20+
int nItems = 100_000;
21+
22+
Function<SdkResponse, Iterator<String>> iteratorFn = resp ->
23+
new Iterator<String>() {
24+
private int count = 0;
25+
26+
@Override
27+
public boolean hasNext() {
28+
return count < nItems;
29+
}
30+
31+
@Override
32+
public String next() {
33+
++count;
34+
return "item";
35+
}
36+
};
37+
38+
AsyncPageFetcher<SdkResponse> pageFetcher = new AsyncPageFetcher<SdkResponse>() {
39+
@Override
40+
public boolean hasNextPage(SdkResponse oldPage) {
41+
return false;
42+
}
43+
44+
@Override
45+
public CompletableFuture<SdkResponse> nextPage(SdkResponse oldPage) {
46+
return CompletableFuture.completedFuture(mock(SdkResponse.class));
47+
}
48+
};
49+
50+
PaginatedItemsPublisher<SdkResponse, String> publisher = PaginatedItemsPublisher.builder()
51+
.isLastPage(false)
52+
.nextPageFetcher(pageFetcher)
53+
.iteratorFunction(iteratorFn)
54+
.build();
55+
56+
AtomicLong counter = new AtomicLong();
57+
publisher.subscribe(i -> counter.incrementAndGet()).join();
58+
assertThat(counter.get()).isEqualTo(nItems);
59+
}
60+
61+
@Test
62+
@Timeout(value = 1, unit = TimeUnit.MINUTES)
63+
public void subscribe_longStream_doesNotFail() throws Exception {
64+
int nPages = 100_000;
65+
int nItemsPerPage = 1;
66+
Function<SdkResponse, Iterator<String>> iteratorFn = resp ->
67+
new Iterator<String>() {
68+
private int count = 0;
69+
70+
@Override
71+
public boolean hasNext() {
72+
return count < nItemsPerPage;
73+
}
74+
75+
@Override
76+
public String next() {
77+
++count;
78+
return "item";
79+
}
80+
};
81+
82+
AsyncPageFetcher<TestResponse> pageFetcher = new AsyncPageFetcher<TestResponse>() {
83+
@Override
84+
public boolean hasNextPage(TestResponse oldPage) {
85+
return oldPage.pageNumber() < nPages - 1;
86+
}
87+
88+
@Override
89+
public CompletableFuture<TestResponse> nextPage(TestResponse oldPage) {
90+
int nextPageNum;
91+
if (oldPage == null) {
92+
nextPageNum = 0;
93+
} else {
94+
nextPageNum = oldPage.pageNumber() + 1;
95+
}
96+
return CompletableFuture.completedFuture(createResponse(nextPageNum));
97+
}
98+
};
99+
100+
PaginatedItemsPublisher<SdkResponse, String> publisher = PaginatedItemsPublisher.builder()
101+
.isLastPage(false)
102+
.nextPageFetcher(pageFetcher)
103+
.iteratorFunction(iteratorFn)
104+
.build();
105+
106+
AtomicLong counter = new AtomicLong();
107+
publisher.subscribe(i -> counter.incrementAndGet()).join();
108+
assertThat(counter.get()).isEqualTo(nPages * nItemsPerPage);
109+
}
110+
111+
private abstract class TestResponse extends SdkResponse {
112+
113+
protected TestResponse(Builder builder) {
114+
super(builder);
115+
}
116+
117+
abstract Integer pageNumber();
118+
}
119+
120+
private static TestResponse createResponse(Integer pageNumber) {
121+
TestResponse mock = mock(TestResponse.class);
122+
when(mock.pageNumber()).thenReturn(pageNumber);
123+
return mock;
124+
}
125+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.dynamodb;
17+
18+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
19+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
20+
import static org.assertj.core.api.Assertions.assertThat;
21+
22+
import com.fasterxml.jackson.databind.JsonNode;
23+
import com.fasterxml.jackson.databind.ObjectMapper;
24+
import com.fasterxml.jackson.databind.node.ArrayNode;
25+
import com.fasterxml.jackson.databind.node.ObjectNode;
26+
import com.github.tomakehurst.wiremock.WireMockServer;
27+
import com.github.tomakehurst.wiremock.client.WireMock;
28+
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
29+
import java.net.URI;
30+
import java.util.concurrent.atomic.AtomicLong;
31+
import org.junit.jupiter.api.AfterAll;
32+
import org.junit.jupiter.api.BeforeAll;
33+
import org.junit.jupiter.api.Test;
34+
import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider;
35+
import software.amazon.awssdk.regions.Region;
36+
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
37+
import software.amazon.awssdk.services.dynamodb.paginators.ScanIterable;
38+
import software.amazon.awssdk.services.dynamodb.paginators.ScanPublisher;
39+
40+
public class PaginatorTest {
41+
private static final WireMockServer wireMock = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort());
42+
private static final ObjectMapper mapper = new ObjectMapper();
43+
private static DynamoDbAsyncClient ddbAsync;
44+
private static DynamoDbClient ddb;
45+
46+
@BeforeAll
47+
public static void setup() {
48+
wireMock.start();
49+
50+
ddbAsync = DynamoDbAsyncClient.builder()
51+
.region(Region.US_WEST_2)
52+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
53+
.credentialsProvider(AnonymousCredentialsProvider.create())
54+
.build();
55+
56+
ddb = DynamoDbClient.builder()
57+
.region(Region.US_WEST_2)
58+
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
59+
.credentialsProvider(AnonymousCredentialsProvider.create())
60+
.build();
61+
}
62+
63+
@AfterAll
64+
public static void teardown() {
65+
ddb.close();
66+
ddbAsync.close();
67+
wireMock.stop();
68+
}
69+
70+
@Test
71+
public void scanPaginator_async_largePage_subscribe_succeeds() {
72+
int nItems = 10_000;
73+
wireMock.stubFor(WireMock.any(anyUrl())
74+
.willReturn(aResponse()
75+
.withStatus(200)
76+
.withJsonBody(createScanResponse(nItems))));
77+
78+
ScanPublisher publisher = ddbAsync.scanPaginator(ScanRequest.builder().build());
79+
80+
AtomicLong counter = new AtomicLong();
81+
publisher.items().subscribe(item -> counter.incrementAndGet()).join();
82+
assertThat(counter.get()).isEqualTo(nItems);
83+
}
84+
85+
@Test
86+
public void scanPaginator_sync_largePage_subscribe_succeeds() {
87+
int nItems = 10_000;
88+
wireMock.stubFor(WireMock.any(anyUrl())
89+
.willReturn(aResponse()
90+
.withStatus(200)
91+
.withJsonBody(createScanResponse(nItems))));
92+
93+
ScanIterable iterable = ddb.scanPaginator(ScanRequest.builder().build());
94+
95+
AtomicLong counter = new AtomicLong();
96+
iterable.items().forEach(item -> counter.incrementAndGet());
97+
assertThat(counter.get()).isEqualTo(nItems);
98+
}
99+
100+
private static JsonNode createScanResponse(int nItems) {
101+
ObjectNode resp = mapper.createObjectNode();
102+
resp.set("Count", mapper.valueToTree(nItems));
103+
104+
ArrayNode items = mapper.createArrayNode();
105+
106+
for (int i = 0; i < nItems; i++) {
107+
// {
108+
// "id": {
109+
// "N": 1
110+
// }
111+
// }
112+
ObjectNode item = mapper.createObjectNode();
113+
ObjectNode idNode = mapper.createObjectNode();
114+
idNode.put("N", mapper.valueToTree(i));
115+
item.set("id", idNode);
116+
items.add(item);
117+
}
118+
119+
resp.set("Items", items);
120+
121+
return resp;
122+
}
123+
}

0 commit comments

Comments
 (0)