singlepp
A C++ library for cell type classification
Loading...
Searching...
No Matches
train_integrated.hpp
Go to the documentation of this file.
1#ifndef SINGLEPP_TRAIN_INTEGRATED_HPP
2#define SINGLEPP_TRAIN_INTEGRATED_HPP
3
4#include "defs.hpp"
5
6#include "train_single.hpp"
7#include "Intersection.hpp"
8
9#include <vector>
10#include <algorithm>
11#include <cstdint>
12#include <memory>
13
19namespace singlepp {
20
30template<typename Value_, typename Index_, typename Label_>
35 const tatami::Matrix<Value_, Index_>* ref;
36 const Label_* labels;
37 const Markers<Index_>* ref_markers;
38 const std::vector<Index_>* test_subset;
39 Index_ test_nrow;
40 std::shared_ptr<const Intersection<Index_> > intersection;
44};
45
64template<typename Value_, typename Index_, typename Label_, typename Float_>
66 const tatami::Matrix<Value_, Index_>& ref,
67 const Label_* labels,
69) {
71 output.ref = &ref;
72 output.labels = labels;
73
74 output.ref_markers = &(trained.markers());
75 output.test_subset = &(trained.subset());
76
77 output.test_nrow = ref.nrow(); // remember, test and ref are assumed to have the same features.
78 return output;
79}
80
104template<typename Index_, typename Value_, typename Label_, typename Float_>
106 Index_ test_nrow,
107 const Intersection<Index_>& intersection,
108 const tatami::Matrix<Value_, Index_>& ref,
109 const Label_* labels,
110 const TrainedSingle<Index_, Float_>& trained
111) {
113 output.ref = &ref;
114 output.labels = labels;
115
116 output.ref_markers = &(trained.markers());
117 output.test_subset = &(trained.subset());
118
119 output.test_nrow = test_nrow;
120 output.intersection = std::shared_ptr<const Intersection<Index_> >(std::shared_ptr<Intersection<Index_> >{}, &intersection);
121 return output;
122}
123
150template<typename Index_, typename Id_, typename Value_, typename Label_, typename Float_>
152 Index_ test_nrow,
153 const Id_* test_id,
154 const tatami::Matrix<Value_, Index_>& ref,
155 const Id_* ref_id,
156 const Label_* labels,
157 const TrainedSingle<Index_, Float_>& trained
158) {
160 output.ref = &ref;
161 output.labels = labels;
162
163 output.ref_markers = &(trained.markers());
164 output.test_subset = &(trained.subset());
165
166 output.test_nrow = test_nrow;
167 auto intersection = intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
168 output.intersection = std::shared_ptr<const Intersection<Index_> >(new Intersection<Index_>(std::move(intersection)));
169 return output;
170}
171
175template<typename Index_>
176struct IntegratedReference {
177 struct DensePerLabel {
178 Index_ num_samples;
179 std::vector<Index_> markers; // indices to 'universe'
180 RankedVector<Index_, Index_> all_ranked; // .second contains indices to 'universe'
181 };
182
183 struct SparsePerLabel {
184 Index_ num_samples;
185 std::vector<Index_> markers; // indices to 'universe'
186 RankedVector<Index_, Index_> negative_ranked, positive_ranked; // .second contains indices to 'universe'
187 std::vector<std::size_t> negative_indptrs, positive_indptrs;
188 };
189
190 std::optional<std::vector<DensePerLabel> > dense;
191 std::optional<std::vector<SparsePerLabel> > sparse;
192};
201template<typename Index_>
203public:
207 TrainedIntegrated(Index_ test_nrow, std::vector<Index_> universe, std::vector<IntegratedReference<Index_> > references) :
208 my_test_nrow(test_nrow),
209 my_universe(std::move(universe)),
210 my_references(std::move(references))
211 {}
212
213 const auto& references() const {
214 return my_references;
215 }
220private:
221 Index_ my_test_nrow;
222 std::vector<Index_> my_universe;
223 std::vector<IntegratedReference<Index_> > my_references;
224
225public:
229 std::size_t num_references() const {
230 return my_references.size();
231 }
232
236 Index_ test_nrow() const {
237 return my_test_nrow;
238 }
239
245 const std::vector<Index_>& subset() const {
246 return my_universe;
247 }
248
253 std::size_t num_labels(std::size_t r) const {
254 const auto& ref = my_references[r];
255 if (ref.dense.has_value()) {
256 return ref.dense->size();
257 } else {
258 return ref.sparse->size();
259 }
260 }
261
266 std::size_t num_profiles(std::size_t r) const {
267 std::size_t num_prof = 0;
268 const auto& ref = my_references[r];
269 if (ref.dense.has_value()) {
270 for (const auto& lab : *(ref.dense)) {
271 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
272 }
273 } else {
274 for (const auto& lab : *(ref.sparse)) {
275 num_prof += sanisizer::sum<std::size_t>(num_prof, lab.num_samples);
276 }
277 }
278 return num_prof;
279 }
280};
281
292
296template<bool ref_sparse_, typename Value_, typename Index_, typename Label_>
297void train_integrated_per_reference_simple(
299 const std::vector<Index_>& universe,
300 const std::vector<Index_>& remap_test_to_universe,
301 const TrainIntegratedOptions& options,
302 const std::vector<Index_>& positions,
303 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
304 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&, bool>::type other_ranked
305) {
306 const auto& ref = *(input.ref);
307 const auto NC = ref.ncol();
308 const auto num_universe = universe.size();
309
310 tatami::parallelize([&](int, Index_ start, Index_ len) {
311 auto vbuffer = sanisizer::create<std::vector<Value_> >(num_universe);
312 auto ibuffer = [&](){
313 if constexpr(ref_sparse_) {
314 return sanisizer::create<std::vector<Index_> >(num_universe);
315 } else {
316 return false;
317 }
318 }();
319
320 RankedVector<Value_, Index_> tmp_ranked;
321 tmp_ranked.reserve(num_universe);
322
323 // 'universe' technically refers to the row indices of the test matrix,
324 // but in simple mode, the rows of the test and reference are the same, so we can use it directly here.
325 tatami::VectorPtr<Index_> universe_ptr(tatami::VectorPtr<Index_>{}, &universe);
326 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref, false, start, len, std::move(universe_ptr));
327
328 for (Index_ c = start, end = start + len; c < end; ++c) {
329 tmp_ranked.clear();
330
331 if constexpr(ref_sparse_) {
332 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
333 for (I<decltype(info.number)> i = 0; i < info.number; ++i) {
334 const auto remapped = remap_test_to_universe[info.index[i]];
335 assert(sanisizer::is_less_than(remapped, num_universe));
336 tmp_ranked.emplace_back(info.value[i], remapped);
337 }
338 } else {
339 auto ptr = ext->fetch(vbuffer.data());
340 for (I<decltype(num_universe)> i = 0; i < num_universe; ++i) {
341 tmp_ranked.emplace_back(ptr[i], i); // a.k.a. remap_test_to_universe[universe[i]].
342 }
343 }
344
345 std::sort(tmp_ranked.begin(), tmp_ranked.end());
346
347 if constexpr(ref_sparse_) {
348 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
349 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
350 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
351 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
352 } else {
353 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
354 }
355 }
356 }, NC, options.num_threads);
357}
358
359template<bool ref_sparse_, typename Value_, typename Index_, typename Label_>
360void train_integrated_per_reference_intersect(
361 const TrainIntegratedInput<Value_, Label_, Index_>& input,
362 const std::vector<Index_>& remap_test_to_universe,
363 const Index_ test_nrow,
364 const TrainIntegratedOptions& options,
365 const std::vector<Index_>& positions,
366 std::vector<std::vector<RankedVector<Index_, Index_> > >& out_ranked,
367 typename std::conditional<ref_sparse_, std::vector<std::vector<RankedVector<Index_, Index_> > >&, bool>::type other_ranked
368) {
369 const auto& ref = *(input.ref);
370 const auto NC = ref.ncol();
371
372 std::vector<Index_> ref_subset;
373 sanisizer::reserve(ref_subset, input.intersection->size());
374 auto remap_ref_subset_to_universe = sanisizer::create<std::vector<Index_> >(ref.nrow(), test_nrow); // all entries of remap_test_to_universe are less than test_nrow.
375 for (const auto& pair : *(input.intersection)) {
376 const auto rdex = remap_test_to_universe[pair.first];
377 if (rdex != test_nrow) {
378 ref_subset.push_back(pair.second);
379 remap_ref_subset_to_universe[pair.second] = rdex;
380 }
381 }
382 std::sort(ref_subset.begin(), ref_subset.end());
383
384 typename std::conditional<ref_sparse_, bool, std::vector<Index_> >::type remap_dense_to_universe;
385 if constexpr(!ref_sparse_) {
386 remap_dense_to_universe.reserve(ref_subset.size());
387 for (auto r : ref_subset) {
388 remap_dense_to_universe.push_back(remap_ref_subset_to_universe[r]);
389 }
390 }
391
392 tatami::parallelize([&](int, Index_ start, Index_ len) {
393 const auto ref_subset_size = ref_subset.size();
394 auto vbuffer = sanisizer::create<std::vector<Value_> >(ref_subset_size);
395 auto ibuffer = [&]() {
396 if constexpr(ref_sparse_) {
397 return sanisizer::create<std::vector<Index_> >(ref_subset_size);
398 } else {
399 return false;
400 }
401 }();
402
403 RankedVector<Value_, Index_> tmp_ranked;
404 tmp_ranked.reserve(ref_subset_size);
405 tatami::VectorPtr<Index_> to_extract_ptr(tatami::VectorPtr<Index_>{}, &ref_subset);
406 auto ext = tatami::consecutive_extractor<ref_sparse_>(ref, false, start, len, std::move(to_extract_ptr));
407
408 for (Index_ c = start, end = start + len; c < end; ++c) {
409 tmp_ranked.clear();
410
411 if constexpr(ref_sparse_) {
412 auto info = ext->fetch(vbuffer.data(), ibuffer.data());
413 for (I<decltype(info.number)> i = 0; i < info.number; ++i) {
414 tmp_ranked.emplace_back(info.value[i], remap_ref_subset_to_universe[info.index[i]]);
415 }
416 } else {
417 auto ptr = ext->fetch(vbuffer.data());
418 for (I<decltype(ref_subset_size)> i = 0; i < ref_subset_size; ++i) {
419 tmp_ranked.emplace_back(ptr[i], remap_dense_to_universe[i]);
420 }
421 }
422
423 std::sort(tmp_ranked.begin(), tmp_ranked.end());
424
425 if constexpr(ref_sparse_) {
426 const auto tStart = tmp_ranked.begin(), tEnd = tmp_ranked.end();
427 auto zero_ranges = find_zero_ranges<Value_, Index_>(tStart, tEnd);
428 simplify_ranks<Value_, Index_>(tStart, zero_ranges.first, out_ranked[input.labels[c]][positions[c]]);
429 simplify_ranks<Value_, Index_>(zero_ranges.second, tEnd, other_ranked[input.labels[c]][positions[c]]);
430 } else {
431 simplify_ranks(tmp_ranked, out_ranked[input.labels[c]][positions[c]]);
432 }
433 }
434 }, NC, options.num_threads);
435}
450template<typename Value_, typename Index_, typename Label_>
452 std::vector<Index_> universe;
453 const auto nrefs = inputs.size();
454 auto references = sanisizer::create<std::vector<IntegratedReference<Index_> > >(nrefs);
455
456 // Checking that the number of genes in the test dataset are consistent.
457 Index_ test_nrow = 0;
458 if (inputs.size()) {
459 test_nrow = inputs.front().test_nrow;
460 for (const auto& in : inputs) {
461 if (!sanisizer::is_equal(in.test_nrow, test_nrow)) {
462 throw std::runtime_error("inconsistent number of rows in the test dataset across entries of 'inputs'");
463 }
464 }
465 }
466
467 // Identify the union of all marker genes as the universe, but excluding those markers that are not present in intersections.
468 // Note that 'universe' contains sorted and unique row indices for the test matrix, where 'remap_test_to_universe[universe[i]] == i'.
469 auto remap_test_to_universe = sanisizer::create<std::vector<Index_> >(test_nrow, test_nrow);
470 {
471 auto present = sanisizer::create<std::vector<char> >(test_nrow);
472 auto count_refs = sanisizer::create<std::vector<I<decltype(nrefs)> > >(test_nrow);
473 universe.reserve(test_nrow);
474
475 for (const auto& in : inputs) {
476 const auto& markers = *(in.ref_markers);
477 const auto& test_subset = *(in.test_subset);
478
479 for (const auto& labmrk : markers) {
480 for (const auto& mrk : labmrk) {
481 for (const auto y : mrk) {
482 const auto ty = test_subset[y];
483 if (!present[ty]) {
484 present[ty] = true;
485 universe.push_back(ty);
486 }
487 }
488 }
489 }
490
491 if (in.intersection) {
492 for (const auto& pp : *(in.intersection)) {
493 count_refs[pp.first] += 1;
494 }
495 } else {
496 for (auto& x : count_refs) {
497 x += 1;
498 }
499 }
500 }
501
502 std::sort(universe.begin(), universe.end());
503 const auto num_universe = universe.size();
504 I<decltype(num_universe)> keep = 0;
505 for (I<decltype(num_universe)> u = 0; u < num_universe; ++u) {
506 const auto marker = universe[u];
507 if (count_refs[marker] == nrefs) {
508 universe[keep] = marker;
509 remap_test_to_universe[marker] = keep;
510 ++keep;
511 }
512 }
513 universe.resize(keep);
514 universe.shrink_to_fit();
515 }
516
517 auto is_active = sanisizer::create<std::vector<char> >(test_nrow);
518 std::vector<Index_> active_genes;
519 active_genes.reserve(test_nrow);
520
521 for (I<decltype(nrefs)> r = 0; r < nrefs; ++r) {
522 const auto& curinput = inputs[r];
523 const auto& currefmarkers = *(curinput.ref_markers);
524 const auto& test_subset = *(curinput.test_subset);
525 const auto nlabels = currefmarkers.size();
526 auto& currefout = references[r];
527
528 const Index_ NC = curinput.ref->ncol();
529 const bool is_sparse = curinput.ref->is_sparse();
530 if (is_sparse) {
531 currefout.sparse.emplace(sanisizer::as_size_type<I<decltype(*(currefout.sparse))> >(nlabels));
532 } else {
533 currefout.dense.emplace(sanisizer::as_size_type<I<decltype(*(currefout.dense))> >(nlabels));
534 }
535
536 // Assembling the per-label markers for this reference.
537 for (I<decltype(nlabels)> l = 0; l < nlabels; ++l) {
538 active_genes.clear();
539 for (const auto& labmark : currefmarkers[l]) {
540 for (const auto y : labmark) {
541 const auto ty = test_subset[y];
542 if (!is_active[ty]) {
543 is_active[ty] = true;
544 active_genes.push_back(ty);
545 }
546 }
547 }
548
549 std::vector<Index_> markers;
550 markers.reserve(active_genes.size());
551
552 for (const auto a : active_genes) {
553 const auto universe_index = remap_test_to_universe[a];
554 if (universe_index != test_nrow) { // ignoring genes not in the intersection.
555 markers.push_back(universe_index);
556 }
557 is_active[a] = false;
558 }
559
560 if (is_sparse) {
561 (*(currefout.sparse))[l].markers.swap(markers);
562 } else {
563 (*(currefout.dense))[l].markers.swap(markers);
564 }
565 }
566
567 // Pre-allocating the ranked vectors.
568 std::vector<Index_> positions;
569 positions.reserve(NC);
570 auto samples_per_label = sanisizer::create<std::vector<Index_> >(nlabels);
571 for (Index_ c = 0; c < NC; ++c) {
572 auto& pos = samples_per_label[curinput.labels[c]];
573 positions.push_back(pos);
574 ++pos;
575 }
576
577 if (curinput.ref->is_sparse()) {
578 auto negative_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
579 auto positive_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
580 for (I<decltype(nlabels)> l = 0; l < nlabels; ++l) {
581 const auto num_samples = samples_per_label[l];
582 sanisizer::resize(negative_ranked[l], num_samples);
583 sanisizer::resize(positive_ranked[l], num_samples);
584 }
585
586 if (curinput.intersection) {
587 train_integrated_per_reference_intersect<true>(curinput, remap_test_to_universe, test_nrow, options, positions, negative_ranked, positive_ranked);
588 } else {
589 train_integrated_per_reference_simple<true, Value_>(curinput, universe, remap_test_to_universe, options, positions, negative_ranked, positive_ranked);
590 }
591
592 for (I<decltype(nlabels)> l = 0; l < nlabels; ++l) {
593 auto& curlabout = (*(currefout.sparse))[l];
594 const auto num_samples = samples_per_label[l];
595 curlabout.num_samples = num_samples;
596
597 I<decltype(curlabout.negative_ranked.size())> num_neg = 0;
598 for (const auto& x : negative_ranked[l]) {
599 num_neg = sanisizer::sum<I<decltype(num_neg)> >(num_neg, x.size());
600 }
601
602 I<decltype(curlabout.positive_ranked.size())> num_pos = 0;
603 for (const auto& x : positive_ranked[l]) {
604 num_pos = sanisizer::sum<I<decltype(num_pos)> >(num_pos, x.size());
605 }
606
607 curlabout.negative_ranked.reserve(num_neg);
608 curlabout.negative_indptrs.reserve(sanisizer::sum<I<decltype(curlabout.negative_indptrs.size())> >(num_samples, 1));
609 curlabout.negative_indptrs.push_back(0);
610 for (const auto& x : negative_ranked[l]) {
611 curlabout.negative_ranked.insert(curlabout.negative_ranked.end(), x.begin(), x.end());
612 curlabout.negative_indptrs.push_back(curlabout.negative_ranked.size());
613 }
614
615 curlabout.positive_ranked.reserve(num_pos);
616 curlabout.positive_indptrs.reserve(sanisizer::sum<I<decltype(curlabout.positive_indptrs.size())> >(num_samples, 1));
617 curlabout.positive_indptrs.push_back(0);
618 for (const auto& x : positive_ranked[l]) {
619 curlabout.positive_ranked.insert(curlabout.positive_ranked.end(), x.begin(), x.end());
620 curlabout.positive_indptrs.push_back(curlabout.positive_ranked.size());
621 }
622 }
623
624 } else {
625 auto out_ranked = sanisizer::create<std::vector<std::vector<RankedVector<Index_, Index_> > > >(nlabels);
626 for (I<decltype(nlabels)> l = 0; l < nlabels; ++l) {
627 const auto num_samples = samples_per_label[l];
628 sanisizer::resize(out_ranked[l], num_samples);
629 }
630
631 if (curinput.intersection) {
632 train_integrated_per_reference_intersect<false>(curinput, remap_test_to_universe, test_nrow, options, positions, out_ranked, true);
633 } else {
634 train_integrated_per_reference_simple<false, Value_>(curinput, universe, remap_test_to_universe, options, positions, out_ranked, true);
635 }
636
637 for (I<decltype(nlabels)> l = 0; l < nlabels; ++l) {
638 auto& curlabout = (*(currefout.dense))[l];
639 curlabout.num_samples = samples_per_label[l];
640 curlabout.all_ranked.reserve(sanisizer::product<I<decltype(curlabout.all_ranked.size())> >(universe.size(), curlabout.num_samples));
641 for (const auto& x : out_ranked[l]) {
642 curlabout.all_ranked.insert(curlabout.all_ranked.end(), x.begin(), x.end());
643 }
644 }
645 }
646 }
647
648 return TrainedIntegrated<Index_>(test_nrow, std::move(universe), std::move(references));
649}
650
651}
652
653#endif
Create an intersection of genes.
Classifier that integrates multiple reference datasets.
Definition train_integrated.hpp:202
Index_ test_nrow() const
Definition train_integrated.hpp:236
std::size_t num_references() const
Definition train_integrated.hpp:229
const std::vector< Index_ > & subset() const
Definition train_integrated.hpp:245
std::size_t num_profiles(std::size_t r) const
Definition train_integrated.hpp:266
std::size_t num_labels(std::size_t r) const
Definition train_integrated.hpp:253
Classifier trained from a single reference.
Definition train_single.hpp:85
const std::vector< Index_ > & subset() const
Definition train_single.hpp:134
const Markers< Index_ > & markers() const
Definition train_single.hpp:125
Common definitions for singlepp.
Cell type classification using the SingleR algorithm in C++.
Definition classify_single.hpp:20
std::vector< std::vector< std::vector< Index_ > > > Markers
Definition Markers.hpp:40
Intersection< Index_ > intersect_genes(Index_ test_nrow, const Id_ *test_id, Index_ ref_nrow, const Id_ *ref_id)
Definition Intersection.hpp:54
TrainIntegratedInput< Value_, Index_, Label_ > prepare_integrated_input(const tatami::Matrix< Value_, Index_ > &ref, const Label_ *labels, const TrainedSingle< Index_, Float_ > &trained)
Definition train_integrated.hpp:65
std::vector< std::pair< Index_, Index_ > > Intersection
Definition Intersection.hpp:35
TrainedIntegrated< Index_ > train_integrated(const std::vector< TrainIntegratedInput< Value_, Index_, Label_ > > &inputs, const TrainIntegratedOptions &options)
Definition train_integrated.hpp:451
Input to train_integrated().
Definition train_integrated.hpp:31
Options for train_integrated().
Definition train_integrated.hpp:285
int num_threads
Definition train_integrated.hpp:290
Train a classifier from a single reference.