Skip to content

Commit daaeeeb

Browse files
authored
fix: proper buffer calc, fixes crash for android (#97)
1 parent a201494 commit daaeeeb

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/main/java/io/github/metarank/lightgbm4j/LGBMBooster.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,18 @@ public boolean updateOneIterCustom(float[] grad, float[] hess) throws LGBMExcept
954954
* @return number of elements in the output result (size)
955955
*/
956956
private long outBufferSize(int rows, int cols, PredictionType predictionType) {
957-
long defaultSize = 2L * rows;
958957
if (PredictionType.C_API_PREDICT_CONTRIB.equals(predictionType))
959-
return defaultSize * (cols + 1);
958+
return (long) rows * (cols + 1);
960959
else if (PredictionType.C_API_PREDICT_LEAF_INDEX.equals(predictionType))
961-
return defaultSize * iterations;
962-
else // for C_API_PREDICT_NORMAL & C_API_PREDICT_RAW_SCORE
963-
return defaultSize;
960+
return (long) rows * iterations;
961+
else {
962+
try {
963+
int numClass = getNumClasses();
964+
return (long) rows * numClass;
965+
} catch (LGBMException e) {
966+
return (long) rows * 2;
967+
}
968+
}
964969
}
965970

966971
}

0 commit comments

Comments
 (0)