Skip to content

Commit 2564807

Browse files
committed
enable and fix full prediction calculation
1 parent 2cce6e3 commit 2564807

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

src/predict.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,6 @@ int main(int argc, const char** argv) {
220220
std::cout << fmt::format("F1: {:.3}%\n", percentage(tp, tp + (fp + fn) / 2));
221221

222222
} else {
223-
spdlog::error("Full predictions are currently not supported");
224-
exit(1);
225-
/*
226223
spdlog::info("Reading model file from '{}'", model_file);
227224
auto model = io::load_model(model_file);
228225

@@ -234,8 +231,7 @@ int main(int argc, const char** argv) {
234231
std::exit(1);
235232
}
236233
const auto& predictions = task.get_predictions();
237-
*/
238234

239-
// TODO fix handling of full predictions
235+
io::prediction::save_dense_predictions(result_file, predictions);
240236
}
241237
}

src/prediction/prediction.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ PredictionBase::PredictionBase(const DatasetBase* data,
3131
}
3232
}
3333

34-
void PredictionBase::make_thread_local_features(int num_threads) {
34+
void PredictionBase::make_thread_local_features(long num_threads) {
3535
m_ThreadLocalFeatures.resize(num_threads);
3636
}
3737

@@ -69,9 +69,12 @@ long FullPredictionTaskGenerator::num_tasks() const
6969

7070
void FullPredictionTaskGenerator::run_tasks(long begin, long end, thread_id_t thread_id)
7171
{
72-
do_prediction(begin, end, thread_id, m_Predictions.middleRows(begin, end));
72+
do_prediction(begin, end, thread_id, m_Predictions.middleRows(begin, end - begin));
7373
}
7474

75+
void FullPredictionTaskGenerator::prepare(long num_threads, long chunk_size) {
76+
make_thread_local_features(num_threads);
77+
}
7578

7679

7780
TopKPredictionTaskGenerator::TopKPredictionTaskGenerator(const DatasetBase* data, std::shared_ptr<const Model> model, long K) :

src/prediction/prediction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ namespace dismec::prediction {
4040
const DatasetBase* m_Data; //!< Data on which the prediction is run
4141
std::shared_ptr<const Model> m_Model; //!< Model (possibly partial) for which prediction is run
4242

43-
void make_thread_local_features(int num_threads);
43+
void make_thread_local_features(long num_threads);
4444

4545
void init_thread(thread_id_t thread_id) final;
4646

@@ -70,6 +70,7 @@ namespace dismec::prediction {
7070

7171
FullPredictionTaskGenerator(const DatasetBase* data, std::shared_ptr<const Model> model);
7272

73+
void prepare(long num_threads, long chunk_size) override;
7374
void run_tasks(long begin, long end, thread_id_t thread_id) override;
7475
[[nodiscard]] long num_tasks() const override;
7576

0 commit comments

Comments
 (0)