Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
Expand Down Expand Up @@ -123,7 +124,8 @@ private Aggregation handleSinglePercentile(
final List<Aggregation> existingAggregations
)
{
final double percentile = ((Number) RexLiteral.value(percentileArg)).doubleValue();
final Object value = RexLiteral.value(percentileArg);
final double percentile = DruidSqlParserUtils.getNumericLiteral(value, "SPECTATOR_PERCENTILE", "percentile").doubleValue();

final String histogramName = StringUtils.format("%s:agg", name);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.druid.data.input.impl.StringDimensionSchema;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.error.DruidException;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
Expand Down Expand Up @@ -57,6 +58,7 @@
import org.apache.druid.sql.calcite.util.SqlTestFramework.StandardComponentSupplier;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

import java.util.Collections;
Expand Down Expand Up @@ -551,4 +553,21 @@ public void testSpectatorFunctionsOnNullHistogram()
ImmutableList.of(new Object[]{null, null, null})
);
}

@Test
public void testSpectatorPercentileWithStringLiteral()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT SPECTATOR_PERCENTILE(histogram_metric, '99.99') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("must be a numeric literal"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
Expand Down Expand Up @@ -92,7 +93,7 @@ public Aggregation toDruidAggregation(
return null;
}

logK = ((Number) RexLiteral.value(logKarg)).intValue();
logK = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(logKarg), "APPROX_COUNT_DISTINCT_DS_HLL", "logK").intValue();
} else {
logK = HllSketchAggregatorFactory.DEFAULT_LG_K;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
Expand Down Expand Up @@ -98,7 +99,7 @@ public Aggregation toDruidAggregation(
return null;
}

final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
final float probability = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(probabilityArg), "APPROX_QUANTILE_DS", "probability").floatValue();
final int k;

if (aggregateCall.getArgList().size() >= 3) {
Expand All @@ -109,7 +110,7 @@ public Aggregation toDruidAggregation(
return null;
}

k = ((Number) RexLiteral.value(resolutionArg)).intValue();
k = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(resolutionArg), "APPROX_QUANTILE_DS", "resolution").intValue();
} else {
k = DoublesSketchAggregatorFactory.DEFAULT_K;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
Expand Down Expand Up @@ -90,7 +91,7 @@ public Aggregation toDruidAggregation(
return null;
}

sketchSize = ((Number) RexLiteral.value(sketchSizeArg)).intValue();
sketchSize = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(sketchSizeArg), "APPROX_COUNT_DISTINCT_DS_THETA", "size").intValue();
} else {
sketchSize = SketchAggregatorFactory.DEFAULT_MAX_SKETCH_SIZE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.error.DruidException;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
Expand Down Expand Up @@ -71,6 +72,7 @@
import org.apache.druid.sql.calcite.util.TestDataBuilder;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Assert;
import org.junit.jupiter.api.Test;

import java.util.Collections;
Expand Down Expand Up @@ -1111,6 +1113,40 @@ public void testSuccessWithSmallMaxStreamLength()
);
}

@Test
public void testApproxQuantileWithStringLiteral()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT APPROX_QUANTILE_DS(m1, '0.99') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("Cannot apply 'APPROX_QUANTILE_DS'"));
}
}

@Test
public void testApproxQuantileWithStringResolution()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT APPROX_QUANTILE_DS(m1, 0.99, '128') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("Cannot apply 'APPROX_QUANTILE_DS'"));
}
}

private static PostAggregator makeFieldAccessPostAgg(String name)
{
return new FieldAccessPostAggregator(name, name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -88,7 +89,7 @@ public Aggregation toDruidAggregation(
return null;
}

final int maxNumEntries = ((Number) RexLiteral.value(maxNumEntriesOperand)).intValue();
final int maxNumEntries = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxNumEntriesOperand), "BLOOM_FILTER", "maxNumEntries").intValue();

// Look for existing matching aggregatorFactory.
for (final Aggregation existing : existingAggregations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
Expand Down Expand Up @@ -91,7 +92,7 @@ public Aggregation toDruidAggregation(
return null;
}

final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
final float probability = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(probabilityArg), "APPROX_QUANTILE", "probability").floatValue();
final int resolution;

if (aggregateCall.getArgList().size() >= 3) {
Expand All @@ -102,7 +103,7 @@ public Aggregation toDruidAggregation(
return null;
}

resolution = ((Number) RexLiteral.value(resolutionArg)).intValue();
resolution = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(resolutionArg), "APPROX_QUANTILE", "resolution").intValue();
} else {
resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -83,7 +84,7 @@ public Aggregation toDruidAggregation(
// maxBytes must be a literal
return null;
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), "ARRAY_CONCAT_AGG", "maxBytes").intValue();
}
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
final ExprMacroTable macroTable = plannerContext.getPlannerToolbox().exprMacroTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -86,7 +87,7 @@ public Aggregation toDruidAggregation(
// maxBytes must be a literal
return null;
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), "ARRAY_AGG", "maxBytes").intValue();
}
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), arguments.get(0));
if (arg == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.parser.DruidSqlParserUtils;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -130,7 +131,7 @@ public Aggregation toDruidAggregation(
// maxBytes must be a literal
return null;
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
maxSizeBytes = DruidSqlParserUtils.getNumericLiteral(RexLiteral.value(maxBytes), "STRING_AGG", "maxBytes").intValue();
}

final DruidExpression arg = arguments.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,51 @@ public static DruidException problemParsing(String message)
{
return InvalidSqlInput.exception(message);
}

/**
* Creates a DruidException for invalid SQL function parameter types.
*
* @param functionName the SQL function name (e.g., "SPECTATOR_PERCENTILE")
* @param parameterName the parameter name
* @param expectedType the expected type
* @param actualValue the value provided needed to determine type
* @return DruidException with INVALID_INPUT category and USER persona
*/
public static DruidException invalidParameterTypeException(
String functionName,
String parameterName,
String expectedType,
@Nullable Object actualValue
)
{
final String actualType = actualValue == null ? "NULL" : actualValue.getClass().getSimpleName();
return InvalidSqlInput.exception(
"%s parameter `%s` must be a %s literal, got %s",
functionName,
parameterName,
expectedType,
actualType
);
}

/**
* Validates and returns a numeric value from a RexLiteral, or throws invalidParameterTypeException if invalid.
*
* @param value the value extracted from RexLiteral.value()
* @param functionName the SQL function name
* @param parameterName the parameter name
* @return the value as a Number
* @throws DruidException if value is not a Number
*/
public static Number getNumericLiteral(
@Nullable Object value,
String functionName,
String parameterName
)
{
if (!(value instanceof Number)) {
throw invalidParameterTypeException(functionName, parameterName, "numeric", value);
}
return (Number) value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14377,6 +14377,57 @@ public void testStringAggMaxBytes()
);
}

@Test
public void testStringAggWithStringMaxBytes()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT STRING_AGG(dim1, ',', 'abc') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal"));
}
}

@Test
public void testArrayAggWithStringMaxBytes()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT ARRAY_AGG(dim1, 'abc') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal"));
}
}

@Test
public void testArrayConcatAggWithStringMaxBytes()
{
// verify invalid queries return 400 (user error)
final String query = "SELECT ARRAY_CONCAT_AGG(MV_TO_ARRAY(dim3), 'abc') FROM foo";

try {
testQuery(query, ImmutableList.of(), ImmutableList.of());
Assert.fail("Expected DruidException but query succeeded");
}
catch (DruidException e) {
Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona());
Assert.assertEquals(DruidException.Category.INVALID_INPUT, e.getCategory());
Assert.assertTrue(e.getMessage().contains("parameter `maxBytes` must be a numeric literal"));
}
}

/**
* see {@link TestDataBuilder#RAW_ROWS1_WITH_NUMERIC_DIMS}
* for the input data source of this test
Expand Down
Loading