1#ifndef SINGLEPP_TRAIN_INTEGRATED_HPP
2#define SINGLEPP_TRAIN_INTEGRATED_HPP
6#include "scaled_ranks.hpp"
11#include <unordered_set>
12#include <unordered_map>
33template<
typename Value_ = DefaultValue,
typename Index_ = DefaultIndex,
typename Label_ = DefaultLabel>
40 const tatami::Matrix<Value_, Index_>* ref;
44 std::vector<std::vector<Index_> > markers;
46 bool with_intersection =
false;
74template<
typename Value_,
typename Index_,
typename Label_,
typename Float_>
76 const tatami::Matrix<Value_, Index_>& ref,
81 output.test_nrow = ref.nrow();
83 output.labels = labels;
87 auto nlabels = old_markers.size();
92 auto& new_markers = output.markers;
93 new_markers.reserve(nlabels);
94 std::unordered_set<Index_> unified;
96 for (
decltype(nlabels) i = 0; i < nlabels; ++i) {
98 for (
const auto& x : old_markers[i]) {
99 unified.insert(x.begin(), x.end());
101 new_markers.emplace_back(unified.begin(), unified.end());
102 auto& cur_new_markers = new_markers.back();
103 for (
auto& y : cur_new_markers) {
134template<
typename Index_,
typename Value_,
typename Label_,
typename Float_>
138 const tatami::Matrix<Value_, Index_>& ref,
139 const Label_* labels,
143 output.test_nrow = test_nrow;
145 output.labels = labels;
149 auto nlabels = old_markers.size();
150 auto& new_markers = output.markers;
151 new_markers.resize(nlabels);
154 std::unordered_set<Index_> unified;
156 for (
decltype(nlabels) i = 0; i < nlabels; ++i) {
157 const auto& cur_old_markers = old_markers[i];
160 for (
const auto& x : cur_old_markers) {
161 unified.insert(x.begin(), x.end());
164 auto& cur_new_markers = new_markers[i];
165 cur_new_markers.reserve(unified.size());
166 for (
auto y : unified) {
167 cur_new_markers.push_back(test_subset[y]);
171 output.with_intersection =
true;
172 output.user_intersection = &intersection;
180template<
typename Index_,
typename Value_,
typename Label_,
typename Float_>
183 const tatami::Matrix<Value_, Index_>& ref,
184 const Label_* labels,
185 const TrainedSingleIntersect<Index_, Float_>& trained)
219template<
typename Index_,
typename Id_,
typename Value_,
typename Label_,
typename Float_>
223 const tatami::Matrix<Value_, Index_>& ref,
225 const Label_* labels,
228 auto intersection =
intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
230 output.user_intersection = NULL;
231 output.auto_intersection.swap(intersection);
239template<
typename Index_>
246 return markers.size();
254 return markers[r].size();
263 for (
const auto& ref : ranked[r]) {
276 std::vector<Index_> universe;
278 std::vector<uint8_t> check_availability;
279 std::vector<std::unordered_set<Index_> > available;
280 std::vector<std::vector<std::vector<Index_> > > markers;
281 std::vector<std::vector<std::vector<internal::RankedVector<Index_, Index_> > > > ranked;
303template<
typename Value_,
typename RefLabel_,
typename Input_,
typename Index_>
304void train_integrated_per_reference(
308 const std::unordered_map<Index_, Index_> remap_to_universe,
311 auto curlab = curinput.labels;
312 const auto& ref = *(curinput.ref);
315 auto& curmarkers = output.markers[ref_i];
316 if constexpr(std::is_const<Input_>::value) {
317 curmarkers.swap(curinput.markers);
319 curmarkers = curinput.markers;
321 for (
auto& outer : curmarkers) {
322 for (
auto& x : outer) {
323 x = remap_to_universe.find(x)->second;
328 auto& cur_ranked = output.ranked[ref_i];
329 std::vector<Index_> positions;
331 auto nlabels = curmarkers.size();
332 Index_ NC = ref.ncol();
333 positions.reserve(NC);
335 std::vector<Index_> samples_per_label(nlabels);
336 for (Index_ c = 0; c < NC; ++c) {
337 auto& pos = samples_per_label[curlab[c]];
338 positions.push_back(pos);
342 cur_ranked.resize(nlabels);
343 for (
decltype(nlabels) l = 0; l < nlabels; ++l) {
344 cur_ranked[l].resize(samples_per_label[l]);
348 if (!curinput.with_intersection) {
352 tatami::VectorPtr<Index_> universe_ptr(tatami::VectorPtr<Index_>{}, &(output.universe));
354 tatami::parallelize([&](
int, Index_ start, Index_ len) {
355 std::vector<Value_> buffer(output.universe.size());
356 internal::RankedVector<Value_, Index_> tmp_ranked;
357 tmp_ranked.reserve(output.universe.size());
358 auto ext = tatami::consecutive_extractor<false>(&ref,
false, start, len, universe_ptr);
360 for (Index_ c = start, end = start + len; c < end; ++c) {
361 auto ptr = ext->fetch(buffer.data());
364 for (
int i = 0, end = output.universe.size(); i < end; ++i, ++ptr) {
365 tmp_ranked.emplace_back(*ptr, i);
367 std::sort(tmp_ranked.begin(), tmp_ranked.end());
369 auto& final_ranked = cur_ranked[curlab[c]][positions[c]];
370 simplify_ranks(tmp_ranked, final_ranked);
375 output.check_availability[ref_i] = 1;
379 const auto& intersection = (curinput.user_intersection == NULL ? curinput.auto_intersection : *(curinput.user_intersection));
380 std::unordered_map<Index_, Index_> intersection_map;
381 intersection_map.reserve(intersection.size());
382 for (
const auto& in : intersection) {
383 intersection_map[in.first] = in.second;
386 std::vector<std::pair<Index_, Index_> > intersection_in_universe;
387 intersection_in_universe.reserve(output.universe.size());
388 auto& cur_available = output.available[ref_i];
389 cur_available.reserve(output.universe.size());
391 for (Index_ i = 0, end = output.universe.size(); i < end; ++i) {
392 auto it = intersection_map.find(output.universe[i]);
393 if (it != intersection_map.end()) {
394 intersection_in_universe.emplace_back(it->second, i);
395 cur_available.insert(i);
398 std::sort(intersection_in_universe.begin(), intersection_in_universe.end());
400 std::vector<Index_> to_extract;
401 to_extract.reserve(intersection_in_universe.size());
402 for (
const auto& p : intersection_in_universe) {
403 to_extract.push_back(p.first);
405 tatami::VectorPtr<Index_> to_extract_ptr(tatami::VectorPtr<Index_>{}, &to_extract);
407 tatami::parallelize([&](
int, Index_ start, Index_ len) {
408 std::vector<Value_> buffer(to_extract.size());
409 internal::RankedVector<Value_, Index_> tmp_ranked;
410 tmp_ranked.reserve(to_extract.size());
411 auto ext = tatami::consecutive_extractor<false>(&ref,
false, start, len, to_extract_ptr);
413 for (Index_ c = start, end = start + len; c < end; ++c) {
414 auto ptr = ext->fetch(buffer.data());
417 for (
const auto& p : intersection_in_universe) {
418 tmp_ranked.emplace_back(*ptr, p.second);
421 std::sort(tmp_ranked.begin(), tmp_ranked.end());
423 auto& final_ranked = cur_ranked[curlab[c]][positions[c]];
424 simplify_ranks(tmp_ranked, final_ranked);
430template<
typename Value_,
typename Index_,
typename Inputs_>
431TrainedIntegrated<Index_> train_integrated(Inputs_& inputs,
const TrainIntegratedOptions& options) {
432 TrainedIntegrated<Index_> output;
433 auto nrefs = inputs.size();
434 output.check_availability.resize(nrefs);
435 output.available.resize(nrefs);
436 output.markers.resize(nrefs);
437 output.ranked.resize(nrefs);
440 output.test_nrow = -1;
441 for (
const auto& in : inputs) {
442 if (output.test_nrow ==
static_cast<Index_
>(-1)) {
443 output.test_nrow = in.test_nrow;
444 }
else if (in.test_nrow !=
static_cast<Index_
>(-1) && in.test_nrow != output.test_nrow) {
445 throw std::runtime_error(
"inconsistent number of rows in the test dataset across entries of 'inputs'");
450 std::unordered_map<Index_, Index_> remap_to_universe;
451 std::unordered_set<Index_> subset_tmp;
452 for (
const auto& in : inputs) {
453 for (
const auto& mrk : in.markers) {
454 subset_tmp.insert(mrk.begin(), mrk.end());
458 output.universe.insert(output.universe.end(), subset_tmp.begin(), subset_tmp.end());
459 std::sort(output.universe.begin(), output.universe.end());
460 remap_to_universe.reserve(output.universe.size());
461 for (Index_ i = 0, end = output.universe.size(); i < end; ++i) {
462 remap_to_universe[output.universe[i]] = i;
465 for (
decltype(nrefs) r = 0; r < nrefs; ++r) {
466 train_integrated_per_reference<Value_>(r, inputs[r], output, remap_to_universe, options);
487template<
typename Value_,
typename Index_,
typename Label_>
489 return internal::train_integrated<Value_, Index_>(inputs, options);
502template<
typename Value_,
typename Index_,
typename Label_>
504 return internal::train_integrated<Value_, Index_>(inputs, options);
Create an intersection of genes.
Classifier that integrates multiple reference datasets.
Definition train_integrated.hpp:240
std::size_t num_references() const
Definition train_integrated.hpp:245
std::size_t num_profiles(std::size_t r) const
Definition train_integrated.hpp:261
std::size_t num_labels(std::size_t r) const
Definition train_integrated.hpp:253
Classifier built from an intersection of genes.
Definition train_single.hpp:218
const Markers< Index_ > & get_markers() const
Definition train_single.hpp:260
const std::vector< Index_ > & get_test_subset() const
Definition train_single.hpp:269
Classifier trained from a single reference.
Definition train_single.hpp:89
const std::vector< Index_ > & get_subset() const
Definition train_single.hpp:136
const Markers< Index_ > & get_markers() const
Definition train_single.hpp:128
Common definitions for singlepp.
Cell type classification using the SingleR algorithm in C++.
Definition classify_single.hpp:20
Intersection< Index_ > intersect_genes(Index_ test_nrow, const Id_ *test_id, Index_ ref_nrow, const Id_ *ref_id)
Definition Intersection.hpp:54
TrainIntegratedInput< Value_, Index_, Label_ > prepare_integrated_input(const tatami::Matrix< Value_, Index_ > &ref, const Label_ *labels, const TrainedSingle< Index_, Float_ > &trained)
Definition train_integrated.hpp:75
TrainIntegratedInput< Value_, Index_, Label_ > prepare_integrated_input_intersect(Index_ test_nrow, const Intersection< Index_ > &intersection, const tatami::Matrix< Value_, Index_ > &ref, const Label_ *labels, const TrainedSingleIntersect< Index_, Float_ > &trained)
Definition train_integrated.hpp:135
std::vector< std::pair< Index_, Index_ > > Intersection
Definition Intersection.hpp:35
TrainedIntegrated< Index_ > train_integrated(const std::vector< TrainIntegratedInput< Value_, Index_, Label_ > > &inputs, const TrainIntegratedOptions &options)
Definition train_integrated.hpp:488
Options for train_integrated().
Definition train_integrated.hpp:290
int num_threads
Definition train_integrated.hpp:295
Train a classifier from a single reference.