Skip to content

Commit

Permalink
Avoid OOM-killing query if large result-level cache population fails …
Browse files Browse the repository at this point in the history
…for query

Currently, result-level caching which attempts to allocate a large enough buffer to store query results will overflow the Integer.MAX_INT capacity. ByteArrayOutputStream materializes this case as an OutOfMemoryError, which is not caught and terminates the node. This limits the allocated buffer for storing query results to whatever is set in `CacheConfig.getResultLevelCacheLimit()`.
  • Loading branch information
jtuglu-netflix committed Jan 22, 2025
1 parent a964220 commit 648faf1
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

/**
* An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit
* is exceeded.
* is exceeded. *Not* thread-safe.
*/
public class LimitedOutputStream extends OutputStream
{
Expand Down Expand Up @@ -88,6 +88,11 @@ public void close() throws IOException
out.close();
}

public OutputStream get()
{
return out;
}

private void plus(final int n) throws IOException
{
written += n;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.druid.client.cache.Cache;
import org.apache.druid.client.cache.Cache.NamedKey;
import org.apache.druid.client.cache.CacheConfig;
import org.apache.druid.io.LimitedOutputStream;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequence;
Expand Down Expand Up @@ -152,6 +153,8 @@ public void after(boolean isDone, Throwable thrown)
// The resultset identifier and its length is cached along with the resultset
resultLevelCachePopulator.populateResults();
log.debug("Cache population complete for query %s", query.getId());
} else { // thrown == null && !resultLevelCachePopulator.isShouldPopulate()
log.error("Failed (gracefully) to populate result level cache for query %s", query.getId());
}
resultLevelCachePopulator.stopPopulating();
}
Expand Down Expand Up @@ -233,8 +236,8 @@ private ResultLevelCachePopulator createResultLevelCachePopulator(
try {
// Save the resultSetId and its length
resultLevelCachePopulator.cacheObjectStream.write(ByteBuffer.allocate(Integer.BYTES)
.putInt(resultSetId.length())
.array());
.putInt(resultSetId.length())
.array());
resultLevelCachePopulator.cacheObjectStream.write(StringUtils.toUtf8(resultSetId));
}
catch (IOException ioe) {
Expand All @@ -255,7 +258,7 @@ private class ResultLevelCachePopulator
private final Cache.NamedKey key;
private final CacheConfig cacheConfig;
@Nullable
private ByteArrayOutputStream cacheObjectStream;
private LimitedOutputStream cacheObjectStream;

private ResultLevelCachePopulator(
Cache cache,
Expand All @@ -270,7 +273,14 @@ private ResultLevelCachePopulator(
this.serialiers = mapper.getSerializerProviderInstance();
this.key = key;
this.cacheConfig = cacheConfig;
this.cacheObjectStream = shouldPopulate ? new ByteArrayOutputStream() : null;
this.cacheObjectStream = shouldPopulate ? new LimitedOutputStream(
new ByteArrayOutputStream(),
cacheConfig.getResultLevelCacheLimit(), limit -> StringUtils.format(
"resultLevelCacheLimit[%,d] exceeded. "
+ "Max ResultLevelCacheLimit for cache exceeded. Result caching failed.",
limit
)
) : null;
}

boolean isShouldPopulate()
Expand All @@ -289,12 +299,8 @@ private void cacheResultEntry(
)
{
Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream");
int cacheLimit = cacheConfig.getResultLevelCacheLimit();
try (JsonGenerator gen = mapper.getFactory().createGenerator(cacheObjectStream)) {
JacksonUtils.writeObjectUsingSerializerProvider(gen, serialiers, cacheFn.apply(resultEntry));
if (cacheLimit > 0 && cacheObjectStream.size() > cacheLimit) {
stopPopulating();
}
}
catch (IOException ex) {
log.error(ex, "Failed to retrieve entry to be cached. Result Level caching will not be performed!");
Expand All @@ -304,7 +310,8 @@ private void cacheResultEntry(

public void populateResults()
{
byte[] cachedResults = Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream").toByteArray();
byte[] cachedResults = ((ByteArrayOutputStream) Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream")
.get()).toByteArray();
cache.put(key, cachedResults);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
public class ResultLevelCachingQueryRunnerTest extends QueryRunnerBasedOnClusteredClientTestBase
{
private Cache cache;
private static final int DEFAULT_CACHE_ENTRY_MAX_SIZE = Integer.MAX_VALUE;

@Before
public void setup()
Expand All @@ -58,7 +59,7 @@ public void testNotPopulateAndNotUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -72,7 +73,7 @@ public void testNotPopulateAndNotUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -93,7 +94,7 @@ public void testPopulateAndNotUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -107,7 +108,7 @@ public void testPopulateAndNotUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -128,7 +129,7 @@ public void testNotPopulateAndUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(false, false),
newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -142,7 +143,7 @@ public void testNotPopulateAndUse()
Assert.assertEquals(0, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(false, true),
newCacheConfig(false, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -163,7 +164,7 @@ public void testPopulateAndUse()
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, true),
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -177,7 +178,7 @@ public void testPopulateAndUse()
Assert.assertEquals(1, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, true),
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand All @@ -192,6 +193,41 @@ public void testPopulateAndUse()
Assert.assertEquals(1, cache.getStats().getNumMisses());
}

@Test
public void testNoPopulateIfEntrySizeExceedsMaximum()
{
prepareCluster(10);
final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner1 = createQueryRunner(
newCacheConfig(true, true, 128),
query
);

final Sequence<Result<TimeseriesResultValue>> sequence1 = queryRunner1.run(
QueryPlus.wrap(query),
responseContext()
);
final List<Result<TimeseriesResultValue>> results1 = sequence1.toList();
Assert.assertEquals(0, cache.getStats().getNumHits());
Assert.assertEquals(0, cache.getStats().getNumEntries());
Assert.assertEquals(1, cache.getStats().getNumMisses());

final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner2 = createQueryRunner(
newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

final Sequence<Result<TimeseriesResultValue>> sequence2 = queryRunner2.run(
QueryPlus.wrap(query),
responseContext()
);
final List<Result<TimeseriesResultValue>> results2 = sequence2.toList();
Assert.assertEquals(results1, results2);
Assert.assertEquals(0, cache.getStats().getNumHits());
Assert.assertEquals(1, cache.getStats().getNumEntries());
Assert.assertEquals(2, cache.getStats().getNumMisses());
}

@Test
public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache()
{
Expand All @@ -206,7 +242,7 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache()

final Query<Result<TimeseriesResultValue>> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval());
final ResultLevelCachingQueryRunner<Result<TimeseriesResultValue>> queryRunner = createQueryRunner(
newCacheConfig(true, false),
newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE),
query
);

Expand Down Expand Up @@ -249,7 +285,11 @@ private <T> ResultLevelCachingQueryRunner<T> createQueryRunner(
);
}

private CacheConfig newCacheConfig(boolean populateResultLevelCache, boolean useResultLevelCache)
private CacheConfig newCacheConfig(
boolean populateResultLevelCache,
boolean useResultLevelCache,
int resultLevelCacheLimit
)
{
return new CacheConfig()
{
Expand All @@ -264,6 +304,12 @@ public boolean isUseResultLevelCache()
{
return useResultLevelCache;
}

@Override
public int getResultLevelCacheLimit()
{
return resultLevelCacheLimit;
}
};
}
}

0 comments on commit 648faf1

Please sign in to comment.