MultiIndex
|
00001 00002 // Copyright 2012 Yandex Artem Babenko 00003 #ifndef SEARCHER_H_ 00004 #define SEARCHER_H_ 00005 00006 #include <algorithm> 00007 #include <map> 00008 00009 #include <boost/archive/binary_iarchive.hpp> 00010 #include <boost/archive/binary_oarchive.hpp> 00011 00012 #include <boost/serialization/serialization.hpp> 00013 #include <boost/serialization/set.hpp> 00014 #include <boost/serialization/vector.hpp> 00015 00016 #include <mkl_cblas.h> 00017 00018 #include "data_util.h" 00019 #include "ordered_lists_merger.h" 00020 #include "perfomance_util.h" 00021 00022 extern int THREADS_COUNT; 00023 00024 extern Dimensions SPACE_DIMENSION; 00025 00026 extern enum PointType point_type; 00027 00032 typedef vector<pair<Distance, ClusterId> > NearestSubspaceCentroids; 00033 00037 template<class Record, class MetaInfo> 00038 class MultiSearcher { 00039 public: 00043 MultiSearcher(); 00052 void Init(const string& index_files_prefix, 00053 const string& coarse_vocabs_filename, 00054 const string& fine_vocabs_filename, 00055 const RerankMode& mode, 00056 const int subspace_centroids_to_consider, 00057 bool do_rerank); 00065 void GetNearestNeighbours(const Point& point, int k, 00066 vector<pair<Distance, MetaInfo> >* neighbours) const; 00070 PerfTester& GetPerfTester(); 00071 private: 00078 void DeserializeData(const string& index_files_prefix, 00079 const string& coarse_vocabs_filename, 00080 const string& fine_vocabs_filename); 00087 void GetNearestSubspacesCentroids(const Point& point, 00088 const int subspace_centroins_count, 00089 vector<NearestSubspaceCentroids>* subspaces_short_lists) const; 00090 00096 bool TraverseNextMultiIndexCell(const Point& point, 00097 vector<pair<Distance, MetaInfo> >* nearest_subpoints) const; 00104 inline void GetCellEdgesInMultiIndexArray(const vector<int>& cell_coordinates, 00105 int* cell_start, int* cell_finish) const; 00110 void InitBlasStructures(); 00114 vector<Centroids> coarse_vocabs_; 00118 vector<Centroids> fine_vocabs_; 00122 mutable OrderedListsMerger<Distance, ClusterId> merger_; 00126 bool do_rerank_; 00130 mutable PerfTester perf_tester_; 00134 string index_files_prefix_; 00138 MultiIndex<Record> multiindex_; 00142 RerankMode rerank_mode_; 00146 vector<float*> coarse_vocabs_matrices_; 00150 vector<vector<float> > coarse_centroids_norms_; 00154 mutable Coord* products_; 00158 mutable vector<Coord> query_norms_; 00162 mutable float* residual_; 00167 int subspace_centroids_to_consider_; 00171 mutable int found_neghbours_count_; 00172 }; 00173 00174 template<class Record, class MetaInfo> 00175 inline void RecordToMetainfoAndDistance(const Coord* point, 00176 const Record& record, 00177 pair<Distance, MetaInfo>* result, 00178 const vector<int>& cell_coordinates, 00179 const vector<Centroids>& fine_vocabs) { 00180 } 00181 00183 00184 template<class Record, class MetaInfo> 00185 MultiSearcher<Record, MetaInfo>::MultiSearcher() { 00186 } 00187 00188 template<class Record, class MetaInfo> 00189 void MultiSearcher<Record, MetaInfo>::DeserializeData(const string& index_files_prefix, 00190 const string& coarse_vocabs_filename, 00191 const string& fine_vocabs_filename) { 00192 cout << "Data deserializing started...\n"; 00193 ifstream cell_edges(string(index_files_prefix + "_cell_edges.bin").c_str(), ios::binary); 00194 if(!cell_edges.good()) { 00195 throw std::logic_error("Bad input cell edges stream"); 00196 } 00197 boost::archive::binary_iarchive arc_cell_edges(cell_edges); 00198 arc_cell_edges >> multiindex_.cell_edges; 00199 cout << "Cell edges deserialized...\n"; 00200 ifstream multi_array(string(index_files_prefix + "_multi_array.bin").c_str(), ios::binary); 00201 if(!multi_array.good()) { 00202 throw std::logic_error("Bad input cell edges stream"); 00203 } 00204 boost::archive::binary_iarchive arc_multi_array(multi_array); 00205 arc_multi_array >> multiindex_.multiindex; 00206 cout << "Multiindex deserialized...\n"; 00207 ReadVocabularies<float>(coarse_vocabs_filename, SPACE_DIMENSION, &coarse_vocabs_); 00208 cout << "Coarse vocabs deserialized...\n"; 00209 ReadFineVocabs<float>(fine_vocabs_filename, &fine_vocabs_); 00210 cout << "Fine vocabs deserialized...\n"; 00211 } 00212 00213 template<class Record, class MetaInfo> 00214 void MultiSearcher<Record, MetaInfo>::Init(const string& index_files_prefix, 00215 const string& coarse_vocabs_filename, 00216 const string& fine_vocabs_filename, 00217 const RerankMode& mode, 00218 const int subspace_centroids_to_consider, 00219 const bool do_rerank) { 00220 do_rerank_ = do_rerank; 00221 index_files_prefix_ = index_files_prefix; 00222 subspace_centroids_to_consider_ = subspace_centroids_to_consider; 00223 DeserializeData(index_files_prefix, coarse_vocabs_filename, fine_vocabs_filename); 00224 rerank_mode_ = mode; 00225 merger_.GetYieldedItems().table.resize(std::pow((float)subspace_centroids_to_consider, 00226 (int)coarse_vocabs_.size())); 00227 for(int i = 0; i < coarse_vocabs_.size(); ++i) { 00228 merger_.GetYieldedItems().dimensions.push_back(subspace_centroids_to_consider); 00229 } 00230 InitBlasStructures(); 00231 } 00232 00233 template<class Record, class MetaInfo> 00234 void MultiSearcher<Record, MetaInfo>::InitBlasStructures(){ 00235 coarse_vocabs_matrices_.resize(coarse_vocabs_.size()); 00236 coarse_centroids_norms_.resize(coarse_vocabs_.size(), vector<float>(coarse_vocabs_[0].size())); 00237 for(int coarse_id = 0; coarse_id < coarse_vocabs_matrices_.size(); ++coarse_id) { 00238 coarse_vocabs_matrices_[coarse_id] = new float[coarse_vocabs_[0].size() * coarse_vocabs_[0][0].size()]; 00239 for(int i = 0; i < coarse_vocabs_[0].size(); ++i) { 00240 Coord norm = 0; 00241 for(int j = 0; j < coarse_vocabs_[0][0].size(); ++j) { 00242 coarse_vocabs_matrices_[coarse_id][coarse_vocabs_[0][0].size() * i + j] = coarse_vocabs_[coarse_id][i][j]; 00243 norm += coarse_vocabs_[coarse_id][i][j] * coarse_vocabs_[coarse_id][i][j]; 00244 } 00245 coarse_centroids_norms_[coarse_id][i] = norm; 00246 } 00247 } 00248 products_ = new Coord[coarse_vocabs_[0].size()]; 00249 query_norms_.resize(coarse_vocabs_[0].size()); 00250 residual_ = new Coord[coarse_vocabs_[0][0].size() * coarse_vocabs_.size()]; 00251 } 00252 00253 template<class Record, class MetaInfo> 00254 PerfTester& MultiSearcher<Record, MetaInfo>::GetPerfTester() { 00255 return perf_tester_; 00256 } 00257 00258 template<class Record, class MetaInfo> 00259 void MultiSearcher<Record, MetaInfo>::GetNearestSubspacesCentroids(const Point& point, 00260 const int subspace_centroins_count, 00261 vector<NearestSubspaceCentroids>* 00262 subspaces_short_lists) const { 00263 subspaces_short_lists->resize(coarse_vocabs_.size()); 00264 Dimensions subspace_dimension = point.size() / coarse_vocabs_.size(); 00265 for(int subspace_index = 0; subspace_index < coarse_vocabs_.size(); ++subspace_index) { 00266 Dimensions start_dim = subspace_index * subspace_dimension; 00267 Dimensions final_dim = std::min((Dimensions)point.size(), start_dim + subspace_dimension); 00268 Coord query_norm = cblas_sdot(final_dim - start_dim, &(point[start_dim]), 1, &(point[start_dim]), 1); 00269 std::fill(query_norms_.begin(), query_norms_.end(), query_norm); 00270 cblas_saxpy(coarse_vocabs_[0].size(), 1, &(coarse_centroids_norms_[subspace_index][0]), 1, &(query_norms_[0]), 1); 00271 cblas_sgemv(CblasRowMajor, CblasNoTrans, coarse_vocabs_[0].size(), subspace_dimension, -2.0, 00272 coarse_vocabs_matrices_[subspace_index], subspace_dimension, &(point[start_dim]), 1, 1, &(query_norms_[0]), 1); 00273 subspaces_short_lists->at(subspace_index).resize(query_norms_.size()); 00274 for(int i = 0; i < query_norms_.size(); ++i) { 00275 subspaces_short_lists->at(subspace_index)[i] = std::make_pair(query_norms_[i], i); 00276 } 00277 std::nth_element(subspaces_short_lists->at(subspace_index).begin(), 00278 subspaces_short_lists->at(subspace_index).begin() + subspace_centroins_count, 00279 subspaces_short_lists->at(subspace_index).end()); 00280 subspaces_short_lists->at(subspace_index).resize(subspace_centroins_count); 00281 std::sort(subspaces_short_lists->at(subspace_index).begin(), 00282 subspaces_short_lists->at(subspace_index).end()); 00283 } 00284 } 00285 00286 template<class Record, class MetaInfo> 00287 void MultiSearcher<Record, MetaInfo>::GetCellEdgesInMultiIndexArray(const vector<int>& cell_coordinates, 00288 int* cell_start, int* cell_finish) const { 00289 int global_index = multiindex_.cell_edges.GetCellGlobalIndex(cell_coordinates); 00290 *cell_start = multiindex_.cell_edges.table[global_index]; 00291 if(global_index + 1 == multiindex_.cell_edges.table.size()) { 00292 *cell_finish = multiindex_.multiindex.size(); 00293 } else { 00294 *cell_finish = multiindex_.cell_edges.table[global_index + 1]; 00295 } 00296 } 00297 00298 template<class Record, class MetaInfo> 00299 bool MultiSearcher<Record, MetaInfo>::TraverseNextMultiIndexCell(const Point& point, 00300 vector<pair<Distance, MetaInfo> >* 00301 nearest_subpoints) const { 00302 MergedItemIndices cell_inner_indices; 00303 clock_t before = clock(); 00304 if(!merger_.GetNextMergedItemIndices(&cell_inner_indices)) { 00305 return false; 00306 } 00307 clock_t after = clock(); 00308 perf_tester_.cell_coordinates_time += after - before; 00309 vector<int> cell_coordinates(cell_inner_indices.size()); 00310 for(int list_index = 0; list_index < merger_.lists_ptr->size(); ++list_index) { 00311 cell_coordinates[list_index] = merger_.lists_ptr->at(list_index)[cell_inner_indices[list_index]].second; 00312 } 00313 int cell_start, cell_finish; 00314 before = clock(); 00315 GetCellEdgesInMultiIndexArray(cell_coordinates, &cell_start, &cell_finish); 00316 after = clock(); 00317 perf_tester_.cell_edges_time += after - before; 00318 if(cell_start >= cell_finish) { 00319 return true; 00320 } 00321 vector<Record>::const_iterator it = multiindex_.multiindex.begin() + cell_start; 00322 GetResidual(point, cell_coordinates, coarse_vocabs_, residual_); 00323 cell_finish = std::min((int)cell_finish, cell_start + (int)nearest_subpoints->size() - found_neghbours_count_); 00324 for(int array_index = cell_start; array_index < cell_finish; ++array_index) { 00325 if(rerank_mode_ == USE_RESIDUALS) { 00326 RecordToMetainfoAndDistance<Record, MetaInfo>(residual_, *it, 00327 &(nearest_subpoints->at(found_neghbours_count_)), 00328 cell_coordinates, fine_vocabs_); 00329 } else if(rerank_mode_ == USE_INIT_POINTS) { 00330 RecordToMetainfoAndDistance<Record, MetaInfo>(&(point[0]), *it, 00331 &(nearest_subpoints->at(found_neghbours_count_)), 00332 cell_coordinates, fine_vocabs_); 00333 } 00334 perf_tester_.NextNeighbour(); 00335 ++found_neghbours_count_; 00336 ++it; 00337 } 00338 return true; 00339 } 00340 00341 00342 template<class Record, class MetaInfo> 00343 void MultiSearcher<Record, MetaInfo>::GetNearestNeighbours(const Point& point, int k, 00344 vector<pair<Distance, MetaInfo> >* neighbours) const { 00345 perf_tester_.handled_queries_count += 1; 00346 neighbours->resize(k); 00347 perf_tester_.ResetQuerywiseStatistic(); 00348 clock_t start = clock(); 00349 perf_tester_.search_start = start; 00350 clock_t before = clock(); 00351 vector<NearestSubspaceCentroids> subspaces_short_lists; 00352 GetNearestSubspacesCentroids(point, subspace_centroids_to_consider_, &subspaces_short_lists); 00353 clock_t after = clock(); 00354 perf_tester_.nearest_subcentroids_time += after - before; 00355 clock_t before_merger = clock(); 00356 merger_.setLists(subspaces_short_lists); 00357 clock_t after_merger = clock(); 00358 perf_tester_.merger_init_time += after_merger - before_merger; 00359 clock_t before_traversal = clock(); 00360 found_neghbours_count_ = 0; 00361 bool traverse_next_cell = true; 00362 int cells_visited = 0; 00363 while(found_neghbours_count_ < k && traverse_next_cell) { 00364 perf_tester_.cells_traversed += 1; 00365 traverse_next_cell = TraverseNextMultiIndexCell(point, neighbours); 00366 cells_visited += 1; 00367 } 00368 clock_t after_traversal = clock(); 00369 perf_tester_.full_traversal_time += after_traversal - before_traversal; 00370 if(do_rerank_) { 00371 if(neighbours->size() > 10000) { 00372 std::nth_element(neighbours->begin(), neighbours->begin() + 10000, neighbours->end()); 00373 neighbours->resize(10000); 00374 } 00375 std::sort(neighbours->begin(), neighbours->end()); 00376 } 00377 clock_t finish = clock(); 00378 perf_tester_.full_search_time += finish - start; 00379 } 00380 00381 template<> 00382 inline void RecordToMetainfoAndDistance<RerankADC8, PointId>(const Coord* point, const RerankADC8& record, 00383 pair<Distance, PointId>* result, 00384 const vector<int>& cell_coordinates, 00385 const vector<Centroids>& fine_vocabs) { 00386 result->second = record.pid; 00387 int coarse_clusters_count = cell_coordinates.size(); 00388 int fine_clusters_count = fine_vocabs.size(); 00389 int coarse_to_fine_ratio = fine_clusters_count / coarse_clusters_count; 00390 int subvectors_dim = SPACE_DIMENSION / fine_clusters_count; 00391 char* rerank_info_ptr = (char*)&record + sizeof(record.pid); 00392 for(int centroid_index = 0; centroid_index < fine_clusters_count; ++centroid_index) { 00393 int start_dim = centroid_index * subvectors_dim; 00394 int final_dim = start_dim + subvectors_dim; 00395 FineClusterId pid_nearest_centroid = *((FineClusterId*)rerank_info_ptr); 00396 rerank_info_ptr += sizeof(FineClusterId); 00397 int current_coarse_index = centroid_index / coarse_to_fine_ratio; 00398 Distance subvector_distance = 0; 00399 for(int i = start_dim; i < final_dim; ++i) { 00400 Coord diff = fine_vocabs[centroid_index][pid_nearest_centroid][i - start_dim] - point[i]; 00401 subvector_distance += diff * diff; 00402 } 00403 result->first += subvector_distance; 00404 } 00405 } 00406 00407 template<> 00408 inline void RecordToMetainfoAndDistance<RerankADC16, PointId>(const Coord* point, const RerankADC16& record, 00409 pair<Distance, PointId>* result, 00410 const vector<int>& cell_coordinates, 00411 const vector<Centroids>& fine_vocabs) { 00412 result->second = record.pid; 00413 int coarse_clusters_count = cell_coordinates.size(); 00414 int fine_clusters_count = fine_vocabs.size(); 00415 int coarse_to_fine_ratio = fine_clusters_count / coarse_clusters_count; 00416 int subvectors_dim = SPACE_DIMENSION / fine_clusters_count; 00417 char* rerank_info_ptr = (char*)&record + sizeof(record.pid); 00418 for(int centroid_index = 0; centroid_index < fine_clusters_count; ++centroid_index) { 00419 int start_dim = centroid_index * subvectors_dim; 00420 int final_dim = start_dim + subvectors_dim; 00421 FineClusterId pid_nearest_centroid = *((FineClusterId*)rerank_info_ptr); 00422 rerank_info_ptr += sizeof(FineClusterId); 00423 int current_coarse_index = centroid_index / coarse_to_fine_ratio; 00424 Distance subvector_distance = 0; 00425 for(int i = start_dim; i < final_dim; ++i) { 00426 Coord diff = fine_vocabs[centroid_index][pid_nearest_centroid][i - start_dim] - point[i]; 00427 subvector_distance += diff * diff; 00428 } 00429 result->first += subvector_distance; 00430 } 00431 } 00432 00433 template class MultiSearcher<RerankADC8, PointId>; 00434 template class MultiSearcher<RerankADC16, PointId>; 00435 template class MultiSearcher<PointId, PointId>; 00436 00437 #endif 00438