457 Index_ test_nrow = 0;
459 test_nrow = inputs.front().test_nrow;
460 for (
const auto& in : inputs) {
461 if (!sanisizer::is_equal(in.test_nrow, test_nrow)) {
462 throw std::runtime_error(
"inconsistent number of rows in the test dataset across entries of 'inputs'");
468 const auto nrefs = inputs.size();
469 std::vector<std::vector<Index_> > remap_intersection_to_test_index;
470 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
471 if (inputs[r].intersection.has_value()) {
472 sanisizer::resize(remap_intersection_to_test_index, nrefs);
480 std::vector<Index_> universe;
481 auto remap_test_to_universe = sanisizer::create<std::vector<Index_> >(test_nrow, test_nrow);
483 auto present = sanisizer::create<std::vector<char> >(test_nrow);
484 auto count_refs = sanisizer::create<std::vector<I<
decltype(nrefs)> > >(test_nrow);
485 universe.reserve(test_nrow);
487 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
488 const auto& markers = inputs[r].markers;
489 const auto& inter = inputs[r].intersection;
491 if (inter.has_value()) {
492 auto& cur_test_remap = remap_intersection_to_test_index[r];
493 sanisizer::resize(cur_test_remap, inputs[r].ref->nrow(), test_nrow);
494 for (
const auto& pp : *inter) {
495 cur_test_remap[pp.second] = pp.first;
496 count_refs[pp.first] += 1;
499 for (
const auto& labmrk : markers) {
500 for (
const auto y : labmrk) {
501 const auto ty = cur_test_remap[y];
502 if (ty != test_nrow && !present[ty]) {
504 universe.push_back(ty);
510 for (
const auto& labmrk : markers) {
511 for (
const auto y : labmrk) {
514 universe.push_back(y);
519 for (
auto& x : count_refs) {
525 std::sort(universe.begin(), universe.end());
526 const auto num_universe = universe.size();
527 I<
decltype(num_universe)> keep = 0;
528 for (I<
decltype(num_universe)> u = 0; u < num_universe; ++u) {
529 const auto marker = universe[u];
530 if (count_refs[marker] == nrefs) {
531 universe[keep] = marker;
532 remap_test_to_universe[marker] = keep;
536 universe.resize(keep);
537 universe.shrink_to_fit();
541 auto references = sanisizer::create<std::vector<IntegratedReference<Index_> > >(nrefs);
542 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
543 const auto& curinput = inputs[r];
544 const auto& currefmarkers = curinput.markers;
545 const auto nlabels = currefmarkers.size();
546 auto& currefout = references[r];
548 const bool is_sparse = curinput.ref->is_sparse();
550 currefout.sparse.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.sparse))> >(nlabels));
552 currefout.dense.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.dense))> >(nlabels));
555 auto get_markers = [&](I<
decltype(nlabels)> l) -> std::vector<Index_>& {
557 return (*(currefout.sparse))[l].markers;
559 return (*(currefout.dense))[l].markers;
563 if (curinput.intersection.has_value()) {
564 auto& cur_test_remap = remap_intersection_to_test_index[r];
565 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
566 const auto& curlabmarkers = currefmarkers[l];
567 auto& markers = get_markers(l);
568 markers.reserve(curlabmarkers.size());
569 for (
const auto y : curlabmarkers) {
570 const auto ty = cur_test_remap[y];
571 if (ty != test_nrow) {
572 const auto universe_index = remap_test_to_universe[ty];
573 if (universe_index != test_nrow) {
574 markers.push_back(universe_index);
581 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
582 const auto& curlabmarkers = currefmarkers[l];
583 auto& markers = get_markers(l);
584 markers.reserve(curlabmarkers.size());
585 for (
const auto y : curlabmarkers) {
586 const auto universe_index = remap_test_to_universe[y];
587 if (universe_index != test_nrow) {
588 markers.push_back(universe_index);
597 remap_intersection_to_test_index.clear();
600 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
601 const auto& curinput = inputs[r];
602 auto& currefout = references[r];
604 const Index_ NC = curinput.ref->ncol();
606 throw std::runtime_error(
"reference dataset must have at least one column");
608 std::vector<Index_> positions;
609 sanisizer::reserve(positions, NC);
611 const auto nlabels = sanisizer::sum<std::size_t>(*std::max_element(curinput.labels, curinput.labels + NC), 1);
612 auto samples_per_label = sanisizer::create<std::vector<Index_> >(nlabels);
613 for (Index_ c = 0; c < NC; ++c) {
614 auto& pos = samples_per_label[curinput.labels[c]];
615 positions.push_back(pos);
619 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
620 if (samples_per_label[l] == 0) {
621 throw std::runtime_error(
"no profiles available for label " + std::to_string(l) +
" in reference " + std::to_string(r));
625 if (!sanisizer::is_equal(curinput.markers.size(), nlabels)) {
626 throw std::runtime_error(
"'markers' length should be equal to the number of unique labels");
629 if (curinput.ref->is_sparse()) {
630 auto negative_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
631 auto positive_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
632 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
633 const auto num_samples = samples_per_label[l];
634 sanisizer::resize(negative_ranked[l], num_samples);
635 sanisizer::resize(positive_ranked[l], num_samples);
638 if (curinput.intersection) {
639 train_integrated_per_reference_intersect<true>(curinput, remap_test_to_universe, test_nrow, options, positions, negative_ranked, positive_ranked);
641 train_integrated_per_reference_simple<true, Value_>(curinput, universe, remap_test_to_universe, options, positions, negative_ranked, positive_ranked);
644 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
645 auto& curlabout = (*(currefout.sparse))[l];
646 const auto num_samples = samples_per_label[l];
647 curlabout.num_samples = num_samples;
649 I<
decltype(curlabout.negative_ranked.size())> num_neg = 0;
650 for (
const auto& x : negative_ranked[l]) {
651 num_neg = sanisizer::sum<I<
decltype(num_neg)> >(num_neg, x.size());
654 I<
decltype(curlabout.positive_ranked.size())> num_pos = 0;
655 for (
const auto& x : positive_ranked[l]) {
656 num_pos = sanisizer::sum<I<
decltype(num_pos)> >(num_pos, x.size());
659 curlabout.negative_ranked.reserve(num_neg);
660 curlabout.negative_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.negative_indptrs.size())> >(num_samples, 1));
661 curlabout.negative_indptrs.push_back(0);
662 for (
const auto& x : negative_ranked[l]) {
663 curlabout.negative_ranked.insert(curlabout.negative_ranked.end(), x.begin(), x.end());
664 curlabout.negative_indptrs.push_back(curlabout.negative_ranked.size());
667 curlabout.positive_ranked.reserve(num_pos);
668 curlabout.positive_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.positive_indptrs.size())> >(num_samples, 1));
669 curlabout.positive_indptrs.push_back(0);
670 for (
const auto& x : positive_ranked[l]) {
671 curlabout.positive_ranked.insert(curlabout.positive_ranked.end(), x.begin(), x.end());
672 curlabout.positive_indptrs.push_back(curlabout.positive_ranked.size());
677 auto out_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
678 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
679 const auto num_samples = samples_per_label[l];
680 sanisizer::resize(out_ranked[l], num_samples);
683 if (curinput.intersection) {
684 train_integrated_per_reference_intersect<false>(curinput, remap_test_to_universe, test_nrow, options, positions, out_ranked,
true);
686 train_integrated_per_reference_simple<false, Value_>(curinput, universe, remap_test_to_universe, options, positions, out_ranked,
true);
689 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
690 auto& curlabout = (*(currefout.dense))[l];
691 curlabout.num_samples = samples_per_label[l];
692 curlabout.all_ranked.reserve(sanisizer::product<I<
decltype(curlabout.all_ranked.size())> >(universe.size(), curlabout.num_samples));
693 for (
const auto& x : out_ranked[l]) {
694 curlabout.all_ranked.insert(curlabout.all_ranked.end(), x.begin(), x.end());