1#ifndef SINGLEPP_TRAIN_INTEGRATED_HPP
2#define SINGLEPP_TRAIN_INTEGRATED_HPP
30template<
typename Value_,
typename Index_,
typename Label_>
35 const tatami::Matrix<Value_, Index_>* ref;
38 const std::vector<Index_>* test_subset;
40 std::shared_ptr<const Intersection<Index_> > intersection;
64template<
typename Value_,
typename Index_,
typename Label_,
typename Float_>
66 const tatami::Matrix<Value_, Index_>& ref,
72 output.labels = labels;
74 output.ref_markers = &(trained.
markers());
75 output.test_subset = &(trained.
subset());
77 output.test_nrow = ref.nrow();
104template<
typename Index_,
typename Value_,
typename Label_,
typename Float_>
108 const tatami::Matrix<Value_, Index_>& ref,
109 const Label_* labels,
114 output.labels = labels;
116 output.ref_markers = &(trained.
markers());
117 output.test_subset = &(trained.
subset());
119 output.test_nrow = test_nrow;
120 output.intersection = std::shared_ptr<const Intersection<Index_> >(std::shared_ptr<Intersection<Index_> >{}, &intersection);
150template<
typename Index_,
typename Id_,
typename Value_,
typename Label_,
typename Float_>
154 const tatami::Matrix<Value_, Index_>& ref,
156 const Label_* labels,
161 output.labels = labels;
163 output.ref_markers = &(trained.
markers());
164 output.test_subset = &(trained.
subset());
166 output.test_nrow = test_nrow;
167 auto intersection =
intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
168 output.intersection = std::shared_ptr<const Intersection<Index_> >(
new Intersection<Index_>(std::move(intersection)));
175template<
typename Index_>
176struct IntegratedReference {
177 struct DensePerLabel {
179 std::vector<Index_> markers;
180 RankedVector<Index_, Index_> all_ranked;
183 struct SparsePerLabel {
185 std::vector<Index_> markers;
186 RankedVector<Index_, Index_> negative_ranked, positive_ranked;
187 std::vector<std::size_t> negative_indptrs, positive_indptrs;
190 std::optional<std::vector<DensePerLabel> > dense;
191 std::optional<std::vector<SparsePerLabel> > sparse;
201template<
typename Index_>
209 my_universe(std::move(universe)),
210 my_references(std::move(references))
213 const auto& references()
const {
214 return my_references;
222 std::vector<Index_> my_universe;
223 std::vector<IntegratedReference<Index_> > my_references;
230 return my_references.size();
245 const std::vector<Index_>&
subset()
const {
254 const auto& ref = my_references[r];
255 if (ref.dense.has_value()) {
256 return ref.dense->size();
258 return ref.sparse->size();
267 std::size_t num_prof = 0;
268 const auto& ref = my_references[r];
269 if (ref.dense.has_value()) {
270 for (
const auto& lab : *(ref.dense)) {
271 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
274 for (
const auto& lab : *(ref.sparse)) {
275 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
296template<
bool ref_sparse_,
typename Value_,
typename Index_,
typename Label_>
297void train_integrated_per_reference_simple(
299 const std::vector<Index_>& universe,
300 const std::vector<Index_>& remap_test_to_universe,
302 const std::vector<Index_>& positions,
303 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
304 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&,
bool>::type other_ranked
306 const auto& ref = *(input.ref);
307 const auto NC = ref.ncol();
308 const auto num_universe = universe.size();
310 tatami::parallelize([&](
int, Index_ start, Index_ len) {
311 auto vbuffer = sanisizer::create<std::vector<Value_> >(num_universe);
312 auto ibuffer = [&](){
313 if constexpr(ref_sparse_) {
314 return sanisizer::create<std::vector<Index_> >(num_universe);
320 RankedVector<Value_, Index_> tmp_ranked;
321 tmp_ranked.reserve(num_universe);
325 tatami::VectorPtr<Index_> universe_ptr(tatami::VectorPtr<Index_>{}, &universe);
326 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref,
false, start, len, std::move(universe_ptr));
328 for (Index_ c = start, end = start + len; c < end; ++c) {
331 if constexpr(ref_sparse_) {
332 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
333 for (I<
decltype(info.number)> i = 0; i < info.number; ++i) {
334 const auto remapped = remap_test_to_universe[info.index[i]];
335 assert(sanisizer::is_less_than(remapped, num_universe));
336 tmp_ranked.emplace_back(info.value[i], remapped);
339 auto ptr = ext->fetch(vbuffer.data());
340 for (I<
decltype(num_universe)> i = 0; i < num_universe; ++i) {
341 tmp_ranked.emplace_back(ptr[i], i);
345 std::sort(tmp_ranked.begin(), tmp_ranked.end());
347 if constexpr(ref_sparse_) {
348 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
349 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
350 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
351 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
353 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
359template<
bool ref_sparse_,
typename Value_,
typename Index_,
typename Label_>
360void train_integrated_per_reference_intersect(
361 const TrainIntegratedInput<Value_, Label_, Index_>& input,
362 const std::vector<Index_>& remap_test_to_universe,
363 const Index_ test_nrow,
364 const TrainIntegratedOptions& options,
365 const std::vector<Index_>& positions,
366 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
367 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&,
bool>::type other_ranked
369 const auto& ref = *(input.ref);
370 const auto NC = ref.ncol();
372 std::vector<Index_> ref_subset;
373 sanisizer::reserve(ref_subset, input.intersection->size());
374 auto remap_ref_subset_to_universe = sanisizer::create<std::vector<Index_> >(ref.nrow(), test_nrow);
375 for (
const auto& pair : *(input.intersection)) {
376 const auto rdex = remap_test_to_universe[pair.first];
377 if (rdex != test_nrow) {
378 ref_subset.push_back(pair.second);
379 remap_ref_subset_to_universe[pair.second] = rdex;
382 std::sort(ref_subset.begin(), ref_subset.end());
384 typename std::conditional<ref_sparse_, bool, std::vector<Index_> >::type remap_dense_to_universe;
385 if constexpr(!ref_sparse_) {
386 remap_dense_to_universe.reserve(ref_subset.size());
387 for (
auto r : ref_subset) {
388 remap_dense_to_universe.push_back(remap_ref_subset_to_universe[r]);
392 tatami::parallelize([&](
int, Index_ start, Index_ len) {
393 const auto ref_subset_size = ref_subset.size();
394 auto vbuffer = sanisizer::create<std::vector<Value_> >(ref_subset_size);
395 auto ibuffer = [&]() {
396 if constexpr(ref_sparse_) {
397 return sanisizer::create<std::vector<Index_> >(ref_subset_size);
403 RankedVector<Value_, Index_> tmp_ranked;
404 tmp_ranked.reserve(ref_subset_size);
405 tatami::VectorPtr<Index_> to_extract_ptr(tatami::VectorPtr<Index_>{}, &ref_subset);
406 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref,
false, start, len, std::move(to_extract_ptr));
408 for (Index_ c = start, end = start + len; c < end; ++c) {
411 if constexpr(ref_sparse_) {
412 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
413 for (I<
decltype(info.number)> i = 0; i < info.number; ++i) {
414 tmp_ranked.emplace_back(info.value[i], remap_ref_subset_to_universe[info.index[i]]);
417 auto ptr = ext->fetch(vbuffer.data());
418 for (I<
decltype(ref_subset_size)> i = 0; i < ref_subset_size; ++i) {
419 tmp_ranked.emplace_back(ptr[i], remap_dense_to_universe[i]);
423 std::sort(tmp_ranked.begin(), tmp_ranked.end());
425 if constexpr(ref_sparse_) {
426 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
427 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
428 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
429 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
431 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
434 }, NC, options.num_threads);
450template<
typename Value_,
typename Index_,
typename Label_>
452 std::vector<Index_> universe;
453 const auto nrefs = inputs.size();
454 auto references = sanisizer::create<std::vector<IntegratedReference<Index_> > >(nrefs);
457 Index_ test_nrow = 0;
459 test_nrow = inputs.front().test_nrow;
460 for (
const auto& in : inputs) {
461 if (!sanisizer::is_equal(in.test_nrow, test_nrow)) {
462 throw std::runtime_error(
"inconsistent number of rows in the test dataset across entries of 'inputs'");
469 auto remap_test_to_universe = sanisizer::create<std::vector<Index_> >(test_nrow, test_nrow);
471 auto present = sanisizer::create<std::vector<char> >(test_nrow);
472 auto count_refs = sanisizer::create<std::vector<I<
decltype(nrefs)> > >(test_nrow);
473 universe.reserve(test_nrow);
475 for (
const auto& in : inputs) {
476 const auto& markers = *(in.ref_markers);
477 const auto& test_subset = *(in.test_subset);
479 for (
const auto& labmrk : markers) {
480 for (
const auto& mrk : labmrk) {
481 for (
const auto y : mrk) {
482 const auto ty = test_subset[y];
485 universe.push_back(ty);
491 if (in.intersection) {
492 for (
const auto& pp : *(in.intersection)) {
493 count_refs[pp.first] += 1;
496 for (
auto& x : count_refs) {
502 std::sort(universe.begin(), universe.end());
503 const auto num_universe = universe.size();
504 I<
decltype(num_universe)> keep = 0;
505 for (I<
decltype(num_universe)> u = 0; u < num_universe; ++u) {
506 const auto marker = universe[u];
507 if (count_refs[marker] == nrefs) {
508 universe[keep] = marker;
509 remap_test_to_universe[marker] = keep;
513 universe.resize(keep);
514 universe.shrink_to_fit();
517 auto is_active = sanisizer::create<std::vector<char> >(test_nrow);
518 std::vector<Index_> active_genes;
519 active_genes.reserve(test_nrow);
521 for (I<
decltype(nrefs)> r = 0; r < nrefs; ++r) {
522 const auto& curinput = inputs[r];
523 const auto& currefmarkers = *(curinput.ref_markers);
524 const auto& test_subset = *(curinput.test_subset);
525 const auto nlabels = currefmarkers.size();
526 auto& currefout = references[r];
528 const Index_ NC = curinput.ref->ncol();
529 const bool is_sparse = curinput.ref->is_sparse();
531 currefout.sparse.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.sparse))> >(nlabels));
533 currefout.dense.emplace(sanisizer::as_size_type<I<
decltype(*(currefout.dense))> >(nlabels));
537 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
538 active_genes.clear();
539 for (
const auto& labmark : currefmarkers[l]) {
540 for (
const auto y : labmark) {
541 const auto ty = test_subset[y];
542 if (!is_active[ty]) {
543 is_active[ty] =
true;
544 active_genes.push_back(ty);
549 std::vector<Index_> markers;
550 markers.reserve(active_genes.size());
552 for (
const auto a : active_genes) {
553 const auto universe_index = remap_test_to_universe[a];
554 if (universe_index != test_nrow) {
555 markers.push_back(universe_index);
557 is_active[a] =
false;
561 (*(currefout.sparse))[l].markers.swap(markers);
563 (*(currefout.dense))[l].markers.swap(markers);
568 std::vector<Index_> positions;
569 positions.reserve(NC);
570 auto samples_per_label = sanisizer::create<std::vector<Index_> >(nlabels);
571 for (Index_ c = 0; c < NC; ++c) {
572 auto& pos = samples_per_label[curinput.labels[c]];
573 positions.push_back(pos);
577 if (curinput.ref->is_sparse()) {
578 auto negative_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
579 auto positive_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
580 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
581 const auto num_samples = samples_per_label[l];
582 sanisizer::resize(negative_ranked[l], num_samples);
583 sanisizer::resize(positive_ranked[l], num_samples);
586 if (curinput.intersection) {
587 train_integrated_per_reference_intersect<true>(curinput, remap_test_to_universe, test_nrow, options, positions, negative_ranked, positive_ranked);
589 train_integrated_per_reference_simple<true, Value_>(curinput, universe, remap_test_to_universe, options, positions, negative_ranked, positive_ranked);
592 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
593 auto& curlabout = (*(currefout.sparse))[l];
594 const auto num_samples = samples_per_label[l];
595 curlabout.num_samples = num_samples;
597 I<
decltype(curlabout.negative_ranked.size())> num_neg = 0;
598 for (
const auto& x : negative_ranked[l]) {
599 num_neg = sanisizer::sum<I<
decltype(num_neg)> >(num_neg, x.size());
602 I<
decltype(curlabout.positive_ranked.size())> num_pos = 0;
603 for (
const auto& x : positive_ranked[l]) {
604 num_pos = sanisizer::sum<I<
decltype(num_pos)> >(num_pos, x.size());
607 curlabout.negative_ranked.reserve(num_neg);
608 curlabout.negative_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.negative_indptrs.size())> >(num_samples, 1));
609 curlabout.negative_indptrs.push_back(0);
610 for (
const auto& x : negative_ranked[l]) {
611 curlabout.negative_ranked.insert(curlabout.negative_ranked.end(), x.begin(), x.end());
612 curlabout.negative_indptrs.push_back(curlabout.negative_ranked.size());
615 curlabout.positive_ranked.reserve(num_pos);
616 curlabout.positive_indptrs.reserve(sanisizer::sum<I<
decltype(curlabout.positive_indptrs.size())> >(num_samples, 1));
617 curlabout.positive_indptrs.push_back(0);
618 for (
const auto& x : positive_ranked[l]) {
619 curlabout.positive_ranked.insert(curlabout.positive_ranked.end(), x.begin(), x.end());
620 curlabout.positive_indptrs.push_back(curlabout.positive_ranked.size());
625 auto out_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
626 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
627 const auto num_samples = samples_per_label[l];
628 sanisizer::resize(out_ranked[l], num_samples);
631 if (curinput.intersection) {
632 train_integrated_per_reference_intersect<false>(curinput, remap_test_to_universe, test_nrow, options, positions, out_ranked,
true);
634 train_integrated_per_reference_simple<false, Value_>(curinput, universe, remap_test_to_universe, options, positions, out_ranked,
true);
637 for (I<
decltype(nlabels)> l = 0; l < nlabels; ++l) {
638 auto& curlabout = (*(currefout.dense))[l];
639 curlabout.num_samples = samples_per_label[l];
640 curlabout.all_ranked.reserve(sanisizer::product<I<
decltype(curlabout.all_ranked.size())> >(universe.size(), curlabout.num_samples));
641 for (
const auto& x : out_ranked[l]) {
642 curlabout.all_ranked.insert(curlabout.all_ranked.end(), x.begin(), x.end());
Create an intersection of genes.
Classifier that integrates multiple reference datasets.
Definition train_integrated.hpp:202
Index_ test_nrow() const
Definition train_integrated.hpp:236
std::size_t num_references() const
Definition train_integrated.hpp:229
const std::vector< Index_ > & subset() const
Definition train_integrated.hpp:245
std::size_t num_profiles(std::size_t r) const
Definition train_integrated.hpp:266
std::size_t num_labels(std::size_t r) const
Definition train_integrated.hpp:253
Classifier trained from a single reference.
Definition train_single.hpp:85
const std::vector< Index_ > & subset() const
Definition train_single.hpp:134
const Markers< Index_ > & markers() const
Definition train_single.hpp:125
Common definitions for singlepp.
Cell type classification using the SingleR algorithm in C++.
Definition classify_single.hpp:20
std::vector< std::vector< std::vector< Index_ > > > Markers
Definition Markers.hpp:40
Intersection< Index_ > intersect_genes(Index_ test_nrow, const Id_ *test_id, Index_ ref_nrow, const Id_ *ref_id)
Definition Intersection.hpp:54
TrainIntegratedInput< Value_, Index_, Label_ > prepare_integrated_input(const tatami::Matrix< Value_, Index_ > &ref, const Label_ *labels, const TrainedSingle< Index_, Float_ > &trained)
Definition train_integrated.hpp:65
std::vector< std::pair< Index_, Index_ > > Intersection
Definition Intersection.hpp:35
TrainedIntegrated< Index_ > train_integrated(const std::vector< TrainIntegratedInput< Value_, Index_, Label_ > > &inputs, const TrainIntegratedOptions &options)
Definition train_integrated.hpp:451
Options for train_integrated().
Definition train_integrated.hpp:285
int num_threads
Definition train_integrated.hpp:290
Train a classifier from a single reference.