MultiIndex
searcher.h
Go to the documentation of this file.
00001 
00002 // Copyright 2012 Yandex Artem Babenko
00003 #ifndef SEARCHER_H_
00004 #define SEARCHER_H_
00005 
00006 #include <algorithm>
00007 #include <map>
00008 
00009 #include <boost/archive/binary_iarchive.hpp>
00010 #include <boost/archive/binary_oarchive.hpp>
00011 
00012 #include <boost/serialization/serialization.hpp>
00013 #include <boost/serialization/set.hpp>
00014 #include <boost/serialization/vector.hpp>
00015 
00016 #include <mkl_cblas.h>
00017 
00018 #include "data_util.h"
00019 #include "ordered_lists_merger.h"
00020 #include "perfomance_util.h"
00021 
00022 extern int THREADS_COUNT;
00023 
00024 extern Dimensions SPACE_DIMENSION;
00025 
00026 extern enum PointType point_type;
00027 
00032 typedef vector<pair<Distance, ClusterId> > NearestSubspaceCentroids;
00033 
00037 template<class Record, class MetaInfo>
00038 class MultiSearcher {
00039  public:
00043   MultiSearcher();
00052   void Init(const string& index_files_prefix,
00053             const string& coarse_vocabs_filename,
00054             const string& fine_vocabs_filename,
00055             const RerankMode& mode,
00056             const int subspace_centroids_to_consider,
00057             bool do_rerank);
00065   void GetNearestNeighbours(const Point& point, int k, 
00066                             vector<pair<Distance, MetaInfo> >* neighbours) const;
00070   PerfTester& GetPerfTester();
00071  private:
00078   void DeserializeData(const string& index_files_prefix,
00079                        const string& coarse_vocabs_filename,
00080                        const string& fine_vocabs_filename);
00087   void GetNearestSubspacesCentroids(const Point& point,
00088                                     const int subspace_centroins_count,
00089                                     vector<NearestSubspaceCentroids>* subspaces_short_lists) const;
00090 
00096   bool TraverseNextMultiIndexCell(const Point& point,
00097                                   vector<pair<Distance, MetaInfo> >* nearest_subpoints) const;
00104 inline void GetCellEdgesInMultiIndexArray(const vector<int>& cell_coordinates,
00105                                           int* cell_start, int* cell_finish) const;
00110   void InitBlasStructures();
00114   vector<Centroids> coarse_vocabs_;
00118   vector<Centroids> fine_vocabs_;
00122   mutable OrderedListsMerger<Distance, ClusterId> merger_;
00126   bool do_rerank_;
00130   mutable PerfTester perf_tester_;
00134   string index_files_prefix_;
00138   MultiIndex<Record> multiindex_;
00142   RerankMode rerank_mode_;
00146   vector<float*> coarse_vocabs_matrices_;
00150   vector<vector<float> > coarse_centroids_norms_;
00154   mutable Coord* products_;
00158   mutable vector<Coord> query_norms_;
00162   mutable float* residual_;
00167   int subspace_centroids_to_consider_;
00171   mutable int found_neghbours_count_;
00172 };
00173 
00174 template<class Record, class MetaInfo>
00175 inline void RecordToMetainfoAndDistance(const Coord* point,
00176                                         const Record& record,
00177                                         pair<Distance, MetaInfo>* result,
00178                                         const vector<int>& cell_coordinates,
00179                                         const vector<Centroids>& fine_vocabs) {
00180 }
00181 
00183 
00184 template<class Record, class MetaInfo>
00185 MultiSearcher<Record, MetaInfo>::MultiSearcher() {
00186 }
00187 
00188 template<class Record, class MetaInfo>
00189 void MultiSearcher<Record, MetaInfo>::DeserializeData(const string& index_files_prefix,
00190                                                       const string& coarse_vocabs_filename,
00191                                                       const string& fine_vocabs_filename) {
00192   cout << "Data deserializing started...\n";
00193   ifstream cell_edges(string(index_files_prefix + "_cell_edges.bin").c_str(), ios::binary);
00194   if(!cell_edges.good()) {
00195     throw std::logic_error("Bad input cell edges stream");
00196   }
00197   boost::archive::binary_iarchive arc_cell_edges(cell_edges);
00198   arc_cell_edges >> multiindex_.cell_edges;
00199   cout << "Cell edges deserialized...\n";
00200   ifstream multi_array(string(index_files_prefix + "_multi_array.bin").c_str(), ios::binary);
00201   if(!multi_array.good()) {
00202     throw std::logic_error("Bad input cell edges stream");
00203   }
00204   boost::archive::binary_iarchive arc_multi_array(multi_array);
00205   arc_multi_array >> multiindex_.multiindex;
00206   cout << "Multiindex deserialized...\n";
00207   ReadVocabularies<float>(coarse_vocabs_filename, SPACE_DIMENSION, &coarse_vocabs_);
00208   cout << "Coarse vocabs deserialized...\n";
00209   ReadFineVocabs<float>(fine_vocabs_filename, &fine_vocabs_);
00210   cout << "Fine vocabs deserialized...\n";
00211 }
00212 
00213 template<class Record, class MetaInfo>
00214 void MultiSearcher<Record, MetaInfo>::Init(const string& index_files_prefix,
00215                                            const string& coarse_vocabs_filename,
00216                                            const string& fine_vocabs_filename,
00217                                            const RerankMode& mode,
00218                                            const int subspace_centroids_to_consider,
00219                                            const bool do_rerank) {
00220   do_rerank_ = do_rerank;
00221   index_files_prefix_ = index_files_prefix;
00222   subspace_centroids_to_consider_ = subspace_centroids_to_consider;
00223   DeserializeData(index_files_prefix, coarse_vocabs_filename, fine_vocabs_filename);
00224   rerank_mode_ = mode;
00225   merger_.GetYieldedItems().table.resize(std::pow((float)subspace_centroids_to_consider,
00226                                                          (int)coarse_vocabs_.size()));
00227   for(int i = 0; i < coarse_vocabs_.size(); ++i) {
00228     merger_.GetYieldedItems().dimensions.push_back(subspace_centroids_to_consider);
00229   }
00230   InitBlasStructures();
00231 }
00232 
00233 template<class Record, class MetaInfo>
00234 void MultiSearcher<Record, MetaInfo>::InitBlasStructures(){
00235   coarse_vocabs_matrices_.resize(coarse_vocabs_.size());
00236   coarse_centroids_norms_.resize(coarse_vocabs_.size(), vector<float>(coarse_vocabs_[0].size()));
00237   for(int coarse_id = 0; coarse_id < coarse_vocabs_matrices_.size(); ++coarse_id) {
00238     coarse_vocabs_matrices_[coarse_id] = new float[coarse_vocabs_[0].size() * coarse_vocabs_[0][0].size()];
00239     for(int i = 0; i < coarse_vocabs_[0].size(); ++i) {
00240       Coord norm = 0;
00241       for(int j = 0; j < coarse_vocabs_[0][0].size(); ++j) {
00242         coarse_vocabs_matrices_[coarse_id][coarse_vocabs_[0][0].size() * i + j] = coarse_vocabs_[coarse_id][i][j];
00243         norm += coarse_vocabs_[coarse_id][i][j] * coarse_vocabs_[coarse_id][i][j];
00244       }
00245       coarse_centroids_norms_[coarse_id][i] = norm;
00246     }
00247   }
00248   products_ = new Coord[coarse_vocabs_[0].size()];
00249   query_norms_.resize(coarse_vocabs_[0].size());
00250   residual_ = new Coord[coarse_vocabs_[0][0].size() * coarse_vocabs_.size()];
00251 }
00252 
00253 template<class Record, class MetaInfo>
00254 PerfTester& MultiSearcher<Record, MetaInfo>::GetPerfTester() {
00255   return perf_tester_;
00256 }
00257 
00258 template<class Record, class MetaInfo>
00259 void MultiSearcher<Record, MetaInfo>::GetNearestSubspacesCentroids(const Point& point,
00260                                                                    const int subspace_centroins_count,
00261                                                                    vector<NearestSubspaceCentroids>*
00262                                                                    subspaces_short_lists) const {
00263   subspaces_short_lists->resize(coarse_vocabs_.size());
00264   Dimensions subspace_dimension = point.size() / coarse_vocabs_.size();
00265   for(int subspace_index = 0; subspace_index < coarse_vocabs_.size(); ++subspace_index) {
00266     Dimensions start_dim = subspace_index * subspace_dimension;
00267     Dimensions final_dim = std::min((Dimensions)point.size(), start_dim + subspace_dimension);
00268     Coord query_norm = cblas_sdot(final_dim - start_dim, &(point[start_dim]), 1, &(point[start_dim]), 1);
00269     std::fill(query_norms_.begin(), query_norms_.end(), query_norm);
00270     cblas_saxpy(coarse_vocabs_[0].size(), 1, &(coarse_centroids_norms_[subspace_index][0]), 1, &(query_norms_[0]), 1);
00271     cblas_sgemv(CblasRowMajor, CblasNoTrans, coarse_vocabs_[0].size(), subspace_dimension, -2.0,
00272                 coarse_vocabs_matrices_[subspace_index], subspace_dimension, &(point[start_dim]), 1, 1, &(query_norms_[0]), 1);
00273     subspaces_short_lists->at(subspace_index).resize(query_norms_.size());
00274     for(int i = 0; i < query_norms_.size(); ++i) {
00275       subspaces_short_lists->at(subspace_index)[i] = std::make_pair(query_norms_[i], i);
00276     }
00277     std::nth_element(subspaces_short_lists->at(subspace_index).begin(),
00278                      subspaces_short_lists->at(subspace_index).begin() + subspace_centroins_count,
00279                      subspaces_short_lists->at(subspace_index).end());
00280     subspaces_short_lists->at(subspace_index).resize(subspace_centroins_count);
00281     std::sort(subspaces_short_lists->at(subspace_index).begin(),
00282               subspaces_short_lists->at(subspace_index).end());
00283   }
00284 }
00285 
00286 template<class Record, class MetaInfo>
00287 void MultiSearcher<Record, MetaInfo>::GetCellEdgesInMultiIndexArray(const vector<int>& cell_coordinates,
00288                                                                     int* cell_start, int* cell_finish) const {
00289   int global_index = multiindex_.cell_edges.GetCellGlobalIndex(cell_coordinates);
00290   *cell_start = multiindex_.cell_edges.table[global_index];
00291   if(global_index + 1 == multiindex_.cell_edges.table.size()) {
00292     *cell_finish = multiindex_.multiindex.size();
00293   } else {
00294     *cell_finish = multiindex_.cell_edges.table[global_index + 1];
00295   }
00296 }
00297 
00298 template<class Record, class MetaInfo>
00299 bool MultiSearcher<Record, MetaInfo>::TraverseNextMultiIndexCell(const Point& point,
00300                                                                  vector<pair<Distance, MetaInfo> >*
00301                                                                              nearest_subpoints) const {
00302   MergedItemIndices cell_inner_indices;
00303   clock_t before = clock();
00304   if(!merger_.GetNextMergedItemIndices(&cell_inner_indices)) {
00305     return false;
00306   }
00307   clock_t after = clock();
00308   perf_tester_.cell_coordinates_time += after - before;
00309   vector<int> cell_coordinates(cell_inner_indices.size());
00310   for(int list_index = 0; list_index < merger_.lists_ptr->size(); ++list_index) {
00311     cell_coordinates[list_index] = merger_.lists_ptr->at(list_index)[cell_inner_indices[list_index]].second;
00312   }
00313   int cell_start, cell_finish;
00314   before = clock();
00315   GetCellEdgesInMultiIndexArray(cell_coordinates, &cell_start, &cell_finish);
00316   after = clock();
00317   perf_tester_.cell_edges_time += after - before;
00318   if(cell_start >= cell_finish) {
00319     return true;
00320   }
00321   vector<Record>::const_iterator it = multiindex_.multiindex.begin() + cell_start;
00322   GetResidual(point, cell_coordinates, coarse_vocabs_, residual_);
00323   cell_finish = std::min((int)cell_finish, cell_start + (int)nearest_subpoints->size() - found_neghbours_count_);
00324   for(int array_index = cell_start; array_index < cell_finish; ++array_index) {
00325     if(rerank_mode_ == USE_RESIDUALS) {
00326       RecordToMetainfoAndDistance<Record, MetaInfo>(residual_, *it,
00327                                                     &(nearest_subpoints->at(found_neghbours_count_)),
00328                                                     cell_coordinates, fine_vocabs_);
00329     } else if(rerank_mode_ == USE_INIT_POINTS) {
00330       RecordToMetainfoAndDistance<Record, MetaInfo>(&(point[0]), *it,
00331                                                     &(nearest_subpoints->at(found_neghbours_count_)),
00332                                                     cell_coordinates, fine_vocabs_);
00333     }
00334     perf_tester_.NextNeighbour();
00335     ++found_neghbours_count_;
00336     ++it;
00337   }
00338   return true;
00339 }
00340 
00341 
00342 template<class Record, class MetaInfo>
00343 void MultiSearcher<Record, MetaInfo>::GetNearestNeighbours(const Point& point, int k, 
00344                                                            vector<pair<Distance, MetaInfo> >* neighbours) const {
00345   perf_tester_.handled_queries_count += 1;
00346   neighbours->resize(k);
00347   perf_tester_.ResetQuerywiseStatistic();
00348   clock_t start = clock();
00349   perf_tester_.search_start = start;
00350   clock_t before = clock();
00351   vector<NearestSubspaceCentroids> subspaces_short_lists;
00352   GetNearestSubspacesCentroids(point, subspace_centroids_to_consider_, &subspaces_short_lists);
00353   clock_t after = clock();
00354   perf_tester_.nearest_subcentroids_time += after - before;
00355   clock_t before_merger = clock();
00356   merger_.setLists(subspaces_short_lists);
00357   clock_t after_merger = clock();
00358   perf_tester_.merger_init_time += after_merger - before_merger;
00359   clock_t before_traversal = clock();
00360   found_neghbours_count_ = 0;
00361   bool traverse_next_cell = true;
00362   int cells_visited = 0;
00363   while(found_neghbours_count_ < k && traverse_next_cell) {
00364     perf_tester_.cells_traversed += 1;
00365     traverse_next_cell = TraverseNextMultiIndexCell(point, neighbours);
00366     cells_visited += 1;
00367   }
00368   clock_t after_traversal = clock();
00369   perf_tester_.full_traversal_time += after_traversal - before_traversal;
00370   if(do_rerank_) {
00371     if(neighbours->size() > 10000) {
00372       std::nth_element(neighbours->begin(), neighbours->begin() + 10000, neighbours->end());
00373       neighbours->resize(10000);
00374     }
00375     std::sort(neighbours->begin(), neighbours->end());
00376   }
00377   clock_t finish = clock();
00378   perf_tester_.full_search_time += finish - start;
00379 }
00380 
00381 template<>
00382 inline void RecordToMetainfoAndDistance<RerankADC8, PointId>(const Coord* point, const RerankADC8& record,
00383                                                              pair<Distance, PointId>* result,
00384                                                              const vector<int>& cell_coordinates,
00385                                                              const vector<Centroids>& fine_vocabs) {
00386   result->second = record.pid;
00387   int coarse_clusters_count = cell_coordinates.size();
00388   int fine_clusters_count = fine_vocabs.size();
00389   int coarse_to_fine_ratio = fine_clusters_count / coarse_clusters_count;
00390   int subvectors_dim = SPACE_DIMENSION / fine_clusters_count;
00391   char* rerank_info_ptr = (char*)&record + sizeof(record.pid);
00392   for(int centroid_index = 0; centroid_index < fine_clusters_count; ++centroid_index) {
00393     int start_dim = centroid_index * subvectors_dim;
00394     int final_dim = start_dim + subvectors_dim;
00395     FineClusterId pid_nearest_centroid = *((FineClusterId*)rerank_info_ptr);
00396     rerank_info_ptr += sizeof(FineClusterId);
00397     int current_coarse_index = centroid_index / coarse_to_fine_ratio;
00398     Distance subvector_distance = 0;
00399     for(int i = start_dim; i < final_dim; ++i) {
00400       Coord diff = fine_vocabs[centroid_index][pid_nearest_centroid][i - start_dim] - point[i];
00401         subvector_distance += diff * diff;
00402     }
00403     result->first += subvector_distance;
00404   }
00405 }
00406 
00407 template<>
00408 inline void RecordToMetainfoAndDistance<RerankADC16, PointId>(const Coord* point, const RerankADC16& record,
00409                                                               pair<Distance, PointId>* result,
00410                                                               const vector<int>& cell_coordinates,
00411                                                               const vector<Centroids>& fine_vocabs) {
00412   result->second = record.pid;
00413   int coarse_clusters_count = cell_coordinates.size();
00414   int fine_clusters_count = fine_vocabs.size();
00415   int coarse_to_fine_ratio = fine_clusters_count / coarse_clusters_count;
00416   int subvectors_dim = SPACE_DIMENSION / fine_clusters_count;
00417   char* rerank_info_ptr = (char*)&record + sizeof(record.pid);
00418   for(int centroid_index = 0; centroid_index < fine_clusters_count; ++centroid_index) {
00419     int start_dim = centroid_index * subvectors_dim;
00420     int final_dim = start_dim + subvectors_dim;
00421     FineClusterId pid_nearest_centroid = *((FineClusterId*)rerank_info_ptr);
00422     rerank_info_ptr += sizeof(FineClusterId);
00423     int current_coarse_index = centroid_index / coarse_to_fine_ratio;
00424     Distance subvector_distance = 0;
00425     for(int i = start_dim; i < final_dim; ++i) {
00426       Coord diff = fine_vocabs[centroid_index][pid_nearest_centroid][i - start_dim] - point[i];
00427       subvector_distance += diff * diff;
00428     }
00429     result->first += subvector_distance;
00430   }
00431 }
00432 
00433 template class MultiSearcher<RerankADC8, PointId>;
00434 template class MultiSearcher<RerankADC16, PointId>;
00435 template class MultiSearcher<PointId, PointId>;
00436 
00437 #endif
00438 
 All Classes Files Functions Variables Typedefs Enumerations Enumerator