21 #ifndef COLLAGE_RANKER_H_
22 #define COLLAGE_RANKER_H_
24 #include <torch/script.h>
25 #include <torch/torch.h>
26 #include <cereal/types/complex.hpp>
27 #include <cereal/types/string.hpp>
28 #include <cereal/types/vector.hpp>
47 template <
typename DType>
48 std::vector<std::vector<DType>>
to_std_matrix(
const at::Tensor& tensor_features) {
49 if (tensor_features.sizes().size() != 2) {
50 throw std::runtime_error(
"Not 2x<dim> matrix.");
53 size_t num_rows{ size_t(tensor_features.sizes()[0]) };
54 size_t num_cols{ size_t(tensor_features.sizes()[1]) };
56 std::vector<std::vector<DType>> mat;
57 mat.reserve(num_rows);
60 float* data_ptr =
static_cast<float*
>(tensor_features.data_ptr());
61 for (std::size_t ir = 0; ir < num_rows; ++ir) {
62 std::vector<DType> row;
63 row.assign(data_ptr, data_ptr + num_cols);
66 mat.emplace_back(std::move(row));
72 template <c10::ScalarType TensorDType_ = at::kFloat,
typename OrigDType_ =
float>
73 at::Tensor
to_tensor(std::vector<OrigDType_>& orig_vec) {
77 return torch::tensor(orig_vec, TensorDType_);
80 template <c10::ScalarType TensorDType_ = at::kFloat,
typename OrigDType_ =
float>
81 at::Tensor
to_tensor(std::vector<std::vector<OrigDType_>>& orig_mat) {
82 do_assert_cond(orig_mat.size() > 0,
"Matrix cannot be empty.");
85 std::vector<at::Tensor> meta;
86 meta.reserve(orig_mat.size());
88 for (
auto&& vec : orig_mat) {
89 meta.emplace_back(torch::tensor(vec.data(), TensorDType_));
92 return torch::cat(meta, 0);
106 const std::vector<std::vector<float>>
RoIs = {
107 { 0.0, 0.0, 1.0, 1.0 }, { 0.1, 0.2, 0.4, 0.6 }, { 0.3, 0.2, 0.4, 0.6 }, { 0.5, 0.2, 0.4, 0.6 },
109 { 0.0, 0.0, 0.4, 0.6 }, { 0.2, 0.0, 0.4, 0.6 }, { 0.4, 0.0, 0.4, 0.6 }, { 0.6, 0.0, 0.4, 0.6 },
111 { 0.0, 0.4, 0.4, 0.6 }, { 0.2, 0.4, 0.4, 0.6 }, { 0.4, 0.4, 0.4, 0.6 }, { 0.6, 0.4, 0.4, 0.6 },
127 at::Tensor
get_L2norm(
const at::Tensor& _data)
const;
132 std::vector<float>
score_image(
const std::vector<float>& feature, std::size_t region)
const;
133 std::vector<float>
average_scores(
const std::vector<std::vector<float>>& scores)
const;
Type representing query related to the canvas (atm text & bitmap) rectangles.
Definition: query-types.h:325
Definition: canvas-query-ranker.h:95
torch::jit::script::Module resnet152
Definition: canvas-query-ranker.h:99
torch::Tensor kw_pca_mat
Definition: canvas-query-ranker.h:103
KeywordRanker * _p_core
Definition: canvas-query-ranker.h:96
void score(const CanvasQuery &, ScoreModel &model, size_t temporal, UsedTools &used_tools, const PrimaryFrameFeatures &_dataset_features, const DatasetFrames &_dataset_frames)
Definition: canvas-query-ranker.cpp:151
static const size_t models_input_height
Definition: canvas-query-ranker.h:116
at::Tensor get_features(const CanvasQuery &, UsedTools &used_tools)
Definition: canvas-query-ranker.cpp:195
std::vector< std::size_t > get_RoIs(const CanvasQuery &collage) const
Definition: canvas-query-ranker.cpp:314
torch::Tensor bias
Definition: canvas-query-ranker.h:101
torch::Tensor kw_pca_mean_vec
Definition: canvas-query-ranker.h:104
std::size_t get_RoI(const CanvasSubquery &image) const
Definition: canvas-query-ranker.cpp:320
torch::jit::script::Module resnext101
Definition: canvas-query-ranker.h:100
CanvasQueryRanker(const Settings &_settings, KeywordRanker *p_core)
Definition: canvas-query-ranker.cpp:37
static const size_t models_input_width
Definition: canvas-query-ranker.h:115
std::vector< FeatureMatrix > region_data
Definition: canvas-query-ranker.h:124
static const size_t models_num_channels
Definition: canvas-query-ranker.h:117
std::vector< float > average_scores(const std::vector< std::vector< float >> &scores) const
Definition: canvas-query-ranker.cpp:352
std::vector< float > score_image(const std::vector< float > &feature, std::size_t region) const
Definition: canvas-query-ranker.cpp:345
bool _loaded
Definition: canvas-query-ranker.h:97
torch::Tensor weights
Definition: canvas-query-ranker.h:102
at::Tensor get_L2norm(const at::Tensor &_data) const
Definition: canvas-query-ranker.cpp:186
const std::vector< std::vector< float > > RoIs
Definition: canvas-query-ranker.h:106
Definition: dataset-frames.h:162
Definition: keyword-ranker.h:52
Definition: common-types.h:33
std::vector< std::vector< DType > > to_std_matrix(const at::Tensor &tensor_features)
Definition: canvas-query-ranker.h:48
at::Tensor to_tensor(std::vector< OrigDType_ > &orig_vec)
Definition: canvas-query-ranker.h:73
std::variant< CanvasSubqueryBitmap, CanvasSubqueryText > CanvasSubquery
Definition: query-types.h:320
#define do_assert_debug(assertion, msg)
Assert execuded only if RUN_ASSERTS is true.
Definition: static-logger.hpp:225
Parsed current config of the core.
Definition: settings.h:190