1#ifndef SINGLEPP_TRAIN_INTEGRATED_HPP
2#define SINGLEPP_TRAIN_INTEGRATED_HPP
6#include "scaled_ranks.hpp"
31template<
typename Value_,
typename Index_,
typename Label_>
36 const tatami::Matrix<Value_, Index_>* ref;
39 const std::vector<Index_>* test_subset;
41 std::shared_ptr<const Intersection<Index_> > intersection;
65template<
typename Value_,
typename Index_,
typename Label_,
typename Float_>
67 const tatami::Matrix<Value_, Index_>& ref,
73 output.labels = labels;
75 output.ref_markers = &(trained.
markers());
76 output.test_subset = &(trained.
subset());
78 output.test_nrow = ref.nrow();
105template<
typename Index_,
typename Value_,
typename Label_,
typename Float_>
109 const tatami::Matrix<Value_, Index_>& ref,
110 const Label_* labels,
115 output.labels = labels;
117 output.ref_markers = &(trained.
markers());
118 output.test_subset = &(trained.
subset());
120 output.test_nrow = test_nrow;
121 output.intersection = std::shared_ptr<const Intersection<Index_> >(std::shared_ptr<Intersection<Index_> >{}, &intersection);
151template<
typename Index_,
typename Id_,
typename Value_,
typename Label_,
typename Float_>
155 const tatami::Matrix<Value_, Index_>& ref,
157 const Label_* labels,
162 output.labels = labels;
164 output.ref_markers = &(trained.
markers());
165 output.test_subset = &(trained.
subset());
167 output.test_nrow = test_nrow;
168 auto intersection =
intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
169 output.intersection = std::shared_ptr<const Intersection<Index_> >(
new Intersection<Index_>(std::move(intersection)));
176template<
typename Index_>
177struct IntegratedReference {
178 struct DensePerLabel {
180 std::vector<Index_> markers;
181 RankedVector<Index_, Index_> all_ranked;
184 struct SparsePerLabel {
186 std::vector<Index_> markers;
187 RankedVector<Index_, Index_> negative_ranked, positive_ranked;
188 std::vector<std::size_t> negative_indptrs, positive_indptrs;
191 std::optional<std::vector<DensePerLabel> > dense;
192 std::optional<std::vector<SparsePerLabel> > sparse;
202template<
typename Index_>
210 my_universe(std::move(universe)),
211 my_references(std::move(references))
214 const auto& references()
const {
215 return my_references;
223 std::vector<Index_> my_universe;
224 std::vector<IntegratedReference<Index_> > my_references;
231 return my_references.size();
246 const std::vector<Index_>&
subset()
const {
255 const auto& ref = my_references[r];
256 if (ref.dense.has_value()) {
257 return ref.dense->size();
259 return ref.sparse->size();
268 std::size_t num_prof = 0;
269 const auto& ref = my_references[r];
270 if (ref.dense.has_value()) {
271 for (
const auto& lab : *(ref.dense)) {
272 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
275 for (
const auto& lab : *(ref.sparse)) {
276 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
297template<
bool ref_sparse_,
typename Value_,
typename Index_,
typename Label_>
298void train_integrated_per_reference_simple(
300 const std::vector<Index_>& universe,
301 const std::vector<Index_>& remap_test_to_universe,
303 const std::vector<Index_>& positions,
304 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
305 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&,
bool>::type other_ranked
307 const auto& ref = *(input.ref);
308 const auto NC = ref.ncol();
309 const auto num_universe = universe.size();
311 tatami::parallelize([&](
int, Index_ start, Index_ len) {
312 auto vbuffer = sanisizer::create<std::vector<Value_> >(num_universe);
313 auto ibuffer = [&](){
314 if constexpr(ref_sparse_) {
315 return sanisizer::create<std::vector<Index_> >(num_universe);
321 RankedVector<Value_, Index_> tmp_ranked;
322 tmp_ranked.reserve(num_universe);
326 tatami::VectorPtr<Index_> universe_ptr(tatami::VectorPtr<Index_>{}, &universe);
327 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref,
false, start, len, std::move(universe_ptr));
329 for (Index_ c = start, end = start + len; c < end; ++c) {
332 if constexpr(ref_sparse_) {
333 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
334 for (I<
decltype(info.number)> i = 0; i < info.number; ++i) {
335 const auto remapped = remap_test_to_universe[info.index[i]];
336 assert(sanisizer::is_less_than(remapped, num_universe));
337 tmp_ranked.emplace_back(info.value[i], remapped);
340 auto ptr = ext->fetch(vbuffer.data());
341 for (I<
decltype(num_universe)> i = 0; i < num_universe; ++i) {
342 tmp_ranked.emplace_back(ptr[i], i);
346 std::sort(tmp_ranked.begin(), tmp_ranked.end());
348 if constexpr(ref_sparse_) {
349 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
350 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
351 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
352 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
354 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
360template<
bool ref_sparse_,
typename Value_,
typename Index_,
typename Label_>
361void train_integrated_per_reference_intersect(
362 const TrainIntegratedInput<Value_, Label_, Index_>& input,
363 const std::vector<Index_>& remap_test_to_universe,
364 const Index_ test_nrow,
365 const TrainIntegratedOptions& options,
366 const std::vector<Index_>& positions,
367 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
368 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&,
bool>::type other_ranked
370 const auto& ref = *(input.ref);
371 const auto NC = ref.ncol();
373 std::vector<Index_> ref_subset;
374 sanisizer::reserve(ref_subset, input.intersection->size());
375 auto remap_ref_subset_to_universe = sanisizer::create<std::vector<Index_> >(ref.nrow(), test_nrow);
376 for (
const auto& pair : *(input.intersection)) {
377 const auto rdex = remap_test_to_universe[pair.first];
378 if (rdex != test_nrow) {
379 ref_subset.push_back(pair.second);
380 remap_ref_subset_to_universe[pair.second] = rdex;
383 std::sort(ref_subset.begin(), ref_subset.end());
385 typename std::conditional<ref_sparse_, bool, std::vector<Index_> >::type remap_dense_to_universe;
386 if constexpr(!ref_sparse_) {
387 remap_dense_to_universe.reserve(ref_subset.size());
388 for (
auto r : ref_subset) {
389 remap_dense_to_universe.push_back(remap_ref_subset_to_universe[r]);
393 tatami::parallelize([&](
int, Index_ start, Index_ len) {
394 const auto ref_subset_size = ref_subset.size();
395 auto vbuffer = sanisizer::create<std::vector<Value_> >(ref_subset_size);
396 auto ibuffer = [&]() {
397 if constexpr(ref_sparse_) {
398 return sanisizer::create<std::vector<Index_> >(ref_subset_size);
404 RankedVector<Value_, Index_> tmp_ranked;
405 tmp_ranked.reserve(ref_subset_size);
406 tatami::VectorPtr<Index_> to_extract_ptr(tatami::VectorPtr<Index_>{}, &ref_subset);
407 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref,
false, start, len, std::move(to_extract_ptr));
409 for (Index_ c = start, end = start + len; c < end; ++c) {
412 if constexpr(ref_sparse_) {
413 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
414 for (I<
decltype(info.number)> i = 0; i < info.number; ++i) {
415 tmp_ranked.emplace_back(info.value[i], remap_ref_subset_to_universe[info.index[i]]);
418 auto ptr = ext->fetch(vbuffer.data());
419 for (I<
decltype(ref_subset_size)> i = 0; i < ref_subset_size; ++i) {
420 tmp_ranked.emplace_back(ptr[i], remap_dense_to_universe[i]);
424 std::sort(tmp_ranked.begin(), tmp_ranked.end());
426 if constexpr(ref_sparse_) {
427 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
428 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
429 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
430 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
432 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
435 }, NC, options.num_threads);
451template<
typename Value_,
typename Index_,
typename Label_>
453 std::vector<Index_> universe;
454 const auto nrefs = inputs.size();
455 auto references = sanisizer::create<std::vector<IntegratedReference<Index_> > >(nrefs);
458 Index_ test_nrow = 0;
460 test_nrow = inputs.front().test_nrow;
461 for (
const auto& in : inputs) {
462 if (!sanisizer::is_equal(in.test_nrow, test_nrow)) {
463 throw std::runtime_error(
"inconsistent number of rows in the test dataset across entries of 'inputs'");
470 auto remap_test_to_universe = sanisizer::create<std::vector<Index_> >(test_nrow, test_nrow);
472 auto present = sanisizer::create<std::vector<char> >(test_nrow);
473 auto count_refs = sanisizer::create<std::vector<I<
decltype(nrefs)> > >(test_nrow);
474 universe.reserve(test_nrow);
476 for (
const auto& in : inputs) {
477 const auto& markers = *(in.ref_markers);
478 const auto& test_subset = *(in.test_subset);
480 for (
const auto& labmrk : markers) {
481 for (
const auto& mrk : labmrk) {
482 for (
const auto y : mrk) {
483 const auto ty = test_subset[y];
486 universe.push_back(ty);
492 if (in.intersection) {
493 for (
const auto& pp : *(in.intersection)) {
494 count_refs[pp.first] += 1;
497 for (
auto& x : count_refs) {
503 std::sort(universe.begin(), universe.end());
504 const auto num_universe = universe.size();
505 I<
decltype(num_universe)> keep = 0;
506 for (I<
decltype(num_universe)> u = 0; u < num_universe; ++u) {
507 const auto marker = universe[u];
508 if (count_refs[marker] == nrefs) {
509 universe[keep] = marker;
510 remap_test_to_universe[marker] = keep;
514 universe.resize(keep);
515 universe.shrink_to_fit();
518 auto is_active = sanisizer::create<std::vector<char> >(test_nrow);
519 std::vector<Index_> active_genes;
520 active_genes.reserve(test_nrow);
522 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
523 const auto& curinput = inputs[r];
524 const auto& currefmarkers = *(curinput.ref_markers);
525 const auto& test_subset = *(curinput.test_subset);
526 const auto nlabels = currefmarkers.size();
527 auto& currefout = references[r];
529 const Index_ NC = curinput.ref->ncol();
530 const bool is_sparse = curinput.ref->is_sparse();
532 currefout.sparse.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.sparse))> >(nlabels));
534 currefout.dense.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.dense))> >(nlabels));
538 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
539 active_genes.clear();
540 for (
const auto& labmark : currefmarkers[l]) {
541 for (
const auto y : labmark) {
542 const auto ty = test_subset[y];
543 if (!is_active[ty]) {
544 is_active[ty] =
true;
545 active_genes.push_back(ty);
550 std::vector<Index_> markers;
551 markers.reserve(active_genes.size());
553 for (
const auto a : active_genes) {
554 const auto universe_index = remap_test_to_universe[a];
555 if (universe_index != test_nrow) {
556 markers.push_back(universe_index);
558 is_active[a] =
false;
562 (*(currefout.sparse))[l].markers.swap(markers);
564 (*(currefout.dense))[l].markers.swap(markers);
569 std::vector<Index_> positions;
570 positions.reserve(NC);
571 auto samples_per_label = sanisizer::create<std::vector<Index_> >(nlabels);
572 for (Index_ c = 0; c < NC; ++c) {
573 auto& pos = samples_per_label[curinput.labels[c]];
574 positions.push_back(pos);
578 if (curinput.ref->is_sparse()) {
579 auto negative_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
580 auto positive_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
581 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
582 const auto num_samples = samples_per_label[l];
583 sanisizer::resize(negative_ranked[l], num_samples);
584 sanisizer::resize(positive_ranked[l], num_samples);
587 if (curinput.intersection) {
588 train_integrated_per_reference_intersect<true>(curinput, remap_test_to_universe, test_nrow, options, positions, negative_ranked, positive_ranked);
590 train_integrated_per_reference_simple<true, Value_>(curinput, universe, remap_test_to_universe, options, positions, negative_ranked, positive_ranked);
593 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
594 auto& curlabout = (*(currefout.sparse))[l];
595 const auto num_samples = samples_per_label[l];
596 curlabout.num_samples = num_samples;
598 I<
decltype(curlabout.negative_ranked.size())> num_neg = 0;
599 for (
const auto& x : negative_ranked[l]) {
600 num_neg = sanisizer::sum<I<
decltype(num_neg)> >(num_neg, x.size());
603 I<
decltype(curlabout.positive_ranked.size())> num_pos = 0;
604 for (
const auto& x : positive_ranked[l]) {
605 num_pos = sanisizer::sum<I<
decltype(num_pos)> >(num_pos, x.size());
608 curlabout.negative_ranked.reserve(num_neg);
609 curlabout.negative_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.negative_indptrs.size())> >(num_samples, 1));
610 curlabout.negative_indptrs.push_back(0);
611 for (
const auto& x : negative_ranked[l]) {
612 curlabout.negative_ranked.insert(curlabout.negative_ranked.end(), x.begin(), x.end());
613 curlabout.negative_indptrs.push_back(curlabout.negative_ranked.size());
616 curlabout.positive_ranked.reserve(num_pos);
617 curlabout.positive_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.positive_indptrs.size())> >(num_samples, 1));
618 curlabout.positive_indptrs.push_back(0);
619 for (
const auto& x : positive_ranked[l]) {
620 curlabout.positive_ranked.insert(curlabout.positive_ranked.end(), x.begin(), x.end());
621 curlabout.positive_indptrs.push_back(curlabout.positive_ranked.size());
626 auto out_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
627 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
628 const auto num_samples = samples_per_label[l];
629 sanisizer::resize(out_ranked[l], num_samples);
632 if (curinput.intersection) {
633 train_integrated_per_reference_intersect<false>(curinput, remap_test_to_universe, test_nrow, options, positions, out_ranked,
true);
635 train_integrated_per_reference_simple<false, Value_>(curinput, universe, remap_test_to_universe, options, positions, out_ranked,
true);
638 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
639 auto& curlabout = (*(currefout.dense))[l];
640 curlabout.num_samples = samples_per_label[l];
641 curlabout.all_ranked.reserve(sanisizer::product<I<
decltype(curlabout.all_ranked.size())> >(universe.size(), curlabout.num_samples));
642 for (
const auto& x : out_ranked[l]) {
643 curlabout.all_ranked.insert(curlabout.all_ranked.end(), x.begin(), x.end());
Create an intersection of genes.
Classifier that integrates multiple reference datasets.
Definition train_integrated.hpp:203
Index_ test_nrow() const
Definition train_integrated.hpp:237
std::size_t num_references() const
Definition train_integrated.hpp:230
const std::vector< Index_ > & subset() const
Definition train_integrated.hpp:246
std::size_t num_profiles(std::size_t r) const
Definition train_integrated.hpp:267
std::size_t num_labels(std::size_t r) const
Definition train_integrated.hpp:254
Classifier trained from a single reference.
Definition train_single.hpp:85
const std::vector< Index_ > & subset() const
Definition train_single.hpp:134
const Markers< Index_ > & markers() const
Definition train_single.hpp:125
Common definitions for singlepp.
Cell type classification using the SingleR algorithm in C++.
Definition classify_single.hpp:20
std::vector< std::vector< std::vector< Index_ > > > Markers
Definition Markers.hpp:40
Intersection< Index_ > intersect_genes(Index_ test_nrow, const Id_ *test_id, Index_ ref_nrow, const Id_ *ref_id)
Definition Intersection.hpp:54
TrainIntegratedInput< Value_, Index_, Label_ > prepare_integrated_input(const tatami::Matrix< Value_, Index_ > &ref, const Label_ *labels, const TrainedSingle< Index_, Float_ > &trained)
Definition train_integrated.hpp:66
std::vector< std::pair< Index_, Index_ > > Intersection
Definition Intersection.hpp:35
TrainedIntegrated< Index_ > train_integrated(const std::vector< TrainIntegratedInput< Value_, Index_, Label_ > > &inputs, const TrainIntegratedOptions &options)
Definition train_integrated.hpp:452
Options for train_integrated().
Definition train_integrated.hpp:286
int num_threads
Definition train_integrated.hpp:291
Train a classifier from a single reference.