MultiIndex
|
00001 00003 // Copyright 2012 Yandex Artem Babenko 00004 #pragma once 00005 00006 00007 #include <bitset> 00008 #include <fstream> 00009 #include <ios> 00010 #include <iostream> 00011 #include <map> 00012 #include <set> 00013 #include <vector> 00014 00015 #include <boost/bind.hpp> 00016 #include <boost/thread.hpp> 00017 00018 #include "mkl_cblas.h" 00019 00020 #include "multitable.hpp" 00021 00022 using std::bitset; 00023 using std::cout; 00024 using std::ifstream; 00025 using std::ios; 00026 using std::endl; 00027 using std::multimap; 00028 using std::pair; 00029 using std::set; 00030 using std::string; 00031 using std::vector; 00032 00037 typedef float Coord; 00042 typedef float Distance; 00047 typedef int Dimensions; 00052 typedef int PointId; 00057 typedef int ClusterId; 00062 typedef vector<Coord> Point; 00067 typedef vector<Point> Points; 00072 typedef vector<PointId> PointIds; 00077 typedef set<PointId> SetPoints; 00082 typedef vector<ClusterId> ClusterIds; 00088 typedef vector<ClusterId> CoarseQuantization; 00093 typedef unsigned char FineClusterId; 00098 typedef vector<FineClusterId> FineQuantization; 00104 typedef vector<SetPoints> ClustersToPoints; 00110 typedef std::vector<ClusterId> PointsToClusters; 00115 typedef std::vector<Point> Centroids; 00116 00121 enum PointType { 00122 FVEC, 00123 BVEC 00124 }; 00125 00131 enum RerankMode { 00132 USE_RESIDUALS, 00133 USE_INIT_POINTS 00134 }; 00135 00139 template<class Record> 00140 struct MultiIndex { 00141 vector<Record> multiindex; 00142 Multitable<int> cell_edges; 00143 }; 00144 00150 Distance Eucldistance(const Point& x, const Point& y); 00159 Distance Eucldistance(const Point& x, const Point& y, Dimensions start, Dimensions finish); 00160 00164 template<class T, class U> 00165 inline U Round(T number) { 00166 return (U)(number + 0.5); 00167 } 00168 00177 template<class T, class U> 00178 void ReadPoints(const string& filename, 00179 vector<vector<U> >* points, 00180 int count) { 00181 ifstream input; 00182 input.open(filename.c_str(), ios::binary); 00183 if(!input.good()) { 00184 throw std::logic_error("Invalid filename"); 00185 } 00186 points->resize(count); 00187 int dimension; 00188 for(PointId pid = 0; pid < count; ++pid) { 00189 input.read((char*)&dimension, sizeof(dimension)); 00190 if(dimension <= 0) { 00191 throw std::logic_error("Bad file content: non-positive dimension"); 00192 } 00193 points->at(pid).resize(dimension); 00194 for(Dimensions d = 0; d < dimension; ++d) { 00195 T buffer; 00196 input.read((char*)&(buffer), sizeof(T)); 00197 points->at(pid)[d] = Round<T, U>(buffer); 00198 } 00199 } 00200 } 00201 00209 template<class T, class U> 00210 void ReadVector(ifstream& input, vector<U>* v) { 00211 if(!input.good()) { 00212 throw std::logic_error("Bad input stream"); 00213 } 00214 int dimension; 00215 input.read((char*)&dimension, sizeof(dimension)); 00216 if(dimension <= 0) { 00217 throw std::logic_error("Bad file content: non-positive dimension"); 00218 } 00219 v->resize(dimension); 00220 for(Dimensions d = 0; d < dimension; ++d) { 00221 T buffer; 00222 input.read((char*)&buffer, sizeof(buffer)); 00223 v->at(d) = Round<T, U>(buffer); 00224 } 00225 } 00226 00234 template<class T> 00235 void ReadVocabulary(ifstream& input, 00236 Dimensions dimension, 00237 int vocabulary_size, 00238 Centroids* centroids) { 00239 if(!input.good()) { 00240 throw std::logic_error("Bad input stream"); 00241 } 00242 centroids->resize(vocabulary_size); 00243 for(ClusterId centroid_index = 0; centroid_index < centroids->size(); ++centroid_index) { 00244 centroids->at(centroid_index).resize(dimension); 00245 for(Dimensions dimension_index = 0; dimension_index < dimension; ++dimension_index) { 00246 T buffer; 00247 input.read((char*)&buffer, sizeof(buffer)); 00248 centroids->at(centroid_index)[dimension_index] = Round<T, Coord>(buffer); 00249 } 00250 } 00251 } 00252 00262 template<class T> 00263 void ReadVocabularies(const string& filename, 00264 Dimensions space_dimension, 00265 vector<Centroids>* centroids) { 00266 ifstream vocabulary; 00267 vocabulary.open(filename.c_str(), ios::binary); 00268 if(!vocabulary.good()) { 00269 throw std::logic_error("Bad vocabulary file"); 00270 } 00271 int dimension; 00272 vocabulary.read((char*)&dimension, sizeof(dimension)); 00273 if(dimension <= 0) { 00274 throw std::logic_error("Bad file content: non-positive dimension"); 00275 } 00276 int vocabs_count = space_dimension / dimension; 00277 if(space_dimension < dimension) { 00278 throw std::logic_error("Space dimension is less than vocabulary dimension"); 00279 } 00280 centroids->resize(vocabs_count); 00281 int vocabulary_size; 00282 vocabulary.read((char*)&vocabulary_size, sizeof(vocabulary_size)); 00283 for(int vocab_item = 0; vocab_item < vocabs_count; ++vocab_item) { 00284 ReadVocabulary<T>(vocabulary, dimension, vocabulary_size, &(centroids->at(vocab_item))); 00285 } 00286 } 00287 00293 template<class T> 00294 void ReadFineVocabs(const string& fine_vocabs_filename, vector<Centroids>* fine_vocabs) { 00295 ifstream fine_vocabs_stream; 00296 fine_vocabs_stream.open(fine_vocabs_filename.c_str(), ios::binary); 00297 if(!fine_vocabs_stream.good()) { 00298 throw std::logic_error("Bad fine vocabulary file"); 00299 } 00300 int vocabs_count, centroids_count, vocabs_dim; 00301 fine_vocabs_stream.read((char*)&vocabs_count, sizeof(vocabs_count)); 00302 if(vocabs_count < 1) { 00303 throw std::logic_error("Bad fine vocabulary file content: number of vocabularies < 1"); 00304 } 00305 fine_vocabs_stream.read((char*)¢roids_count, sizeof(centroids_count)); 00306 if(centroids_count < 1) { 00307 throw std::logic_error("Bad fine vocabulary file content: vocabulary capacity < 1"); 00308 } 00309 fine_vocabs_stream.read((char*)&vocabs_dim, sizeof(vocabs_dim)); 00310 if(vocabs_dim < 1) { 00311 throw std::logic_error("Bad fine vocabulary file content: vocabulary dimension < 1"); 00312 } 00313 fine_vocabs->resize(vocabs_count); 00314 for(int voc_index = 0; voc_index < vocabs_count; ++voc_index) { 00315 ReadVocabulary<T>(fine_vocabs_stream, vocabs_dim, centroids_count, &(fine_vocabs->at(voc_index))); 00316 } 00317 } 00318 00327 void GetSubpoints(const Points& points, 00328 const Dimensions start_dim, 00329 const Dimensions final_dim, 00330 Points* subpoints); 00331 00340 ClusterId GetNearestClusterId(const Point& point, const Centroids& centroids, 00341 const Dimensions start_dim, const Dimensions final_dim); 00342 00350 void GetResidual(const Point& point, const CoarseQuantization& coarse_quantizations, 00351 const vector<Centroids>& centroids, Point* residual); 00359 void GetResidual(const Point& point, const CoarseQuantization& coarse_quantizations, 00360 const vector<Centroids>& centroids, Coord* residual); 00361 00370 void GetNearestClusterIdsForPointSubset(const Points& points, const Centroids& centroids, 00371 const PointId start_pid, const PointId final_pid, 00372 vector<ClusterId>* nearest); 00373 00384 void GetNearestClusterIdsForSubpoints(const Points& points, const Centroids& centroids, 00385 const Dimensions start_dim, const Dimensions final_dim, 00386 int threads_count, vector<ClusterId>* nearest); 00387 00395 void GetPointsCoarseQuaintizations(const Points& points, const vector<Centroids>& centroids, 00396 const int threads_count, 00397 vector<CoarseQuantization>* coarse_quantizations); 00398 00399 00403 struct IndexConfig { 00404 RerankMode rerank_mode; 00405 vector<Centroids> fine_vocabs; 00406 }; 00407 00412 struct RerankADC8 { 00413 PointId pid; 00414 FineClusterId quantizations[8]; 00415 template<class Archive> 00416 void serialize(Archive& arc, unsigned int version) { 00417 arc & pid; 00418 arc & quantizations; 00419 } 00420 }; 00421 00426 struct RerankADC16 { 00427 PointId pid; 00428 FineClusterId quantizations[16]; 00429 template<class Archive> 00430 void serialize(Archive& arc, unsigned int version) { 00431 arc & pid; 00432 arc & quantizations; 00433 } 00434 }; 00435 00436 00437 00438 00439 00440 00441 00442 00443 00444 00445 00446 00447 00448 00449 00450 00451 00452 00453