diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..d23cc2425 --- /dev/null +++ b/.clang-format @@ -0,0 +1,192 @@ +--- +Language: Cpp +# BasedOnStyle: Microsoft +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignArrayOfStructures: None +AlignConsecutiveMacros: None +AlignConsecutiveAssignments: None +AlignConsecutiveBitFields: None +AlignConsecutiveDeclarations: None +AlignEscapedNewlines: Right +AlignOperands: Align +AlignTrailingComments: true +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortEnumsOnASingleLine: false +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortLambdasOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine +AttributeMacros: + - __capability +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterCaseLabel: false + AfterClass: true + AfterControlStatement: Always + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: false + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: true +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +QualifierAlignment: Leave +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DeriveLineEnding: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +PackConstructorInitializers: BinPack +BasedOnStyle: '' +ConstructorInitializerAllOnOneLineOrOnePerLine: false +AllowAllConstructorInitializersOnNextLine: true +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IfMacros: + - KJ_IF_MAYBE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 1 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '(Test)?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseLabels: false +IndentCaseBlocks: false +IndentGotoLabels: true +IndentPPDirectives: None +IndentExternBlock: AfterExternBlock +IndentRequires: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertTrailingCommas: None +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +LambdaBodyIndentation: Signature +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakOpenParenthesis: 0 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 1000 +PenaltyIndentedWhitespace: 0 +PointerAlignment: Right +PPIndentWidth: -1 +ReferenceAlignment: Pointer +ReflowComments: true +RemoveBracesLLVM: false +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 1 +SortIncludes: CaseSensitive +SortJavaStaticImport: Before +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDefinitionName: false + AfterFunctionDeclarationName: false + AfterIfMacros: true + AfterOverloadedOperator: false + BeforeNonEmptyParentheses: false +SpaceAroundPointerQualifiers: Default +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: Never +SpacesInConditionalStatement: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +BitFieldColonSpacing: Both +Standard: Latest +StatementAttributeLikeMacros: + - Q_EMIT +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 4 +UseCRLF: false +UseTab: Never +WhitespaceSensitiveMacros: + - STRINGIZE + - PP_STRINGIZE + - BOOST_PP_STRINGIZE + - NS_SWIFT_NAME + - CF_SWIFT_NAME +... + diff --git a/.gitignore b/.gitignore index 750f6dec5..4642f4e8e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,9 @@ *.obj *.elf +# Log files +*.tlog + # Linker output *.ilk *.map @@ -60,4 +63,402 @@ dkms.conf /build /ipch /packages -/out/build/x64-Debug \ No newline at end of file +/out/build/x64-Debug +/sgKey.snk +## files generated by popular Visual Studio add-ons. +## +## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore + +# User-specific files +*.rsuser +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Mono auto generated files +mono_crash.* + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio 2015/2017 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# Visual Studio 2017 auto generated files +Generated\ Files/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUnit +*.VisualState.xml +TestResult.xml +nunit-*.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# Benchmark Results +BenchmarkDotNet.Artifacts/ + +# .NET Core +project.lock.json +project.fragment.lock.json +artifacts/ + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Files built by Visual Studio +*_i.c +*_p.c +*_h.h +*.ilk +*.meta +*.obj +*.iobj +*.pch +*.pdb +*.ipdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*_wpftmp.csproj +*.log +*.tlog +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# Visual Studio Trace Files +*.e2e + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# AxoCover is a Code Coverage Tool +.axoCover/* +!.axoCover/settings.json + +# Coverlet is a free, cross platform Code Coverage Tool +coverage*.json +coverage*.xml +coverage*.info + +# Visual Studio code coverage results +*.coverage +*.coveragexml + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# Note: Comment the next line if you want to checkin your web deploy settings, +# but database connection strings (with potential passwords) will be unencrypted +*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# NuGet Symbol Packages +*.snupkg +# The packages folder can be ignored because of Package Restore +**/[Pp]ackages/* +# except build/, which is used as an MSBuild target. +!**/[Pp]ackages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/[Pp]ackages/repositories.config +# NuGet v3's project.json files produces more ignorable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt +*.appx +*.appxbundle +*.appxupload + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!?*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +orleans.codegen.cs + +# Including strong name files can present a security risk +# (https://github.com/github/gitignore/pull/2483#issue-259490424) +#*.snk + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm +ServiceFabricBackup/ +*.rptproj.bak + +# SQL Server files +*.mdf +*.ldf +*.ndf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings +*.rptproj.rsuser +*- [Bb]ackup.rdl +*- [Bb]ackup ([0-9]).rdl +*- [Bb]ackup ([0-9][0-9]).rdl + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat +node_modules/ + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) +*.vbw + +# Visual Studio 6 auto-generated project file (contains which files were open etc.) +*.vbp + +# Visual Studio 6 workspace and project file (working project files containing files to include in project) +*.dsw +*.dsp + +# Visual Studio 6 technical files +*.ncb +*.aps + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# CodeRush personal settings +.cr/personal + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc + +# Cake - Uncomment if you are using it +# tools/** +# !tools/packages.config + +# Tabs Studio +*.tss + +# Telerik's JustMock configuration file +*.jmconfig + +# BizTalk build output +*.btp.cs +*.btm.cs +*.odx.cs +*.xsd.cs + +# OpenCover UI analysis results +OpenCover/ + +# Azure Stream Analytics local run output +ASALocalRun/ + +# MSBuild Binary and Structured Log +*.binlog + +# NVidia Nsight GPU debugger configuration file +*.nvuser + +# MFractors (Xamarin productivity tool) working folder +.mfractor/ + +# Local History for Visual Studio +.localhistory/ + +# Visual Studio History (VSHistory) files +.vshistory/ + +# BeatPulse healthcheck temp database +healthchecksdb + +# Backup folder for Package Reference Convert tool in Visual Studio 2017 +MigrationBackup/ + +# Ionide (cross platform F# VS Code tools) working folder +.ionide/ + +# Fody - auto-generated XML schema +FodyWeavers.xsd + +# VS Code files for those working on multiple tools +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Windows Installer files from build outputs +*.cab +*.msi +*.msix +*.msm +*.msp + +# JetBrains Rider +*.sln.iml diff --git a/.gitmodules b/.gitmodules index a6fa563cf..c37db6a11 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,12 @@ path = ThirdParty/zstd url = https://github.com/facebook/zstd branch = release +[submodule "ThirdParty/spdk"] + path = ThirdParty/spdk + url = https://github.com/spdk/spdk +[submodule "ThirdParty/isal-l_crypto"] + path = ThirdParty/isal-l_crypto + url = https://github.com/intel/isa-l_crypto +[submodule "ThirdParty/RocksDB"] + path = ThirdParty/RocksDB + url = https://github.com/PtilopsisL/rocksdb diff --git a/AnnService.users.props b/AnnService.users.props index e65231357..15f2fb5bd 100644 --- a/AnnService.users.props +++ b/AnnService.users.props @@ -10,9 +10,10 @@ $(SolutionDir)\$(Platform)\$(Configuration)\ - $(SolutionDir)\$(Platform)\$(Configuration)\ + $(SolutionDir)\$(Platform)\$(Configuration)\ + 3.9 diff --git a/AnnService/Aggregator.vcxproj b/AnnService/Aggregator.vcxproj index 4946c13f9..a969443ce 100644 --- a/AnnService/Aggregator.vcxproj +++ b/AnnService/Aggregator.vcxproj @@ -42,15 +42,16 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -132,11 +133,13 @@ true _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) stdcpp17 + Guard true true CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + true @@ -158,26 +161,26 @@ - - - - - - - - + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - + + + + + + + + \ No newline at end of file diff --git a/AnnService/BalancedDataPartition.vcxproj b/AnnService/BalancedDataPartition.vcxproj index 6ce001f60..f4a712655 100644 --- a/AnnService/BalancedDataPartition.vcxproj +++ b/AnnService/BalancedDataPartition.vcxproj @@ -42,13 +42,13 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte @@ -95,6 +95,7 @@ true /Zc:twoPhase- %(AdditionalOptions) stdcpp17 + Guard true @@ -148,5 +149,12 @@ + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + \ No newline at end of file diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt index 5aaa1cbca..851996f7e 100644 --- a/AnnService/CMakeLists.txt +++ b/AnnService/CMakeLists.txt @@ -10,6 +10,39 @@ include_directories(${Zstd}/lib) file(GLOB_RECURSE HDR_FILES ${AnnService}/inc/Core/*.h ${AnnService}/inc/Helper/*.h) file(GLOB_RECURSE SRC_FILES ${AnnService}/src/Core/*.cpp ${AnnService}/src/Helper/*.cpp) +set(SPDK_LIBRARIES "") +if (SPDK) + set(Spdk ${PROJECT_SOURCE_DIR}/ThirdParty/spdk/build) + set(Dpdk ${PROJECT_SOURCE_DIR}/ThirdParty/spdk/dpdk/build) + set(IsalLCrypto ${PROJECT_SOURCE_DIR}/ThirdParty/isal-l_crypto/.libs/libisal_crypto.a) + set(SpdkLibPrefix ${Spdk}/lib/libspdk_) + set(DpdkLibPrefix ${Dpdk}/lib/librte_) + if ((NOT EXISTS ${SpdkLibPrefix}bdev_nvme.a) OR (NOT EXISTS ${DpdkLibPrefix}pci.a) OR (NOT EXISTS ${IsalLCrypto})) + set (SPDK_LIBRARIES "") + message (FATAL_ERROR "Cound not find SPDK, DPDK and IsalLCrypto!") + list(REMOVE_ITEM HDR_FILES + ${AnnService}/inc/Core/SPANN/ExtraSPDKController.h + ) + + list(REMOVE_ITEM SRC_FILES + ${AnnService}/src/Core/SPANN/ExtraSPDKController.cpp + ) + else() + include_directories(${Spdk}/include) + add_definitions(-DSPDK) + set(SPDK_LIBRARIES -Wl,--whole-archive ${SpdkLibPrefix}bdev_nvme.a ${SpdkLibPrefix}bdev.a ${SpdkLibPrefix}nvme.a ${SpdkLibPrefix}vfio_user.a ${SpdkLibPrefix}sock.a ${SpdkLibPrefix}dma.a ${SpdkLibPrefix}notify.a ${SpdkLibPrefix}accel.a ${SpdkLibPrefix}event_bdev.a ${SpdkLibPrefix}event_accel.a ${SpdkLibPrefix}vmd.a ${SpdkLibPrefix}event_vmd.a ${SpdkLibPrefix}event_sock.a ${SpdkLibPrefix}event_iobuf.a ${SpdkLibPrefix}event.a ${SpdkLibPrefix}env_dpdk.a ${SpdkLibPrefix}log.a ${SpdkLibPrefix}thread.a ${SpdkLibPrefix}rpc.a ${SpdkLibPrefix}init.a ${SpdkLibPrefix}jsonrpc.a ${SpdkLibPrefix}json.a ${SpdkLibPrefix}trace.a ${SpdkLibPrefix}util.a ${DpdkLibPrefix}mempool.a ${DpdkLibPrefix}mempool_ring.a ${DpdkLibPrefix}eal.a ${DpdkLibPrefix}ring.a ${DpdkLibPrefix}telemetry.a ${DpdkLibPrefix}bus_pci.a ${DpdkLibPrefix}kvargs.a ${DpdkLibPrefix}pci.a -Wl,--no-whole-archive dl rt isal ${IsalLCrypto} uuid) + message (STATUS "Found SPDK:${SPDK_LIBRARIES}") + endif() +else() + list(REMOVE_ITEM HDR_FILES + ${AnnService}/inc/Core/SPANN/ExtraSPDKController.h + ) + + list(REMOVE_ITEM SRC_FILES + ${AnnService}/src/Core/SPANN/ExtraSPDKController.cpp + ) +endif() + list(REMOVE_ITEM HDR_FILES ${AnnService}/inc/Core/Common/DistanceUtils.h ${AnnService}/inc/Core/Common/SIMDUtils.h @@ -38,9 +71,10 @@ if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") endif() add_library (SPTAGLib SHARED ${SRC_FILES} ${HDR_FILES}) -target_link_libraries (SPTAGLib DistanceUtils libzstd_shared) +target_link_libraries (SPTAGLib DistanceUtils ${RocksDB_LIBRARIES} ${uring_LIBRARIES} libzstd_shared ${NUMA_LIBRARY} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) add_library (SPTAGLibStatic STATIC ${SRC_FILES} ${HDR_FILES}) -target_link_libraries (SPTAGLibStatic DistanceUtils libzstd_static) +target_link_libraries (SPTAGLibStatic DistanceUtils ${RocksDB_LIBRARIES} ${uring_LIBRARIES} libzstd_static ${NUMA_LIBRARY_STATIC} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) + if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") target_compile_options(SPTAGLibStatic PRIVATE -fPIC) endif() @@ -87,23 +121,31 @@ endif() file(GLOB_RECURSE SSD_SERVING_HDR_FILES ${AnnService}/inc/SSDServing/*.h) file(GLOB_RECURSE SSD_SERVING_FILES ${AnnService}/src/SSDServing/*.cpp) -#if(NOT WIN32) -# list(REMOVE_ITEM SSD_SERVING_HDR_FILES -# ${VECTORSEARCH_INC_DIR}/AsyncFileReader.h -# ) -#elseif(WIN32) -# list(REMOVE_ITEM SSD_SERVING_HDR_FILES -# ${VECTORSEARCH_INC_DIR}/AsyncFileReaderLinux.h -# ) -#endif() + +file(GLOB_RECURSE SPFRESH_HDR_FILES ${AnnService}/inc/SPFresh/*.h) +file(GLOB_RECURSE SPFRESH_FILES ${AnnService}/src/SPFresh/*.cpp) + +file(GLOB_RECURSE KEYVALUETEST_HDR_FILES ${AnnService}/inc/KeyValueTest/*.h) +file(GLOB_RECURSE KEYVALUETEST_FILES ${AnnService}/src/KeyValueTest/*.cpp) + +file(GLOB_RECURSE USEFULTOOL_FILES ${AnnService}/src/UsefulTool/*.cpp) add_executable(ssdserving ${SSD_SERVING_HDR_FILES} ${SSD_SERVING_FILES}) -target_link_libraries(ssdserving SPTAGLibStatic ${Boost_LIBRARIES}) +add_executable(spfresh ${SPFRESH_HDR_FILES} ${SPFRESH_FILES}) +add_executable(keyvaluetest ${KEYVALUETEST_HDR_FILES} ${KEYVALUETEST_FILES}) +add_executable(usefultool ${USEFULTOOL_FILES}) +target_link_libraries(ssdserving SPTAGLibStatic ${Boost_LIBRARIES} ${RocksDB_LIBRARIES} ${uring_LIBRARIES} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) +target_link_libraries(spfresh SPTAGLibStatic ${Boost_LIBRARIES} ${RocksDB_LIBRARIES} ${uring_LIBRARIES} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) +target_link_libraries(keyvaluetest SPTAGLibStatic ${Boost_LIBRARIES} ${RocksDB_LIBRARIES} ${uring_LIBRARIES} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) +target_link_libraries(usefultool SPTAGLibStatic ${Boost_LIBRARIES} ${RocksDB_LIBRARIES} ${uring_LIBRARIES} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) target_compile_definitions(ssdserving PRIVATE _exe) +target_compile_definitions(spfresh PRIVATE _exe) +target_compile_definitions(keyvaluetest PRIVATE _exe) +target_compile_definitions(usefultool PRIVATE _exe) # for Test add_library(ssdservingLib ${SSD_SERVING_HDR_FILES} ${SSD_SERVING_FILES}) -target_link_libraries(ssdservingLib SPTAGLibStatic ${Boost_LIBRARIES}) +target_link_libraries(ssdservingLib SPTAGLibStatic ${Boost_LIBRARIES} ${RocksDB_LIBRARIES} ${uring_LIBRARIES} ${TBB_LIBRARIES} ${SPDK_LIBRARIES}) find_package(MPI) if (MPI_FOUND) diff --git a/AnnService/Client.vcxproj b/AnnService/Client.vcxproj index 9381af598..fdf22990f 100644 --- a/AnnService/Client.vcxproj +++ b/AnnService/Client.vcxproj @@ -42,15 +42,16 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -105,10 +106,13 @@ CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + true _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) stdcpp17 + Level3 + Guard @@ -125,26 +129,26 @@ - - - - - - - - + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - + + + + + + + + \ No newline at end of file diff --git a/AnnService/CoreLibrary.vcxproj b/AnnService/CoreLibrary.vcxproj index 1c73e1873..721d74d4b 100644 --- a/AnnService/CoreLibrary.vcxproj +++ b/AnnService/CoreLibrary.vcxproj @@ -1,14 +1,6 @@  - - Debug - Win32 - - - Release - Win32 - Debug x64 @@ -27,43 +19,25 @@ - - Application - true - v142 - MultiByte - - - Application - false - v142 - true - MultiByte - StaticLibrary true - v142 + v143 MultiByte StaticLibrary false - v142 + v143 true MultiByte + Spectre - - - - - - @@ -78,15 +52,6 @@ $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ $(OutLibDir) - - - Level3 - Disabled - true - true - /Zc:twoPhase- %(AdditionalOptions) - - Level3 @@ -97,25 +62,10 @@ _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) Guard ProgramDatabase - /Zc:twoPhase- /Zc:__cplusplus %(AdditionalOptions) + /Zc:twoPhase- /Zc:__cplusplus /bigobj %(AdditionalOptions) stdcpp17 - - - Level3 - MaxSpeed - true - true - true - true - /Zc:twoPhase- %(AdditionalOptions) - - - true - true - - Level3 @@ -126,8 +76,13 @@ true true _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) - /Zc:twoPhase- /Zc:__cplusplus %(AdditionalOptions) + /Zc:twoPhase- /Zc:__cplusplus /bigobj %(AdditionalOptions) stdcpp17 + Speed + AdvancedVectorExtensions + AnySuitable + true + Fast true @@ -135,15 +90,18 @@ + + + @@ -159,14 +117,21 @@ + + - + + + + + + @@ -177,6 +142,7 @@ + @@ -198,9 +164,14 @@ + + + + + @@ -221,13 +192,5 @@ - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - + \ No newline at end of file diff --git a/AnnService/CoreLibrary.vcxproj.filters b/AnnService/CoreLibrary.vcxproj.filters index 7c749f00c..3f7b3fa11 100644 --- a/AnnService/CoreLibrary.vcxproj.filters +++ b/AnnService/CoreLibrary.vcxproj.filters @@ -85,6 +85,12 @@ Header Files\Core + + Header Files\Core + + + Header Files\Core + Header Files\Core @@ -196,7 +202,7 @@ Header Files\Core\SPANN - + Header Files\Core\SPANN @@ -220,6 +226,48 @@ Header Files\Core\Common + + Header Files\Core\SPANN + + + Header Files\Core\SPANN + + + Header Files\Helper + + + Header Files\Core\Common + + + Header Files\Core\SPANN + + + Header Files\Core\SPANN + + + Header Files\Core\SPANN + + + Header Files\Core\SPANN + + + Header Files\Helper + + + Header Files\Core\Common + + + Header Files\Core\SPANN + + + Header Files\Core\SPANN + + + Header Files\Core\Common + + + Header Files\Core\Common + @@ -237,6 +285,12 @@ Source Files\Core + + Source Files\Core + + + Source Files\Core + Source Files\Helper @@ -282,14 +336,23 @@ Source Files\Core\Common + + Source Files\Core\Common + + + Source Files\Core\Common + Source Files\Core\SPANN Source Files\Helper - - Source Files\Core\Common + + Source Files\Core\SPANN + + + Source Files\Core\SPANN diff --git a/AnnService/GPUCoreLibrary.vcxproj b/AnnService/GPUCoreLibrary.vcxproj index 2296ce147..e2f78a12a 100644 --- a/AnnService/GPUCoreLibrary.vcxproj +++ b/AnnService/GPUCoreLibrary.vcxproj @@ -15,6 +15,7 @@ GPUCoreLibrary 10.0 GPUCoreLibrary + $(CUDA_PATH) @@ -22,14 +23,15 @@ StaticLibrary true MultiByte - v142 + v143 + x64 StaticLibrary false true MultiByte - v142 + v143 @@ -50,6 +52,7 @@ false + $(CudaToolkitIncludeDir);$(PublicIncludeDirectories) @@ -60,6 +63,7 @@ true 4819;%(DisableSpecificWarnings) MultiThreadedDebug + Default true @@ -70,10 +74,13 @@ 64 compute_70,sm_70;compute_75,sm_75;compute_80,sm_80; WIN32;%(Defines) - MTd - /openmp /std:c++14 /Zc:__cplusplus /FS + InheritFromHost + /openmp /std:c++17 /Zc:__cplusplus /FS true InheritFromHost + %(Include) + Static + 64 @@ -86,6 +93,8 @@ true 4819;%(DisableSpecificWarnings) MultiThreaded + Default + ProgramDatabase true @@ -97,7 +106,7 @@ 64 compute_70,sm_70;compute_75,sm_75;compute_80,sm_80 - /openmp /std:c++14 /Zc:__cplusplus /FS + /openmp /std:c++17 /Zc:__cplusplus /FS true O2 MT @@ -106,11 +115,13 @@ + + @@ -130,6 +141,8 @@ + + @@ -186,7 +199,7 @@ - + @@ -196,27 +209,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/GPUCoreLibrary.vcxproj.filters b/AnnService/GPUCoreLibrary.vcxproj.filters index 79e050332..c9500e6d2 100644 --- a/AnnService/GPUCoreLibrary.vcxproj.filters +++ b/AnnService/GPUCoreLibrary.vcxproj.filters @@ -163,9 +163,6 @@ Header Files\Core\Common - - Header Files\Core\Common\cuda - Header Files\Core\Common\cuda @@ -226,6 +223,12 @@ Header Files\Core\Common + + Header Files\Core\Common\cuda + + + Header Files\Core\Common\cuda + @@ -270,6 +273,12 @@ Source Files\Helper + + Source Files\Core\SPANN + + + Source Files\Core\Common + @@ -290,9 +299,6 @@ Source Files\Core\Common - - Source Files\Core\SPANN - Source Files\Core\Common @@ -302,5 +308,8 @@ Source Files\Helper + + Source Files\Core\Common + \ No newline at end of file diff --git a/AnnService/GPUIndexBuilder.vcxproj b/AnnService/GPUIndexBuilder.vcxproj index 1d9c0ddc7..a9ac5367e 100644 --- a/AnnService/GPUIndexBuilder.vcxproj +++ b/AnnService/GPUIndexBuilder.vcxproj @@ -22,14 +22,14 @@ Application true MultiByte - v142 + v143 Application false true MultiByte - v142 + v143 @@ -71,7 +71,7 @@ compute_70,sm_70;compute_75,sm_75;compute_80,sm_80 WIN32;%(Defines) MTd - /openmp /std:c++14 /Zc:__cplusplus /FS + /openmp /std:c++17 /Zc:__cplusplus /FS true InheritFromHost @@ -96,13 +96,14 @@ 64 compute_70,sm_70;compute_75,sm_75;compute_80,sm_80 - /openmp /std:c++14 /Zc:__cplusplus /FS + /openmp /std:c++17 /Zc:__cplusplus /FS true O2 MT + @@ -112,27 +113,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/GPUSSDServing.vcxproj b/AnnService/GPUSSDServing.vcxproj index 97d63367a..7b1e1bada 100644 --- a/AnnService/GPUSSDServing.vcxproj +++ b/AnnService/GPUSSDServing.vcxproj @@ -22,14 +22,14 @@ Application true MultiByte - v142 + v143 Application false true MultiByte - v142 + v143 @@ -71,9 +71,11 @@ compute_70,sm_70;compute_75,sm_75;compute_80,sm_80 WIN32;%(Defines) MTd - /openmp /std:c++14 /Zc:__cplusplus /FS /D "_exe" + /openmp /std:c++17 /Zc:__cplusplus /FS /D "_exe" true InheritFromHost + Static + 64 @@ -97,7 +99,7 @@ 64 compute_70,sm_70;compute_75,sm_75;compute_80,sm_80 - /openmp /std:c++14 /Zc:__cplusplus /FS /D "_exe" + /openmp /std:c++17 /Zc:__cplusplus /FS /D "_exe" true O2 MT @@ -119,27 +121,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/GPUSSDServing.vcxproj.filters b/AnnService/GPUSSDServing.vcxproj.filters index a8eb95b2f..41835709b 100644 --- a/AnnService/GPUSSDServing.vcxproj.filters +++ b/AnnService/GPUSSDServing.vcxproj.filters @@ -26,8 +26,6 @@ - - Source Files - + \ No newline at end of file diff --git a/AnnService/IndexBuilder.vcxproj b/AnnService/IndexBuilder.vcxproj index 0900590cb..deb392b01 100644 --- a/AnnService/IndexBuilder.vcxproj +++ b/AnnService/IndexBuilder.vcxproj @@ -42,15 +42,16 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -95,10 +96,12 @@ true /Zc:twoPhase- %(AdditionalOptions) stdcpp17 + Guard true true + true @@ -147,27 +150,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/IndexSearcher.vcxproj b/AnnService/IndexSearcher.vcxproj index 6d1378379..77b43c579 100644 --- a/AnnService/IndexSearcher.vcxproj +++ b/AnnService/IndexSearcher.vcxproj @@ -43,15 +43,16 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -96,10 +97,12 @@ true /Zc:twoPhase- %(AdditionalOptions) stdcpp17 + Guard true true + true @@ -148,27 +151,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/Quantizer.vcxproj b/AnnService/Quantizer.vcxproj index 942e55e1d..a725f173f 100644 --- a/AnnService/Quantizer.vcxproj +++ b/AnnService/Quantizer.vcxproj @@ -36,43 +36,25 @@ - - Application - true - v142 - MultiByte - - - Application - false - v142 - true - MultiByte - Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre - - - - - - @@ -134,6 +116,7 @@ true Guard ProgramDatabase + stdcpp17 Console @@ -152,6 +135,7 @@ true /Zc:twoPhase- %(AdditionalOptions) Guard + stdcpp17 Console @@ -160,30 +144,9 @@ true %(AdditionalLibraryDirectories) CoreLibrary.lib;%(AdditionalDependencies) + true - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/SSDServing.vcxproj b/AnnService/SSDServing.vcxproj index f66765978..da8261b0d 100644 --- a/AnnService/SSDServing.vcxproj +++ b/AnnService/SSDServing.vcxproj @@ -1,4 +1,4 @@ - + @@ -53,15 +53,16 @@ StaticLibrary true - v142 + v143 MultiByte StaticLibrary false - v142 + v143 true MultiByte + Spectre @@ -161,6 +162,12 @@ true /Zc:twoPhase- %(AdditionalOptions) stdcpp17 + $(IntDir)$(ProjectName).pdb + Speed + AdvancedVectorExtensions + AnySuitable + true + Fast Console @@ -172,27 +179,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj index 3b38afe40..fefafe8b8 100644 --- a/AnnService/Server.vcxproj +++ b/AnnService/Server.vcxproj @@ -42,15 +42,16 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -105,10 +106,13 @@ CoreLibrary.lib;SocketLib.lib;%(AdditionalDependencies) + true _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) stdcpp17 + Level3 + Guard @@ -133,26 +137,26 @@ - - - - - - - - + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - + + + + + + + + \ No newline at end of file diff --git a/AnnService/SocketLib.vcxproj b/AnnService/SocketLib.vcxproj index d29ac8ede..4a28ae733 100644 --- a/AnnService/SocketLib.vcxproj +++ b/AnnService/SocketLib.vcxproj @@ -41,15 +41,16 @@ StaticLibrary true - v142 + v143 MultiByte StaticLibrary false - v142 + v143 true MultiByte + Spectre @@ -88,6 +89,7 @@ _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) stdcpp17 + Level3 @@ -115,12 +117,12 @@ - + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - + \ No newline at end of file diff --git a/AnnService/inc/Core/BKT/Index.h b/AnnService/inc/Core/BKT/Index.h index f3131343c..dd9734bed 100644 --- a/AnnService/inc/Core/BKT/Index.h +++ b/AnnService/inc/Core/BKT/Index.h @@ -41,8 +41,32 @@ namespace SPTAG public: RebuildJob(COMMON::Dataset* p_data, COMMON::BKTree* p_tree, COMMON::RelativeNeighborhoodGraph* p_graph, DistCalcMethod p_distMethod) : m_data(p_data), m_tree(p_tree), m_graph(p_graph), m_distMethod(p_distMethod) {} + + void exec(void* p_workSpace, IAbortOperation* p_abort) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot support job.exec(workspace, abort)!\n"); + } + void exec(IAbortOperation* p_abort) { - m_tree->Rebuild(*m_data, m_distMethod, p_abort); + COMMON::BKTree newTrees(*m_tree); + newTrees.BuildTrees(*m_data, m_distMethod, 1, nullptr, nullptr, false, p_abort); + + std::unique_lock lock(*(m_tree->m_lock)); + const std::unordered_map* idmap = &(m_tree->GetSampleMap()); + for (auto iter = idmap->begin(); iter != idmap->end(); iter++) { + if (iter->first < 0) + { + (*m_graph)[-1 - iter->first][m_graph->m_iNeighborhoodSize - 1] = -1; + } + } + m_tree->SwapTree(newTrees); + + const std::unordered_map* newidmap = &(m_tree->GetSampleMap()); + for (auto iter = newidmap->begin(); iter != newidmap->end(); iter++) { + if (iter->first < 0) + { + (*m_graph)[-1 - iter->first][m_graph->m_iNeighborhoodSize - 1] = -2 - iter->second; + } + } } private: COMMON::Dataset* m_data; @@ -72,7 +96,6 @@ namespace SPTAG std::shared_timed_mutex m_dataDeleteLock; COMMON::Labelset m_deletedID; - std::unique_ptr> m_workSpacePool; Helper::ThreadPool m_threadPool; int m_iNumberOfThreads; @@ -85,6 +108,7 @@ namespace SPTAG int m_iNumberOfInitialDynamicPivots; int m_iNumberOfOtherDynamicPivots; int m_iHashTableExp; + std::unique_ptr> m_workSpaceFactory; public: Index() @@ -98,6 +122,7 @@ namespace SPTAG m_pSamples.SetName("Vector"); m_fComputeDistance = std::function(COMMON::DistanceCalcSelector(m_iDistCalcMethod)); m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; + m_workSpaceFactory = std::make_unique>(); } ~Index() {} @@ -122,6 +147,9 @@ namespace SPTAG return 1.0f - xy / (sqrt(xx) * sqrt(yy)); } inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } + inline float GetDistance(const void* target, const SizeType idx) const { + return ComputeDistance(target, m_pSamples.At(idx)); + } inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; } inline bool ContainSample(const SizeType idx) const { return idx >= 0 && idx < m_deletedID.R() && !m_deletedID.Contains(idx); } inline bool NeedRefine() const { return m_deletedID.Count() > (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); } @@ -154,9 +182,18 @@ namespace SPTAG ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized = false, bool p_shareOwnership = false); ErrorCode SearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const; + + std::shared_ptr GetIterator(const void* p_target, bool p_searchDeleted = false, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + ErrorCode SearchIndexIterativeNext(QueryResult& p_query, COMMON::WorkSpace* workSpace, int p_batch, int& resultCount, bool p_isFirst, bool p_searchDeleted) const; + ErrorCode SearchIndexIterativeEnd(std::unique_ptr workSpace) const; + bool SearchIndexIterativeFromNeareast(QueryResult& p_query, COMMON::WorkSpace* p_space, bool p_isFirst, bool p_searchDeleted = false) const; + std::unique_ptr RentWorkSpace(int batch, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + ErrorCode SearchIndexWithFilter(QueryResult& p_query, std::function filterFunc, int maxCheck = 0, bool p_searchDeleted = false) const; ErrorCode RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const; ErrorCode SearchTree(QueryResult &p_query) const; ErrorCode AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false, bool p_normalized = false); + ErrorCode AddIndexIdx(SizeType begin, SizeType end); + ErrorCode AddIndexId(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, int& beginHead, int& endHead); ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum); ErrorCode DeleteIndex(const SizeType& p_id); @@ -164,11 +201,42 @@ namespace SPTAG std::string GetParameter(const char* p_param, const char* p_section = nullptr) const; ErrorCode UpdateIndex(); - ErrorCode RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort); + ErrorCode RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping); ErrorCode RefineIndex(std::shared_ptr& p_newIndex); + ErrorCode SetWorkSpaceFactory(std::unique_ptr> up_workSpaceFactory) + { + SPTAG::COMMON::IWorkSpaceFactory* raw_generic_ptr = up_workSpaceFactory.release(); + if (!raw_generic_ptr) return ErrorCode::Fail; + + + SPTAG::COMMON::IWorkSpaceFactory* raw_specialized_ptr = dynamic_cast*>(raw_generic_ptr); + if (!raw_specialized_ptr) + { + delete raw_generic_ptr; + return ErrorCode::Fail; + } + else + { + m_workSpaceFactory = std::unique_ptr>(raw_specialized_ptr); + return ErrorCode::Success; + } + } + + virtual ErrorCode Check() override; + virtual std::string GetPriorityID(int queryID) const override { return m_pTrees.GetPriorityID(m_pSamples, queryID, m_fComputeDistance); } private: - void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted, bool p_searchDuplicated) const; + + int SearchIndexIterative(COMMON::QueryResultSet& p_query, COMMON::WorkSpace& p_space, bool p_isFirst, int batch, bool p_searchDeleted, bool p_searchDuplicated) const; + + template &, SizeType, float), bool (*checkFilter)(const std::shared_ptr &, SizeType, std::function)> + int SearchIterative(COMMON::QueryResultSet& p_query, + COMMON::WorkSpace& p_space, bool p_isFirst, int batch) const; + + void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted, bool p_searchDuplicated, std::function filterFunc = nullptr) const; + template &, SizeType, float), bool(*checkFilter)(const std::shared_ptr&, SizeType, std::function)> + void Search(COMMON::QueryResultSet& p_query, COMMON::WorkSpace& p_space, std::function filterFunc) const; }; } // namespace BKT } // namespace SPTAG diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index 1bad007dd..dce9e4d51 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -1,9 +1,17 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #ifndef _SPTAG_CORE_COMMONDEFS_H_ #define _SPTAG_CORE_COMMONDEFS_H_ +#ifdef DEBUG +#define IF_DEBUG(statement) statement +#define IF_NDEBUG(statement) +#else +#define IF_DEBUG(statement) +#define IF_NDEBUG(statement) statement +#endif + #include #include #include @@ -12,8 +20,8 @@ #include #include #include +#include #include "inc/Helper/Logging.h" -#include "inc/Helper/DiskIO.h" #ifndef _MSC_VER #include @@ -40,6 +48,13 @@ inline bool fileexists(const char* path) { struct stat info; return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0; } +/* +inline int64_t filesize(const char* path) { + struct stat info; + if (stat(path, &info) == 0) return info.st_blocks * info.st_blksize; + return 0; +} +*/ template inline T min(T a, T b) { @@ -80,9 +95,45 @@ inline bool fileexists(const TCHAR* path) { auto dwAttr = GetFileAttributes(path); return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0; } +/* +inline int64_t filesize(const char* path) { + WIN32_FILE_ATTRIBUTE_DATA fad; + if (!GetFileAttributesEx(path, GetFileExInfoStandard, &fad)) return -1; + LARGE_INTEGER size; + size.HighPart = fad.nFileSizeHigh; + size.LowPart = fad.nFileSizeLow; + return size.QuadPart; +} +*/ + #define mkdir(a) CreateDirectory(a, NULL) + +#ifndef max +#define max(a,b) (((a) > (b)) ? (a) : (b)) +#endif + +#ifndef min +#define min(a,b) (((a) < (b)) ? (a) : (b)) +#endif + +FORCEINLINE +char +InterlockedCompareExchange( + _Inout_ _Interlocked_operand_ char volatile* Destination, + _In_ char Exchange, + _In_ char Comperand +) +{ + return (char)_InterlockedCompareExchange8(Destination, Exchange, Comperand); +} + #endif +inline int64_t filesize(const char* path) { + std::filesystem::path p{path}; + return std::filesystem::file_size(p); +} + namespace SPTAG { #if (__cplusplus < 201703L) @@ -90,115 +141,121 @@ namespace SPTAG #define ALIGN_FREE(ptr) _mm_free(ptr) #define PAGE_ALLOC(size) _mm_malloc(size, 512) #define PAGE_FREE(ptr) _mm_free(ptr) +#define BLOCK_ALLOC(size, val) _mm_malloc(size, val) +#define BLOCK_FREE(ptr, val) _mm_free(ptr) #else #define ALIGN_ALLOC(size) ::operator new(size, (std::align_val_t)32) #define ALIGN_FREE(ptr) ::operator delete(ptr, (std::align_val_t)32) #define PAGE_ALLOC(size) ::operator new(size, (std::align_val_t)512) #define PAGE_FREE(ptr) ::operator delete(ptr, (std::align_val_t)512) +#define BLOCK_ALLOC(size, val) ::operator new(size, (std::align_val_t)val) +#define BLOCK_FREE(ptr, val) ::operator delete(ptr, (std::align_val_t)val) #endif -typedef std::int32_t SizeType; -typedef std::int32_t DimensionType; +#define ALIGN_ROUND(size) ((size) + 31) / 32 * 32 +#define ROUND_UP(size, val) (size + val - 1) / val * val -const SizeType MaxSize = (std::numeric_limits::max)(); -const float MinDist = (std::numeric_limits::min)(); -const float MaxDist = (std::numeric_limits::max)() / 10; -const float Epsilon = 0.000001f; -const std::uint16_t PageSize = 4096; -const int PageSizeEx = 12; + typedef std::int32_t SizeType; + typedef std::int32_t DimensionType; -extern std::mt19937 rg; - -extern std::shared_ptr(*f_createIO)(); + const SizeType MaxSize = (std::numeric_limits::max)(); + const float MinDist = (std::numeric_limits::min)(); + const float MaxDist = (std::numeric_limits::max)() / 10; + const float Epsilon = 0.000001f; + const std::uint16_t PageSize = 4096; + const int PageSizeEx = 12; + const std::chrono::microseconds MaxTimeout = (std::chrono::microseconds::max)(); #define IOBINARY(ptr, func, bytes, ...) if (ptr->func(bytes, __VA_ARGS__) != bytes) return ErrorCode::DiskIOFail #define IOSTRING(ptr, func, ...) if (ptr->func(__VA_ARGS__) == 0) return ErrorCode::DiskIOFail -extern std::shared_ptr g_pLogger; + extern Helper::LoggerHolder& GetLoggerHolder(); + extern std::shared_ptr GetLogger(); + extern void SetLogger(std::shared_ptr); -#define LOG(l, ...) g_pLogger->Logging("SPTAG", l, __FILE__, __LINE__, __FUNCTION__, __VA_ARGS__) +#define SPTAGLIB_LOG(l, ...) GetLogger()->Logging("SPTAG", l, __FILE__, __LINE__, __FUNCTION__, __VA_ARGS__) -class MyException : public std::exception -{ -private: - std::string Exp; -public: - MyException(std::string e) { Exp = e; } + class MyException : public std::exception + { + private: + std::string Exp; + public: + MyException(std::string e) { Exp = e; } #ifdef _MSC_VER - const char* what() const { return Exp.c_str(); } + const char* what() const { return Exp.c_str(); } #else - const char* what() const noexcept { return Exp.c_str(); } + const char* what() const noexcept { return Exp.c_str(); } #endif -}; + }; -enum class ErrorCode : std::uint16_t -{ + enum class ErrorCode : std::uint16_t + { #define DefineErrorCode(Name, Value) Name = Value, #include "DefinitionList.h" #undef DefineErrorCode - Undefined -}; -static_assert(static_cast(ErrorCode::Undefined) != 0, "Empty ErrorCode!"); + Undefined + }; + static_assert(static_cast(ErrorCode::Undefined) != 0, "Empty ErrorCode!"); -enum class DistCalcMethod : std::uint8_t -{ + enum class DistCalcMethod : std::uint8_t + { #define DefineDistCalcMethod(Name) Name, #include "DefinitionList.h" #undef DefineDistCalcMethod - Undefined -}; -static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!"); + Undefined + }; + static_assert(static_cast(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!"); -enum class VectorValueType : std::uint8_t -{ + enum class VectorValueType : std::uint8_t + { #define DefineVectorValueType(Name, Type) Name, #include "DefinitionList.h" #undef DefineVectorValueType - Undefined -}; -static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); + Undefined + }; + static_assert(static_cast(VectorValueType::Undefined) != 0, "Empty VectorValueType!"); -enum class IndexAlgoType : std::uint8_t -{ + enum class IndexAlgoType : std::uint8_t + { #define DefineIndexAlgo(Name) Name, #include "DefinitionList.h" #undef DefineIndexAlgo - Undefined -}; -static_assert(static_cast(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!"); + Undefined + }; + static_assert(static_cast(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!"); -enum class VectorFileType : std::uint8_t -{ + enum class VectorFileType : std::uint8_t + { #define DefineVectorFileType(Name) Name, #include "DefinitionList.h" #undef DefineVectorFileType - Undefined -}; -static_assert(static_cast(VectorFileType::Undefined) != 0, "Empty VectorFileType!"); + Undefined + }; + static_assert(static_cast(VectorFileType::Undefined) != 0, "Empty VectorFileType!"); -enum class TruthFileType : std::uint8_t -{ + enum class TruthFileType : std::uint8_t + { #define DefineTruthFileType(Name) Name, #include "DefinitionList.h" #undef DefineTruthFileType - Undefined -}; -static_assert(static_cast(TruthFileType::Undefined) != 0, "Empty TruthFileType!"); + Undefined + }; + static_assert(static_cast(TruthFileType::Undefined) != 0, "Empty TruthFileType!"); -template -constexpr VectorValueType GetEnumValueType() -{ - return VectorValueType::Undefined; -} + template + constexpr VectorValueType GetEnumValueType() + { + return VectorValueType::Undefined; + } #define DefineVectorValueType(Name, Type) \ @@ -212,10 +269,10 @@ constexpr VectorValueType GetEnumValueType() \ #undef DefineVectorValueType -inline std::size_t GetValueTypeSize(VectorValueType p_valueType) -{ - switch (p_valueType) + inline std::size_t GetValueTypeSize(VectorValueType p_valueType) { + switch (p_valueType) + { #define DefineVectorValueType(Name, Type) \ case VectorValueType::Name: \ return sizeof(Type); \ @@ -223,22 +280,52 @@ inline std::size_t GetValueTypeSize(VectorValueType p_valueType) #include "DefinitionList.h" #undef DefineVectorValueType - default: - break; - } + default: + break; + } - return 0; -} + return 0; + } -enum class QuantizerType : std::uint8_t -{ + enum class QuantizerType : std::uint8_t + { #define DefineQuantizerType(Name, Type) Name, #include "DefinitionList.h" #undef DefineQuantizerType - Undefined -}; -static_assert(static_cast(QuantizerType::Undefined) != 0, "Empty QuantizerType!"); + Undefined + }; + static_assert(static_cast(QuantizerType::Undefined) != 0, "Empty QuantizerType!"); + + enum class NumaStrategy : std::uint8_t + { +#define DefineNumaStrategy(Name) Name, +#include "DefinitionList.h" +#undef DefineNumaStrategy + + Undefined + }; + static_assert(static_cast(NumaStrategy::Undefined) != 0, "Empty NumaStrategy!"); + + enum class OrderStrategy : std::uint8_t + { +#define DefineOrderStrategy(Name) Name, +#include "DefinitionList.h" +#undef DefineOrderStrategy + + Undefined + }; + static_assert(static_cast(OrderStrategy::Undefined) != 0, "Empty OrderStrategy!"); + + enum class Storage : std::uint8_t + { +#define DefineStorage(Name) Name, +#include "DefinitionList.h" +#undef DefineStorage + + Undefined + }; + static_assert(static_cast(Storage::Undefined) != 0, "Empty Storage!"); } // namespace SPTAG diff --git a/AnnService/inc/Core/Common/BKTree.h b/AnnService/inc/Core/Common/BKTree.h index f920dd273..00019d97c 100644 --- a/AnnService/inc/Core/Common/BKTree.h +++ b/AnnService/inc/Core/Common/BKTree.h @@ -7,8 +7,8 @@ #include #include #include +#include #include - #include "inc/Core/VectorIndex.h" #include "CommonUtils.h" @@ -37,7 +37,7 @@ namespace SPTAG int _DK; DimensionType _D; DimensionType _RD; - int _T; + int _TH; DistCalcMethod _M; T* centers; T* newTCenters; @@ -52,26 +52,25 @@ namespace SPTAG std::function fComputeDistance; const std::shared_ptr& m_pQuantizer; - KmeansArgs(int k, DimensionType dim, SizeType datasize, int threadnum, DistCalcMethod distMethod, const std::shared_ptr& quantizer = nullptr) : _K(k), _DK(k), _D(dim), _RD(dim), _T(threadnum), _M(distMethod), m_pQuantizer(quantizer){ + KmeansArgs(int k, DimensionType dim, SizeType datasize, int threadnum, DistCalcMethod distMethod, const std::shared_ptr& quantizer = nullptr) : _K(k), _DK(k), _D(dim), _RD(dim), _TH(threadnum), _M(distMethod), m_pQuantizer(quantizer) { if (m_pQuantizer) { _RD = m_pQuantizer->ReconstructDim(); fComputeDistance = m_pQuantizer->DistanceCalcSelector(distMethod); } - else - { + else { fComputeDistance = COMMON::DistanceCalcSelector(distMethod); } centers = (T*)ALIGN_ALLOC(sizeof(T) * _K * _D); newTCenters = (T*)ALIGN_ALLOC(sizeof(T) * _K * _D); counts = new SizeType[_K]; - newCenters = new float[_T * _K * _RD]; - newCounts = new SizeType[_T * _K]; + newCenters = new float[_TH * _K * _RD]; + newCounts = new SizeType[_TH * _K]; label = new int[datasize]; - clusterIdx = new SizeType[_T * _K]; - clusterDist = new float[_T * _K]; + clusterIdx = new SizeType[_TH * _K]; + clusterDist = new float[_TH * _K]; weightedCounts = new float[_K]; - newWeightedCounts = new float[_T * _K]; + newWeightedCounts = new float[_TH * _K]; } ~KmeansArgs() { @@ -88,16 +87,16 @@ namespace SPTAG } inline void ClearCounts() { - memset(newCounts, 0, sizeof(SizeType) * _T * _K); - memset(newWeightedCounts, 0, sizeof(float) * _T * _K); + memset(newCounts, 0, sizeof(SizeType) * _TH * _K); + memset(newWeightedCounts, 0, sizeof(float) * _TH * _K); } inline void ClearCenters() { - memset(newCenters, 0, sizeof(float) * _T * _K * _RD); + memset(newCenters, 0, sizeof(float) * _TH * _K * _RD); } inline void ClearDists(float dist) { - for (int i = 0; i < _T * _K; i++) { + for (int i = 0; i < _TH * _K; i++) { clusterIdx[i] = -1; clusterDist[i] = dist; } @@ -109,7 +108,7 @@ namespace SPTAG for (int k = 1; k < _K; k++) pos[k] = pos[k - 1] + newCounts[k - 1]; for (int k = 0; k < _K; k++) { - if (newCounts[k] == 0) continue; + if (counts[k] == 0) continue; SizeType i = pos[k]; while (newCounts[k] > 0) { SizeType swapid = pos[label[i]] + newCounts[label[i]] - 1; @@ -157,7 +156,7 @@ namespace SPTAG } if (maxcluster != -1 && (args.clusterIdx[maxcluster] < 0 || args.clusterIdx[maxcluster] >= data.R())) - LOG(Helper::LogLevel::LL_Debug, "maxcluster:%d(%d) Error dist:%f\n", maxcluster, args.newCounts[maxcluster], args.clusterDist[maxcluster]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "maxcluster:%d(%d) Error dist:%f\n", maxcluster, args.newCounts[maxcluster], args.clusterDist[maxcluster]); float diff = 0; std::vector reconstructVector(args._RD, 0); @@ -179,19 +178,20 @@ namespace SPTAG for (DimensionType j = 0; j < args._RD; j++) { currCenters[j] /= args.counts[k]; } + if (args._M == DistCalcMethod::Cosine) { COMMON::Utils::Normalize(currCenters, args._RD, COMMON::Utils::GetBase()); } - + if (args.m_pQuantizer) { for (DimensionType j = 0; j < args._RD; j++) reconstructVector[j] = (R)(currCenters[j]); args.m_pQuantizer->QuantizeVector(reconstructVector.data(), (uint8_t*)TCenter); } else { - for (DimensionType j = 0; j < args._RD; j++) TCenter[j] = (T)(currCenters[j]); + for (DimensionType j = 0; j < args._D; j++) TCenter[j] = (T)(currCenters[j]); } } - diff += args.fComputeDistance(args.centers + k*args._D, TCenter, args._D); + diff += DistanceUtils::ComputeDistance(TCenter, args.centers + k * args._D, args._D, DistCalcMethod::L2); } return diff; } @@ -222,62 +222,84 @@ namespace SPTAG const SizeType first, const SizeType last, KmeansArgs& args, const bool updateCenters, float lambda) { float currDist = 0; - SizeType subsize = (last - first - 1) / args._T + 1; + SizeType subsize = (last - first - 1) / args._TH + 1; -#pragma omp parallel for num_threads(args._T) shared(data, indices) reduction(+:currDist) - for (int tid = 0; tid < args._T; tid++) + std::vector mythreads; + mythreads.reserve(args._TH); + for (int tid = 0; tid < args._TH; tid++) { - SizeType istart = first + tid * subsize; - SizeType iend = min(first + (tid + 1) * subsize, last); - SizeType *inewCounts = args.newCounts + tid * args._K; - float *inewCenters = args.newCenters + tid * args._K * args._RD; - SizeType * iclusterIdx = args.clusterIdx + tid * args._K; - float * iclusterDist = args.clusterDist + tid * args._K; - float * iweightedCounts = args.newWeightedCounts + tid * args._K; - float idist = 0; - R* reconstructVector = nullptr; - if (args.m_pQuantizer) reconstructVector = (R*)ALIGN_ALLOC(args.m_pQuantizer->ReconstructSize()); - - for (SizeType i = istart; i < iend; i++) { - int clusterid = 0; - float smallestDist = MaxDist; - for (int k = 0; k < args._DK; k++) { - float dist = args.fComputeDistance(data[indices[i]], args.centers + k*args._D, args._D) + lambda*args.counts[k]; - if (dist > -MaxDist && dist < smallestDist) { - clusterid = k; smallestDist = dist; - } - } - args.label[i] = clusterid; - inewCounts[clusterid]++; - iweightedCounts[clusterid] += smallestDist; - idist += smallestDist; - if (updateCenters) { - if (args.m_pQuantizer) { - args.m_pQuantizer->ReconstructVector((const uint8_t*)data[indices[i]], reconstructVector); - } - else { - reconstructVector = (R*)data[indices[i]]; + mythreads.emplace_back([tid, first, last, updateCenters, lambda, subsize, &data, &indices, &args, &currDist]() { + SizeType istart = first + tid * subsize; + SizeType iend = min(first + (tid + 1) * subsize, last); + SizeType *inewCounts = args.newCounts + tid * args._K; + float *inewCenters = args.newCenters + tid * args._K * args._RD; + SizeType *iclusterIdx = args.clusterIdx + tid * args._K; + float *iclusterDist = args.clusterDist + tid * args._K; + float *iweightedCounts = args.newWeightedCounts + tid * args._K; + float idist = 0; + R *reconstructVector = nullptr; + if (args.m_pQuantizer) + reconstructVector = (R *)ALIGN_ALLOC(args.m_pQuantizer->ReconstructSize()); + + for (SizeType i = istart; i < iend; i++) + { + int clusterid = 0; + float smallestDist = MaxDist; + for (int k = 0; k < args._DK; k++) + { + float dist = args.fComputeDistance(data[indices[i]], args.centers + k * args._D, args._D) + + lambda * args.counts[k]; + if (dist > -MaxDist && dist < smallestDist) + { + clusterid = k; + smallestDist = dist; + } } - float* center = inewCenters + clusterid*args._RD; - for (DimensionType j = 0; j < args._RD; j++) center[j] += reconstructVector[j]; - - if (smallestDist > iclusterDist[clusterid]) { - iclusterDist[clusterid] = smallestDist; - iclusterIdx[clusterid] = indices[i]; + args.label[i] = clusterid; + inewCounts[clusterid]++; + iweightedCounts[clusterid] += smallestDist; + idist += smallestDist; + if (updateCenters) + { + if (args.m_pQuantizer) + { + args.m_pQuantizer->ReconstructVector((const uint8_t *)data[indices[i]], + reconstructVector); + } + else + { + reconstructVector = (R *)data[indices[i]]; + } + float *center = inewCenters + clusterid * args._RD; + for (DimensionType j = 0; j < args._RD; j++) + center[j] += reconstructVector[j]; + + if (smallestDist > iclusterDist[clusterid]) + { + iclusterDist[clusterid] = smallestDist; + iclusterIdx[clusterid] = indices[i]; + } } - } - else { - if (smallestDist <= iclusterDist[clusterid]) { - iclusterDist[clusterid] = smallestDist; - iclusterIdx[clusterid] = indices[i]; + else + { + if (smallestDist <= iclusterDist[clusterid]) + { + iclusterDist[clusterid] = smallestDist; + iclusterIdx[clusterid] = indices[i]; + } } } - } - if (args.m_pQuantizer) ALIGN_FREE(reconstructVector); - currDist += idist; + if (args.m_pQuantizer) + ALIGN_FREE(reconstructVector); + COMMON::Utils::atomic_float_add(&currDist, idist); + }); } - - for (int i = 1; i < args._T; i++) { + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + for (int i = 1; i < args._TH; i++) { for (int k = 0; k < args._DK; k++) { args.newCounts[k] += args.newCounts[i * args._K + k]; args.newWeightedCounts[k] += args.newWeightedCounts[i * args._K + k]; @@ -285,7 +307,7 @@ namespace SPTAG } if (updateCenters) { - for (int i = 1; i < args._T; i++) { + for (int i = 1; i < args._TH; i++) { float* currCenter = args.newCenters + i*args._K*args._RD; for (size_t j = 0; j < ((size_t)args._DK) * args._RD; j++) args.newCenters[j] += currCenter[j]; @@ -298,7 +320,7 @@ namespace SPTAG } } else { - for (int i = 1; i < args._T; i++) { + for (int i = 1; i < args._TH; i++) { for (int k = 0; k < args._DK; k++) { if (args.clusterIdx[i*args._K + k] != -1 && args.clusterDist[i*args._K + k] <= args.clusterDist[k]) { args.clusterDist[k] = args.clusterDist[i*args._K + k]; @@ -346,6 +368,7 @@ namespace SPTAG float adjustedLambda = InitCenters(data, indices, first, last, args, samples, 3); if (abort && abort->ShouldAbort()) return 0; + std::mt19937 rg; SizeType batchEnd = min(first + samples, last); float currDiff, currDist, minClusterDist = MaxDist; int noImprovement = 0; @@ -374,12 +397,12 @@ namespace SPTAG for (int k = 0; k < args._DK; k++) { log += std::to_string(args.counts[k]) + " "; } - LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f lambda:(%f,%f) counts:%s\n", iter, currDist, originalLambda, adjustedLambda, log.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f lambda:(%f,%f) counts:%s\n", iter, currDist, originalLambda, adjustedLambda, log.c_str()); } */ currDiff = RefineCenters(data, args); - //if (debug) LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iter, currDist, currDiff); + //if (debug) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iter, currDist, currDiff); if (abort && abort->ShouldAbort()) return 0; if (currDiff < 1e-3 || noImprovement >= 5) break; @@ -406,7 +429,7 @@ namespace SPTAG if (args.counts[i] > 0) availableClusters++; } CountStd = sqrt(CountStd / args._DK) / CountAvg; - if (debug) LOG(Helper::LogLevel::LL_Info, "Lambda:min(%g,%g) Max:%d Min:%d Avg:%f Std/Avg:%f Dist:%f NonZero/Total:%d/%d\n", originalLambda, adjustedLambda, maxCount, minCount, CountAvg, CountStd, currDist, availableClusters, args._DK); + if (debug) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Lambda:min(%g,%g) Max:%d Min:%d Avg:%f Std/Avg:%f Dist:%f NonZero/Total:%d/%d\n", originalLambda, adjustedLambda, maxCount, minCount, CountAvg, CountStd, currDist, availableClusters, args._DK); return CountStd; } @@ -418,7 +441,7 @@ namespace SPTAG float bestLambdaFactor = 100.0f, bestCountStd = (std::numeric_limits::max)(); for (float lambdaFactor = 0.001f; lambdaFactor <= 1000.0f + 1e-3; lambdaFactor *= 10) { - float CountStd; + float CountStd = 0.0; if (args.m_pQuantizer) { switch (args.m_pQuantizer->GetReconstructType()) @@ -458,7 +481,7 @@ break; } } */ - LOG(Helper::LogLevel::LL_Info, "Best Lambda Factor:%f\n", bestLambdaFactor); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Best Lambda Factor:%f\n", bestLambdaFactor); return bestLambdaFactor; } @@ -509,7 +532,8 @@ break; m_iSamples(other.m_iSamples), m_fBalanceFactor(other.m_fBalanceFactor), m_lock(new std::shared_timed_mutex), - m_pQuantizer(other.m_pQuantizer) {} + m_pQuantizer(other.m_pQuantizer), + m_bfs(0) {} ~BKTree() {} inline const BKTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; } @@ -524,6 +548,13 @@ break; inline const std::unordered_map& GetSampleMap() const { return m_pSampleCenterMap; } + inline void SwapTree(BKTree& newTrees) + { + m_pTreeRoots.swap(newTrees.m_pTreeRoots); + m_pTreeStart.swap(newTrees.m_pTreeStart); + m_pSampleCenterMap.swap(newTrees.m_pSampleCenterMap); + } + template void Rebuild(const Dataset& data, DistCalcMethod distMethod, IAbortOperation* abort) { @@ -560,6 +591,7 @@ break; if (m_fBalanceFactor < 0) m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), args, m_iSamples); + std::mt19937 rg; m_pSampleCenterMap.clear(); for (char i = 0; i < m_iTreeNumber; i++) { @@ -567,12 +599,15 @@ break; m_pTreeStart.push_back((SizeType)m_pTreeRoots.size()); m_pTreeRoots.emplace_back((SizeType)localindices.size()); - LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d\n", i + 1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d\n", i + 1); ss.push(BKTStackItem(m_pTreeStart[i], 0, (SizeType)localindices.size(), true)); while (!ss.empty()) { - if (abort && abort->ShouldAbort()) return; - + if (abort && abort->ShouldAbort()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Abort!!!\n"); + return; + } BKTStackItem item = ss.top(); ss.pop(); m_pTreeRoots[item.index].childStart = (SizeType)m_pTreeRoots.size(); if (item.last - item.first <= m_iBKTLeafSize) { @@ -615,7 +650,7 @@ break; m_pTreeRoots[item.index].childEnd = (SizeType)m_pTreeRoots.size(); } m_pTreeRoots.emplace_back(-1); - LOG(Helper::LogLevel::LL_Info, "%d BKTree built, %zu %zu\n", i + 1, m_pTreeRoots.size() - m_pTreeStart[i], localindices.size()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built, %zu %zu\n", i + 1, m_pTreeRoots.size() - m_pTreeStart[i], localindices.size()); } } @@ -633,13 +668,13 @@ break; SizeType treeNodeSize = (SizeType)m_pTreeRoots.size(); IOBINARY(p_out, WriteBinary, sizeof(treeNodeSize), (char*)&treeNodeSize); IOBINARY(p_out, WriteBinary, sizeof(BKTNode) * treeNodeSize, (char*)m_pTreeRoots.data()); - LOG(Helper::LogLevel::LL_Info, "Save BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } ErrorCode SaveTrees(std::string sTreeFileName) const { - LOG(Helper::LogLevel::LL_Info, "Save BKT to %s\n", sTreeFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save BKT to %s\n", sTreeFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sTreeFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return SaveTrees(ptr); @@ -658,7 +693,7 @@ break; m_pTreeRoots.resize(treeNodeSize); memcpy(m_pTreeRoots.data(), pBKTMemFile, sizeof(BKTNode) * treeNodeSize); if (m_pTreeRoots.size() > 0 && m_pTreeRoots.back().centerid != -1) m_pTreeRoots.emplace_back(-1); - LOG(Helper::LogLevel::LL_Info, "Load BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } @@ -674,13 +709,13 @@ break; IOBINARY(p_input, ReadBinary, sizeof(BKTNode) * treeNodeSize, (char*)m_pTreeRoots.data()); if (m_pTreeRoots.size() > 0 && m_pTreeRoots.back().centerid != -1) m_pTreeRoots.emplace_back(-1); - LOG(Helper::LogLevel::LL_Info, "Load BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load BKT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } ErrorCode LoadTrees(std::string sTreeFileName) { - LOG(Helper::LogLevel::LL_Info, "Load BKT From %s\n", sTreeFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load BKT From %s\n", sTreeFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sTreeFileName.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; return LoadTrees(ptr); @@ -698,8 +733,12 @@ break; int MaxBFSNodes = 100; p_space.m_currBSPTQueue.Resize(MaxBFSNodes); p_space.m_nextBSPTQueue.Resize(MaxBFSNodes); Heap* p_curr = &p_space.m_currBSPTQueue, * p_next = &p_space.m_nextBSPTQueue; - p_curr->Top().distance = 1e9; + + for (SizeType begin = node.childStart; begin < node.childEnd; begin++) { + _mm_prefetch((const char*)(data[m_pTreeRoots[begin].centerid]), _MM_HINT_T0); + } + for (SizeType begin = node.childStart; begin < node.childEnd; begin++) { SizeType index = m_pTreeRoots[begin].centerid; float dist = fComputeDistance(p_query.GetQuantizedTarget(), data[index], data.C()); @@ -711,7 +750,7 @@ break; } } - for (int level = 1; level < 2; level++) { + for (int level = 1; level <= m_bfs; level++) { p_next->Top().distance = 1e9; while (!p_curr->empty()) { NodeDistPair tmp = p_curr->pop(); @@ -720,6 +759,9 @@ break; p_space.m_SPTQueue.insert(tmp); } else { + for (SizeType begin = tnode.childStart; begin < tnode.childEnd; begin++) { + _mm_prefetch((const char*)(data[m_pTreeRoots[begin].centerid]), _MM_HINT_T0); + } if (!p_space.CheckAndSet(tnode.centerid)) { p_space.m_NGQueue.insert(NodeDistPair(tnode.centerid, tmp.distance)); } @@ -743,6 +785,9 @@ break; } } else { + for (SizeType begin = node.childStart; begin < node.childEnd; begin++) { + _mm_prefetch((const char*)(data[m_pTreeRoots[begin].centerid]), _MM_HINT_T0); + } for (SizeType begin = node.childStart; begin < node.childEnd; begin++) { SizeType index = m_pTreeRoots[begin].centerid; p_space.m_SPTQueue.insert(NodeDistPair(begin, fComputeDistance(p_query.GetQuantizedTarget(), data[index], data.C()))); @@ -767,6 +812,9 @@ break; if (p_space.m_iNumberOfCheckedLeaves >= p_limits) break; } else { + for (SizeType begin = tnode.childStart; begin < tnode.childEnd; begin++) { + _mm_prefetch((const char*)(data[m_pTreeRoots[begin].centerid]), _MM_HINT_T0); + } if (!p_space.CheckAndSet(tnode.centerid)) { p_space.m_NGQueue.insert(NodeDistPair(tnode.centerid, bcell.distance)); } @@ -778,6 +826,32 @@ break; } } + template + std::string GetPriorityID(const Dataset& data, int p_queryID, std::function fComputeDistance) const { + std::string ret = ""; + for (char i = 0; i < m_iTreeNumber; i++) { + const BKTNode* node = &(m_pTreeRoots[m_pTreeStart[i]]); + while (node->childStart > 0) { + float minDist = MaxDist; + SizeType minIdx = -1; + for (SizeType begin = node->childStart; begin < node->childEnd; begin++) { + _mm_prefetch((const char*)(data[m_pTreeRoots[begin].centerid]), _MM_HINT_T0); + } + for (SizeType begin = node->childStart; begin < node->childEnd; begin++) { + SizeType index = m_pTreeRoots[begin].centerid; + float dist = fComputeDistance(data[p_queryID], data[index], data.C()); + if (dist < minDist) { + minDist = dist; + minIdx = begin; + } + } + ret += static_cast(minIdx - node->childStart); + node = &(m_pTreeRoots[minIdx]); + } + } + return ret; + } + private: std::vector m_pTreeStart; std::vector m_pTreeRoots; diff --git a/AnnService/inc/Core/Common/Checksum.h b/AnnService/inc/Core/Common/Checksum.h new file mode 100644 index 000000000..499bdb56f --- /dev/null +++ b/AnnService/inc/Core/Common/Checksum.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_CHECKSUM_H_ +#define _SPTAG_COMMON_CHECKSUM_H_ + +#include +#include "inc/Core/Common.h" + +namespace SPTAG +{ +typedef uint8_t ChecksumType; + +namespace COMMON +{ +class Checksum +{ + public: + Checksum() : m_type(0), m_seed(0), m_skip(false) + { + } + + void Initialize(bool p_skip, uint8_t p_type, int p_seed = 0) + { + m_type = p_type; + m_seed = p_seed; + m_skip = p_skip; + } + + ChecksumType CalcChecksum(const char *p_data, int p_length) + { + uint8_t cs = m_seed; + for (int i = 0; i < p_length; i++) + cs ^= p_data[i]; + return cs; + } + + ChecksumType AppendChecksum(ChecksumType p_checksum, const char* p_data, int p_length) + { + for (int i = 0; i < p_length; i++) + p_checksum ^= p_data[i]; + return p_checksum; + } + + bool ValidateChecksum(const char *p_data, int p_length, ChecksumType p_checksum) + { + if (m_skip) return true; + return (CalcChecksum(p_data, p_length) == p_checksum); + } + + private: + uint8_t m_type; + int m_seed; + bool m_skip; +}; +} // namespace COMMON +} // namespace SPTAG + +#endif // _SPTAG_COMMON_CHECKSUM_H_ \ No newline at end of file diff --git a/AnnService/inc/Core/Common/CommonUtils.h b/AnnService/inc/Core/Common/CommonUtils.h index 755994ced..905de2cd7 100644 --- a/AnnService/inc/Core/Common/CommonUtils.h +++ b/AnnService/inc/Core/Common/CommonUtils.h @@ -11,8 +11,8 @@ #include #include +#include #include -#include #include #include #include diff --git a/AnnService/inc/Core/Common/Dataset.h b/AnnService/inc/Core/Common/Dataset.h index 1b2cdbd9c..a5abaf743 100644 --- a/AnnService/inc/Core/Common/Dataset.h +++ b/AnnService/inc/Core/Common/Dataset.h @@ -4,11 +4,14 @@ #ifndef _SPTAG_COMMON_DATASET_H_ #define _SPTAG_COMMON_DATASET_H_ +#include "inc/Core/VectorIndex.h" + namespace SPTAG { namespace COMMON { // structure to save Data and Graph + /* template class Dataset { @@ -39,6 +42,12 @@ namespace SPTAG } void Initialize(SizeType rows_, DimensionType cols_, SizeType rowsInBlock_, SizeType capacity_, T* data_ = nullptr, bool shareOwnership_ = true) { + if (data != nullptr) { + if (ownData) ALIGN_FREE(data); + for (T* ptr : incBlocks) ALIGN_FREE(ptr); + incBlocks.clear(); + } + rows = rows_; cols = cols_; data = data_; @@ -73,11 +82,20 @@ namespace SPTAG inline const T* At(SizeType index) const { - if (index >= rows) { - SizeType incIndex = index - rows; - return incBlocks[incIndex >> rowsInBlockEx] + ((size_t)(incIndex & rowsInBlock)) * cols; + if (index < R() && index >= 0) + { + if (index >= rows) { + SizeType incIndex = index - rows; + return incBlocks[incIndex >> rowsInBlockEx] + ((size_t)(incIndex & rowsInBlock)) * cols; + } + return data + ((size_t)index) * cols; + } + else + { + std::ostringstream oss; + oss << "Index out of range in Dataset. Index: " << index << " Size: " << R(); + throw std::out_of_range(oss.str()); } - return data + ((size_t)index) * cols; } T* operator[](SizeType index) @@ -90,7 +108,7 @@ namespace SPTAG return At(index); } - ErrorCode AddBatch(const T* pData, SizeType num) + ErrorCode AddBatch(SizeType num, const T* pData = nullptr) { if (R() > maxRows - num) return ErrorCode::MemoryOverFlow; @@ -98,38 +116,20 @@ namespace SPTAG while (written < num) { SizeType curBlockIdx = ((incRows + written) >> rowsInBlockEx); if (curBlockIdx >= (SizeType)incBlocks.size()) { - T* newBlock = (T*)ALIGN_ALLOC(((size_t)rowsInBlock + 1) * cols * sizeof(T)); + T* newBlock = (T*)ALIGN_ALLOC(sizeof(T) * (rowsInBlock + 1) * cols); if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; + std::memset(newBlock, -1, sizeof(T) * (rowsInBlock + 1) * cols); incBlocks.push_back(newBlock); } SizeType curBlockPos = ((incRows + written) & rowsInBlock); SizeType toWrite = min(rowsInBlock + 1 - curBlockPos, num - written); - std::memcpy(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, pData + ((size_t)written) * cols, ((size_t)toWrite) * cols * sizeof(T)); + if (pData != nullptr) std::memcpy(incBlocks[curBlockIdx] + ((size_t)curBlockPos) * cols, pData + ((size_t)written) * cols, ((size_t)toWrite) * cols * sizeof(T)); written += toWrite; } incRows += written; return ErrorCode::Success; } - ErrorCode AddBatch(SizeType num) - { - if (R() > maxRows - num) return ErrorCode::MemoryOverFlow; - - SizeType written = 0; - while (written < num) { - SizeType curBlockIdx = (incRows + written) >> rowsInBlockEx; - if (curBlockIdx >= (SizeType)incBlocks.size()) { - T* newBlock = (T*)ALIGN_ALLOC(sizeof(T) * (rowsInBlock + 1) * cols); - if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; - std::memset(newBlock, -1, sizeof(T) * (rowsInBlock + 1) * cols); - incBlocks.push_back(newBlock); - } - written += min(rowsInBlock + 1 - ((incRows + written) & rowsInBlock), num - written); - } - incRows += written; - return ErrorCode::Success; - } - ErrorCode Save(std::shared_ptr p_out) const { SizeType CR = R(); @@ -143,13 +143,13 @@ namespace SPTAG SizeType remain = (incRows & rowsInBlock); if (remain > 0) IOBINARY(p_out, WriteBinary, sizeof(T) * cols * remain, (char*)incBlocks[blocks]); - LOG(Helper::LogLevel::LL_Info, "Save %s (%d,%d) Finish!\n", name.c_str(), CR, cols); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s (%d,%d) Finish!\n", name.c_str(), CR, cols); return ErrorCode::Success; } ErrorCode Save(std::string sDataPointsFileName) const { - LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return Save(ptr); @@ -162,13 +162,13 @@ namespace SPTAG Initialize(rows, cols, blockSize, capacity); IOBINARY(pInput, ReadBinary, sizeof(T) * cols * rows, (char*)data); - LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), rows, cols); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), rows, cols); return ErrorCode::Success; } ErrorCode Load(std::string sDataPointsFileName, SizeType blockSize, SizeType capacity) { - LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", name.c_str(), sDataPointsFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", name.c_str(), sDataPointsFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; return Load(ptr, blockSize, capacity); @@ -186,7 +186,7 @@ namespace SPTAG pDataPointsMemFile += sizeof(DimensionType); Initialize(R, C, blockSize, capacity, (T*)pDataPointsMemFile); - LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), R, C); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), R, C); return ErrorCode::Success; } @@ -209,19 +209,293 @@ namespace SPTAG for (SizeType i = 0; i < R; i++) { IOBINARY(output, WriteBinary, sizeof(T) * cols, (char*)At(indices[i])); } - LOG(Helper::LogLevel::LL_Info, "Save Refine %s (%d,%d) Finish!\n", name.c_str(), R, cols); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Refine %s (%d,%d) Finish!\n", name.c_str(), R, cols); return ErrorCode::Success; } ErrorCode Refine(const std::vector& indices, std::string sDataPointsFileName) const { - LOG(Helper::LogLevel::LL_Info, "Save Refine %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Refine %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return Refine(indices, ptr); } }; + */ + template + class Dataset + { + private: + std::string name = "Data"; + SizeType rows = 0; + DimensionType cols = 1; + char* data = nullptr; + bool ownData = false; + SizeType incRows = 0; + SizeType maxRows = MaxSize; + SizeType rowsInBlock = 1024 * 1024; + SizeType rowsInBlockEx = 20; + std::shared_ptr> incBlocks; + + DimensionType colStart = 0; + DimensionType mycols = 0; + + public: + Dataset() {} + + Dataset(SizeType rows_, DimensionType cols_, SizeType rowsInBlock_, SizeType capacity_, const void* data_ = nullptr, bool shareOwnership_ = true, std::shared_ptr> incBlocks_ = nullptr, int colStart_ = 0, int rowEnd_ = -1) + { + Initialize(rows_, cols_, rowsInBlock_, capacity_, data_, shareOwnership_, incBlocks_, colStart_, rowEnd_); + } + ~Dataset() + { + if (ownData) ALIGN_FREE(data); + if (incBlocks != nullptr) { + for (char* ptr : *incBlocks) ALIGN_FREE(ptr); + incBlocks->clear(); + } + } + + void Initialize(SizeType rows_, DimensionType cols_, SizeType rowsInBlock_, SizeType capacity_, const void* data_ = nullptr, bool shareOwnership_ = true, std::shared_ptr> incBlocks_ = nullptr, int colStart_ = 0, int rowEnd_ = -1) + { + if (data != nullptr) { + if (ownData) ALIGN_FREE(data); + for (char* ptr : *incBlocks) ALIGN_FREE(ptr); + incBlocks->clear(); + } + + rows = rows_; + if (rowEnd_ >= colStart_) cols = rowEnd_; + else cols = cols_ * sizeof(T); + data = (char*)data_; + if (data_ == nullptr || !shareOwnership_) + { + ownData = true; + data = (char*)ALIGN_ALLOC(((size_t)rows) * cols); + if (data_ != nullptr) memcpy(data, data_, ((size_t)rows) * cols); + else std::memset(data, -1, ((size_t)rows) * cols); + } + maxRows = capacity_; + rowsInBlockEx = static_cast(ceil(log2(rowsInBlock_))); + rowsInBlock = (1 << rowsInBlockEx) - 1; + incBlocks = incBlocks_; + if (incBlocks == nullptr) incBlocks.reset(new std::vector()); + incBlocks->reserve((static_cast(capacity_) + rowsInBlock) >> rowsInBlockEx); + + colStart = colStart_; + mycols = cols_; + } + + bool IsReady() const { return data != nullptr; } + + void SetName(const std::string& name_) { name = name_; } + const std::string& Name() const { return name; } + + void SetR(SizeType R_) + { + if (R_ >= rows) + incRows = R_ - rows; + else + { + rows = R_; + incRows = 0; + } + } + + inline SizeType R() const { return rows + incRows; } + inline const DimensionType& C() const { return mycols; } + inline std::uint64_t BufferSize() const { return sizeof(SizeType) + sizeof(DimensionType) + sizeof(T) * R() * C(); } + +#define GETITEM(index) \ + if (index >= rows) { \ + SizeType incIndex = index - rows; \ + return (T*)((*incBlocks)[incIndex >> rowsInBlockEx] + ((size_t)(incIndex & rowsInBlock)) * cols + colStart); \ + } \ + return (T*)(data + ((size_t)index) * cols + colStart); \ + + inline const T* At(SizeType index) const + { + GETITEM(index) + } + + inline T* At(SizeType index) + { + GETITEM(index) + } + + inline T* operator[](SizeType index) + { + GETITEM(index) + } + + inline const T* operator[](SizeType index) const + { + GETITEM(index) + } + +#undef GETITEM + + ErrorCode AddBatch(SizeType num, const T* pData = nullptr) + { + if (colStart != 0) return ErrorCode::Success; + if (R() > maxRows - num) return ErrorCode::MemoryOverFlow; + + SizeType written = 0; + while (written < num) { + SizeType curBlockIdx = ((incRows + written) >> rowsInBlockEx); + if (curBlockIdx >= (SizeType)(incBlocks->size())) { + char* newBlock = (char*)ALIGN_ALLOC(((size_t)rowsInBlock + 1) * cols); + if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; + std::memset(newBlock, -1, ((size_t)rowsInBlock + 1) * cols); + incBlocks->push_back(newBlock); + } + SizeType curBlockPos = ((incRows + written) & rowsInBlock); + SizeType toWrite = min(rowsInBlock + 1 - curBlockPos, num - written); + if (pData) { + for (int i = 0; i < toWrite; i++) { + std::memcpy((*incBlocks)[curBlockIdx] + ((size_t)curBlockPos + i) * cols + colStart, pData + ((size_t)written + i) * mycols, mycols * sizeof(T)); + } + } + written += toWrite; + } + incRows += written; + return ErrorCode::Success; + } + + ErrorCode Save(std::shared_ptr p_out) const + { + SizeType CR = R(); + IOBINARY(p_out, WriteBinary, sizeof(SizeType), (char*)&CR); + IOBINARY(p_out, WriteBinary, sizeof(DimensionType), (char*)&mycols); + for (SizeType i = 0; i < CR; i++) { + IOBINARY(p_out, WriteBinary, sizeof(T) * mycols, (char*)At(i)); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s (%d,%d) Finish!\n", name.c_str(), CR, mycols); + return ErrorCode::Success; + } + + ErrorCode Save(std::string sDataPointsFileName) const + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + return Save(ptr); + } + + ErrorCode Load(std::shared_ptr pInput, SizeType blockSize, SizeType capacity) + { + SizeType r; + DimensionType c; + IOBINARY(pInput, ReadBinary, sizeof(SizeType), (char*)&(r)); + IOBINARY(pInput, ReadBinary, sizeof(DimensionType), (char*)(&c)); + + if (data == nullptr) Initialize(r, c, blockSize, capacity); + else if (r > rows + incRows) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Read more data (%d, %d) than before (%d, %d)!\n", r, c, rows + incRows, mycols); + if(AddBatch(r - rows - incRows) != ErrorCode::Success) return ErrorCode::MemoryOverFlow; + } + + for (SizeType i = 0; i < r; i++) { + IOBINARY(pInput, ReadBinary, sizeof(T) * mycols, (char*)At(i)); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), r, c); + return ErrorCode::Success; + } + + ErrorCode Load(std::string sDataPointsFileName, SizeType blockSize, SizeType capacity) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", name.c_str(), sDataPointsFileName.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + return Load(ptr, blockSize, capacity); + } + + // Functions for loading models from memory mapped files + ErrorCode Load(char* pDataPointsMemFile, SizeType blockSize, SizeType capacity) + { + SizeType R; + DimensionType C; + R = *((SizeType*)pDataPointsMemFile); + pDataPointsMemFile += sizeof(SizeType); + + C = *((DimensionType*)pDataPointsMemFile); + pDataPointsMemFile += sizeof(DimensionType); + + Initialize(R, C, blockSize, capacity, (char*)pDataPointsMemFile); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", name.c_str(), R, C); + return ErrorCode::Success; + } + + ErrorCode Refine(const std::vector& indices, COMMON::Dataset& dataset) const + { + SizeType newrows = (SizeType)(indices.size()); + if (dataset.data == nullptr) dataset.Initialize(newrows, mycols, rowsInBlock + 1, static_cast(incBlocks->capacity() * (rowsInBlock + 1))); + + for (SizeType i = 0; i < newrows; i++) { + std::memcpy((void*)dataset.At(i), (void*)At(indices[i]), sizeof(T) * mycols); + } + return ErrorCode::Success; + } + + virtual ErrorCode Refine(const std::vector& indices, std::shared_ptr output) const + { + SizeType newrows = (SizeType)(indices.size()); + IOBINARY(output, WriteBinary, sizeof(SizeType), (char*)&newrows); + IOBINARY(output, WriteBinary, sizeof(DimensionType), (char*)&mycols); + + for (SizeType i = 0; i < newrows; i++) { + IOBINARY(output, WriteBinary, sizeof(T) * mycols, (char*)At(indices[i])); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Refine %s (%d,%d) Finish!\n", name.c_str(), newrows, C()); + return ErrorCode::Success; + } + + virtual ErrorCode Refine(const std::vector& indices, std::string sDataPointsFileName) const + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Refine %s To %s\n", name.c_str(), sDataPointsFileName.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(sDataPointsFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + return Refine(indices, ptr); + } + }; + + template + ErrorCode LoadOptDatasets(std::shared_ptr pVectorsInput, std::shared_ptr pGraphInput, + Dataset& pVectors, Dataset& pGraph, DimensionType pNeighborhoodSize, + SizeType blockSize, SizeType capacity) { + SizeType VR, GR; + DimensionType VC, GC; + IOBINARY(pVectorsInput, ReadBinary, sizeof(SizeType), (char*)&VR); + IOBINARY(pVectorsInput, ReadBinary, sizeof(DimensionType), (char*)&VC); + DimensionType totalC = ALIGN_ROUND(sizeof(T) * VC + sizeof(SizeType) * pNeighborhoodSize); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "OPT TotalC: %d\n", totalC); + char* data = (char*)ALIGN_ALLOC(((size_t)totalC) * VR); + std::shared_ptr> incBlocks(new std::vector()); + + pVectors.Initialize(VR, VC, blockSize, capacity, data, true, incBlocks, 0, totalC); + pVectors.SetName("Opt" + pVectors.Name()); + for (SizeType i = 0; i < VR; i++) { + IOBINARY(pVectorsInput, ReadBinary, sizeof(T) * VC, (char*)(pVectors.At(i))); + } + + IOBINARY(pGraphInput, ReadBinary, sizeof(SizeType), (char*)&GR); + IOBINARY(pGraphInput, ReadBinary, sizeof(DimensionType), (char*)&GC); + if (GR != VR || GC != pNeighborhoodSize) return ErrorCode::DiskIOFail; + + pGraph.Initialize(GR, GC, blockSize, capacity, data, false, incBlocks, sizeof(T) * VC, totalC); + pGraph.SetName("Opt" + pGraph.Name()); + for (SizeType i = 0; i < VR; i++) { + IOBINARY(pGraphInput, ReadBinary, sizeof(SizeType) * GC, (char*)(pGraph.At(i))); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", pVectors.Name().c_str(), pVectors.R(), pVectors.C()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s (%d,%d) Finish!\n", pGraph.Name().c_str(), pGraph.R(), pGraph.C()); + + return ErrorCode::Success; + } } } -#endif // _SPTAG_COMMON_DATASET_H_ \ No newline at end of file +#endif // _SPTAG_COMMON_DATASET_H_ diff --git a/AnnService/inc/Core/Common/DistanceUtils.h b/AnnService/inc/Core/Common/DistanceUtils.h index e46c0c84f..cfb02155d 100644 --- a/AnnService/inc/Core/Common/DistanceUtils.h +++ b/AnnService/inc/Core/Common/DistanceUtils.h @@ -4,7 +4,6 @@ #ifndef _SPTAG_COMMON_DISTANCEUTILS_H_ #define _SPTAG_COMMON_DISTANCEUTILS_H_ -#include #include #include @@ -103,17 +102,21 @@ namespace SPTAG return func(p1, p2, length); } + template static inline float ConvertCosineSimilarityToDistance(float cs) { // Cosine similarity is in [-1, 1], the higher the value, the closer are the two vectors. // However, the tree is built and searched based on "distance" between two vectors, that's >=0. The smaller the value, the closer are the two vectors. // So we do a linear conversion from a cosine similarity to a distance value. - return 1 - cs; //[1, 3] + int base = Utils::GetBase(); + return (1 - cs) * (base * base); } + template static inline float ConvertDistanceBackToCosineSimilarity(float d) { - return 1 - d; + int base = Utils::GetBase(); + return 1 - d / (base * base); } }; template diff --git a/AnnService/inc/Core/Common/FineGrainedLock.h b/AnnService/inc/Core/Common/FineGrainedLock.h index 98659dab2..6491cf0f3 100644 --- a/AnnService/inc/Core/Common/FineGrainedLock.h +++ b/AnnService/inc/Core/Common/FineGrainedLock.h @@ -29,14 +29,41 @@ namespace SPTAG unsigned index = hash_func((unsigned)idx); return m_locks[index]; } + + static inline unsigned hash_func(unsigned idx) + { + return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & PoolSize; + } + private: static const int PoolSize = 32767; std::unique_ptr m_locks; + }; - inline unsigned hash_func(unsigned idx) const + class FineGrainedRWLock { + public: + FineGrainedRWLock() { + m_locks.reset(new std::shared_timed_mutex[PoolSize + 1]); + } + ~FineGrainedRWLock() {} + + std::shared_timed_mutex& operator[](SizeType idx) { + unsigned index = hash_func((unsigned)idx); + return m_locks[index]; + } + + const std::shared_timed_mutex& operator[](SizeType idx) const { + unsigned index = hash_func((unsigned)idx); + return m_locks[index]; + } + + static inline unsigned hash_func(unsigned idx) { return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & PoolSize; } + private: + static const int PoolSize = 32767; + std::unique_ptr m_locks; }; } } diff --git a/AnnService/inc/Core/Common/Heap.h b/AnnService/inc/Core/Common/Heap.h index a5d544b1d..826ef376e 100644 --- a/AnnService/inc/Core/Common/Heap.h +++ b/AnnService/inc/Core/Common/Heap.h @@ -13,7 +13,7 @@ namespace SPTAG template class Heap { public: - Heap() : heap(nullptr), length(0), count(0) {} + Heap() : heap(nullptr), length(0), count(0),lastlevel(0) {} Heap(int size) { Resize(size); } @@ -24,10 +24,15 @@ namespace SPTAG count = 0; lastlevel = int(pow(2.0, floor(log2((float)size)))); } - ~Heap() {} + ~Heap() { heap.reset(); } inline int size() { return count; } inline bool empty() { return count == 0; } - inline void clear() { count = 0; } + inline void clear(int size) + { + if (size > length) Resize(size); + count = 0; + } + inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; } // Insert a new element in the heap. diff --git a/AnnService/inc/Core/Common/IQuantizer.h b/AnnService/inc/Core/Common/IQuantizer.h index cbf3a91be..26d23889a 100644 --- a/AnnService/inc/Core/Common/IQuantizer.h +++ b/AnnService/inc/Core/Common/IQuantizer.h @@ -5,6 +5,7 @@ #define _SPTAG_COMMON_QUANTIZER_H_ #include "inc/Core/Common.h" +#include "inc/Helper/DiskIO.h" #include #include "inc/Core/CommonDataStructure.h" #include "DistanceUtils.h" @@ -23,7 +24,7 @@ namespace SPTAG template std::function DistanceCalcSelector(SPTAG::DistCalcMethod p_method) const; - virtual void QuantizeVector(const void* vec, std::uint8_t* vecout) const = 0; + virtual void QuantizeVector(const void* vec, std::uint8_t* vecout, bool ADC = true) const = 0; virtual SizeType QuantizeSize() const = 0; @@ -58,6 +59,9 @@ namespace SPTAG virtual int GetBase() const = 0; virtual float* GetL2DistanceTables() = 0; + + template + T* GetCodebooks(); }; } } diff --git a/AnnService/inc/Core/Common/InstructionUtils.h b/AnnService/inc/Core/Common/InstructionUtils.h index e311061c6..4cd3703c1 100644 --- a/AnnService/inc/Core/Common/InstructionUtils.h +++ b/AnnService/inc/Core/Common/InstructionUtils.h @@ -6,8 +6,13 @@ #include #include +#ifndef GPU + #ifndef _MSC_VER #include +#include +#include + void cpuid(int info[4], int InfoType); #else @@ -15,6 +20,8 @@ void cpuid(int info[4], int InfoType); #define cpuid(info, x) __cpuidex(info, x, 0) #endif +#endif + namespace SPTAG { namespace COMMON { diff --git a/AnnService/inc/Core/Common/KDTree.h b/AnnService/inc/Core/Common/KDTree.h index 9a0fdef6f..90049f3c0 100644 --- a/AnnService/inc/Core/Common/KDTree.h +++ b/AnnService/inc/Core/Common/KDTree.h @@ -30,11 +30,11 @@ namespace SPTAG class KDTree { public: - KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000), m_lock(new std::shared_timed_mutex), m_pQuantizer(nullptr) {} + KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000), m_lock(new std::shared_timed_mutex), m_bOldVersion(false), m_pQuantizer(nullptr) {} KDTree(const KDTree& other) : m_iTreeNumber(other.m_iTreeNumber), m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit), - m_iSamples(other.m_iSamples), m_lock(new std::shared_timed_mutex), m_pQuantizer(other.m_pQuantizer) {} + m_iSamples(other.m_iSamples), m_lock(new std::shared_timed_mutex), m_bOldVersion(other.m_bOldVersion), m_pQuantizer(other.m_pQuantizer) {} ~KDTree() {} inline const KDTNode& operator[](SizeType index) const { return m_pTreeRoots[index]; } @@ -96,22 +96,48 @@ break; m_pTreeRoots.resize(m_iTreeNumber * localindices.size()); m_pTreeStart.resize(m_iTreeNumber, 0); -#pragma omp parallel for num_threads(numOfThreads) - for (int i = 0; i < m_iTreeNumber; i++) + std::vector mythreads; + mythreads.reserve(numOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < numOfThreads; tid++) { - if (abort && abort->ShouldAbort()) continue; - - Sleep(i * 100); std::srand(clock()); - - std::vector pindices(localindices.begin(), localindices.end()); - std::shuffle(pindices.begin(), pindices.end(), rg); - - m_pTreeStart[i] = i * (SizeType)pindices.size(); - LOG(Helper::LogLevel::LL_Info, "Start to build KDTree %d\n", i + 1); - SizeType iTreeSize = m_pTreeStart[i]; - DivideTree(data, pindices, 0, (SizeType)pindices.size() - 1, m_pTreeStart[i], iTreeSize, abort); - LOG(Helper::LogLevel::LL_Info, "%d KDTree built, %d %zu\n", i + 1, iTreeSize - m_pTreeStart[i], pindices.size()); + mythreads.emplace_back([&, tid]() { + size_t i = 0; + std::mt19937 rg; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iTreeNumber) + { + if (abort && abort->ShouldAbort()) + continue; + + Sleep(i * 100); + std::srand(clock()); + + std::vector pindices(localindices.begin(), localindices.end()); + std::shuffle(pindices.begin(), pindices.end(), rg); + + m_pTreeStart[i] = i * (SizeType)pindices.size(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build KDTree %d\n", i + 1); + SizeType iTreeSize = m_pTreeStart[i]; + DivideTree(data, pindices, 0, (SizeType)pindices.size() - 1, m_pTreeStart[i], + iTreeSize, abort); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d KDTree built, %d %zu\n", i + 1, + iTreeSize - m_pTreeStart[i], pindices.size()); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); } inline std::uint64_t BufferSize() const @@ -128,13 +154,13 @@ break; SizeType treeNodeSize = (SizeType)m_pTreeRoots.size(); IOBINARY(p_out, WriteBinary, sizeof(treeNodeSize), (char*)&treeNodeSize); IOBINARY(p_out, WriteBinary, sizeof(KDTNode) * treeNodeSize, (char*)m_pTreeRoots.data()); - LOG(Helper::LogLevel::LL_Info, "Save KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } ErrorCode SaveTrees(std::string sTreeFileName) const { - LOG(Helper::LogLevel::LL_Info, "Save KDT to %s\n", sTreeFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save KDT to %s\n", sTreeFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sTreeFileName.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return SaveTrees(ptr); @@ -152,7 +178,7 @@ break; pKDTMemFile += sizeof(SizeType); m_pTreeRoots.resize(treeNodeSize); memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize); - LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } @@ -186,7 +212,7 @@ break; } treeNodeSize += iNodeSize; } - LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } IOBINARY(p_input, ReadBinary, sizeof(m_iTreeNumber), (char*)&m_iTreeNumber); @@ -198,65 +224,40 @@ break; m_pTreeRoots.resize(treeNodeSize); IOBINARY(p_input, ReadBinary, sizeof(KDTNode) * treeNodeSize, (char*)m_pTreeRoots.data()); - LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load KDT (%d,%d) Finish!\n", m_iTreeNumber, treeNodeSize); return ErrorCode::Success; } ErrorCode LoadTrees(std::string sTreeFileName) { - LOG(Helper::LogLevel::LL_Info, "Load KDT From %s\n", sTreeFileName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load KDT From %s\n", sTreeFileName.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sTreeFileName.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; return LoadTrees(ptr); } - template + template void InitSearchTrees(const Dataset& p_data, std::function fComputeDistance, COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const { for (int i = 0; i < m_iTreeNumber; i++) { - KDTSearch(p_data, fComputeDistance, p_query, p_space, m_pTreeStart[i], 0); + KDTSearch(p_data, fComputeDistance, p_query, p_space, m_pTreeStart[i], 0); } } - template + template void SearchTrees(const Dataset& p_data, std::function fComputeDistance, COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const int p_limits) const { while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits) { auto& tcell = p_space.m_SPTQueue.pop(); - KDTSearch(p_data, fComputeDistance, p_query, p_space, tcell.node, tcell.distance); + KDTSearch(p_data, fComputeDistance, p_query, p_space, tcell.node, tcell.distance); } } private: - template - void KDTSearch(const Dataset& p_data, std::function fComputeDistance, COMMON::QueryResultSet& p_query, - COMMON::WorkSpace& p_space, const SizeType node, const float distBound) const - { - if (m_pQuantizer) - { - switch (m_pQuantizer->GetReconstructType()) - { -#define DefineVectorValueType(Name, Type) \ -case VectorValueType::Name: \ -return KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - } - else - { - return KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, distBound); - } - - } - template - void KDTSearchCore(const Dataset& p_data, std::function fComputeDistance, COMMON::QueryResultSet &p_query, + void KDTSearch(const Dataset& p_data, std::function fComputeDistance, COMMON::QueryResultSet &p_query, COMMON::WorkSpace& p_space, const SizeType node, const float distBound) const { if (node < 0) { @@ -292,7 +293,7 @@ return KDTSearchCore(p_data, fComputeDistance, p_query, p_space, node, } p_space.m_SPTQueue.insert(NodeDistPair(otherChild, distanceBound)); - KDTSearchCore(p_data, fComputeDistance, p_query, p_space, bestChild, distBound); + KDTSearch(p_data, fComputeDistance, p_query, p_space, bestChild, distBound); } diff --git a/AnnService/inc/Core/Common/Labelset.h b/AnnService/inc/Core/Common/Labelset.h index d310d0bf2..c34f563de 100644 --- a/AnnService/inc/Core/Common/Labelset.h +++ b/AnnService/inc/Core/Common/Labelset.h @@ -13,19 +13,28 @@ namespace SPTAG { class Labelset { + public: + enum class InvalidIDBehavior + { + Passthrough, + AlwaysContains, + AlwaysNotContains + }; private: std::atomic m_inserted; Dataset m_data; + InvalidIDBehavior m_invalidIDBehaviorSetting; public: - Labelset() + Labelset() { m_inserted = 0; m_data.SetName("DeleteID"); } - void Initialize(SizeType size, SizeType blockSize, SizeType capacity) + void Initialize(SizeType size, SizeType blockSize, SizeType capacity, InvalidIDBehavior invalidIDBehaviorSetting = InvalidIDBehavior::Passthrough) { + m_invalidIDBehaviorSetting = invalidIDBehaviorSetting; m_data.Initialize(size, 1, blockSize, capacity); } @@ -33,17 +42,64 @@ namespace SPTAG inline bool Contains(const SizeType& key) const { + if (key >= R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LabelSet::Contains: key %lld out of range %lld\n", (long long)key, (long long)R()); + switch (m_invalidIDBehaviorSetting) + { + case InvalidIDBehavior::AlwaysContains: + return true; + case InvalidIDBehavior::AlwaysNotContains: + return false; + default: {} + } + } + return *m_data[key] == 1; } inline bool Insert(const SizeType& key) { + if (key >= R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LabelSet::Insert: key %lld out of range %lld\n", (long long)key, (long long)R()); + switch (m_invalidIDBehaviorSetting) + { + case InvalidIDBehavior::AlwaysContains: + return true; + case InvalidIDBehavior::AlwaysNotContains: + return false; + default: {} + } + } + char oldvalue = InterlockedExchange8((char*)m_data[key], 1); if (oldvalue == 1) return false; m_inserted++; return true; } + inline bool Reset(const SizeType& key) + { + if (key >= R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LabelSet::Insert: key %lld out of range %lld\n", + (long long)key, (long long)R()); + switch (m_invalidIDBehaviorSetting) + { + case InvalidIDBehavior::AlwaysContains: + return true; + case InvalidIDBehavior::AlwaysNotContains: + return false; + default: {} + } + } + char oldvalue = InterlockedExchange8((char *)m_data[key], -1); + if (oldvalue == -1) return false; + m_inserted--; + return true; + } + inline ErrorCode Save(std::shared_ptr output) { SizeType deleted = m_inserted.load(); @@ -53,30 +109,32 @@ namespace SPTAG inline ErrorCode Save(std::string filename) { - LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_data.Name().c_str(), filename.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_data.Name().c_str(), filename.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return Save(ptr); } - inline ErrorCode Load(std::shared_ptr input, SizeType blockSize, SizeType capacity) + inline ErrorCode Load(std::shared_ptr input, SizeType blockSize, SizeType capacity, InvalidIDBehavior invalidIDBehaviorSetting = InvalidIDBehavior::Passthrough) { + m_invalidIDBehaviorSetting = invalidIDBehaviorSetting; SizeType deleted; IOBINARY(input, ReadBinary, sizeof(SizeType), (char*)&deleted); m_inserted = deleted; return m_data.Load(input, blockSize, capacity); } - inline ErrorCode Load(std::string filename, SizeType blockSize, SizeType capacity) + inline ErrorCode Load(std::string filename, SizeType blockSize, SizeType capacity, InvalidIDBehavior invalidIDBehaviorSetting = InvalidIDBehavior::Passthrough) { - LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", m_data.Name().c_str(), filename.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", m_data.Name().c_str(), filename.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; - return Load(ptr, blockSize, capacity); + return Load(ptr, blockSize, capacity, invalidIDBehaviorSetting); } - inline ErrorCode Load(char* pmemoryFile, SizeType blockSize, SizeType capacity) + inline ErrorCode Load(char* pmemoryFile, SizeType blockSize, SizeType capacity, InvalidIDBehavior invalidIDBehaviorSetting = InvalidIDBehavior::Passthrough) { + m_invalidIDBehaviorSetting = invalidIDBehaviorSetting; m_inserted = *((SizeType*)pmemoryFile); return m_data.Load(pmemoryFile + sizeof(SizeType), blockSize, capacity); } @@ -104,4 +162,4 @@ namespace SPTAG } } -#endif // _SPTAG_COMMON_LABELSET_H_ +#endif // _SPTAG_COMMON_LABELSET_H_ \ No newline at end of file diff --git a/AnnService/inc/Core/Common/NeighborhoodGraph.h b/AnnService/inc/Core/Common/NeighborhoodGraph.h index 991045072..92ff358b2 100644 --- a/AnnService/inc/Core/Common/NeighborhoodGraph.h +++ b/AnnService/inc/Core/Common/NeighborhoodGraph.h @@ -34,7 +34,8 @@ namespace SPTAG class NeighborhoodGraph { public: - NeighborhoodGraph() : m_iTPTNumber(32), + NeighborhoodGraph() : m_iGraphSize(0), + m_iTPTNumber(32), m_iTPTLeafSize(2000), m_iSamples(1000), m_numTopDimensionTPTSplit(5), @@ -52,7 +53,7 @@ namespace SPTAG m_iGPULeafSize(500), m_iheadNumGPUs(1), m_iTPTBalanceFactor(2), - m_rebuild(0) + m_rebuild(0), m_iThreadNum(1) {} ~NeighborhoodGraph() {} @@ -65,36 +66,62 @@ namespace SPTAG { DimensionType* correct = new DimensionType[samples]; -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < samples; i++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - SizeType x = COMMON::Utils::rand(m_iGraphSize); - //int x = i; - COMMON::QueryResultSet query(nullptr, m_iCEF); - for (SizeType y = 0; y < m_iGraphSize; y++) - { - if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue; - float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y)); - query.AddPoint(y, dist); - } - query.SortResult(); - SizeType* exact_rng = new SizeType[m_iNeighborhoodSize]; - RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF); - - correct[i] = 0; - for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) { - if (exact_rng[j] == -1) { - correct[i] += m_iNeighborhoodSize - j; - break; - } - for (DimensionType k = 0; k < m_iNeighborhoodSize; k++) - if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) { - correct[i]++; - break; + mythreads.emplace_back([&, tid](){ + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < samples) + { + SizeType x = COMMON::Utils::rand(m_iGraphSize); + // int x = i; + COMMON::QueryResultSet query(nullptr, m_iCEF); + for (SizeType y = 0; y < m_iGraphSize; y++) + { + if ((idmap != nullptr && idmap->find(y) != idmap->end())) + continue; + float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y)); + query.AddPoint(y, dist); + } + query.SortResult(); + SizeType *exact_rng = new SizeType[m_iNeighborhoodSize]; + RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF); + + correct[i] = 0; + for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) + { + if (exact_rng[j] == -1) + { + correct[i] += m_iNeighborhoodSize - j; + break; + } + for (DimensionType k = 0; k < m_iNeighborhoodSize; k++) + if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) + { + correct[i]++; + break; + } + } + delete[] exact_rng; } - } - delete[] exact_rng; + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); + float acc = 0; for (SizeType i = 0; i < samples; i++) acc += float(correct[i]); acc = acc / samples / m_iNeighborhoodSize; @@ -109,13 +136,8 @@ namespace SPTAG SizeType initSize; SPTAG::Helper::Convert::ConvertStringTo(index->GetParameter("NumberOfInitialDynamicPivots").c_str(), initSize); - if (index->m_pQuantizer) { - buildGraph(index, m_iGraphSize, m_iNeighborhoodSize, m_iTPTNumber, (int*)m_pNeighborhoodGraph[0], m_iGPURefineSteps, m_iGPURefineDepth, m_iGPUGraphType, m_iGPULeafSize, initSize, m_iheadNumGPUs, m_iTPTBalanceFactor); - } - else { - // Build the entire RNG graph, both builds the KNN and refines it to RNG - buildGraph(index, m_iGraphSize, m_iNeighborhoodSize, m_iTPTNumber, (int*)m_pNeighborhoodGraph[0], m_iGPURefineSteps, m_iGPURefineDepth, m_iGPUGraphType, m_iGPULeafSize, initSize, m_iheadNumGPUs, m_iTPTBalanceFactor); - } + // Build the entire RNG graph, both builds the KNN and refines it to RNG + buildGraph(index, m_iGraphSize, m_iNeighborhoodSize, m_iTPTNumber, (int*)m_pNeighborhoodGraph[0], m_iGPURefineSteps, m_iGPURefineDepth, m_iGPUGraphType, m_iGPULeafSize, initSize, m_iheadNumGPUs, m_iTPTBalanceFactor); if (idmap != nullptr) { std::unordered_map::const_iterator iter; @@ -149,7 +171,6 @@ break; } else { - printf("No quantizer!\n"); PartitionByTptreeCore(index, indices, first, last, leaves); } } @@ -317,51 +338,99 @@ break; (NeighborhoodDists)[i][j] = MaxDist; auto t1 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Parallel TpTree Partition begin\n"); -#pragma omp parallel for schedule(dynamic) - for (int i = 0; i < m_iTPTNumber; i++) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Parallel TpTree Partition begin\n"); { - Sleep(i * 100); std::srand(clock()); - for (SizeType j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j; - std::shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end(), rg); - PartitionByTptree(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]); - LOG(Helper::LogLevel::LL_Info, "Finish Getting Leaves for Tree %d\n", i); + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + std::mt19937 rg; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iTPTNumber) + { + Sleep(i * 100); + std::srand(clock()); + for (SizeType j = 0; j < m_iGraphSize; j++) + TptreeDataIndices[i][j] = j; + std::shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end(), rg); + PartitionByTptree(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, + TptreeLeafNodes[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish Getting Leaves for Tree %d\n", i); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); } - LOG(Helper::LogLevel::LL_Info, "Parallel TpTree Partition done\n"); - auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Build TPTree time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); - for(int i=0; i<10; i++) { - for(int j=0; j<20; j++) { - std::cout << static_cast(((uint8_t*)index->GetSample(i))[j]) << ", "; - } - std::cout << std::endl; - } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Parallel TpTree Partition done\n"); + auto t2 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build TPTree time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); for (int i = 0; i < m_iTPTNumber; i++) { -#pragma omp parallel for schedule(dynamic) - for (SizeType j = 0; j < (SizeType)TptreeLeafNodes[i].size(); j++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - SizeType start_index = TptreeLeafNodes[i][j].first; - SizeType end_index = TptreeLeafNodes[i][j].second; - if ((j * 5) % TptreeLeafNodes[i].size() == 0) LOG(Helper::LogLevel::LL_Info, "Processing Tree %d %d%%\n", i, static_cast(j * 1.0 / TptreeLeafNodes[i].size() * 100)); - for (SizeType x = start_index; x < end_index; x++) - { - for (SizeType y = x + 1; y <= end_index; y++) + mythreads.emplace_back([&, tid]() { + size_t j = 0; + while (true) { - SizeType p1 = TptreeDataIndices[i][x]; - SizeType p2 = TptreeDataIndices[i][y]; - float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2)); - if (idmap != nullptr) { - p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1); - p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2); + j = sent.fetch_add(1); + if (j < TptreeLeafNodes[i].size()) + { + SizeType start_index = TptreeLeafNodes[i][j].first; + SizeType end_index = TptreeLeafNodes[i][j].second; + if ((j * 5) % TptreeLeafNodes[i].size() == 0) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Processing Tree %d %d%%\n", i, + static_cast(j * 1.0 / TptreeLeafNodes[i].size() * 100)); + for (SizeType x = start_index; x < end_index; x++) + { + for (SizeType y = x + 1; y <= end_index; y++) + { + SizeType p1 = TptreeDataIndices[i][x]; + SizeType p2 = TptreeDataIndices[i][y]; + float dist = + index->ComputeDistance(index->GetSample(p1), index->GetSample(p2)); + if (idmap != nullptr) + { + p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1); + p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2); + } + COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], + (NeighborhoodDists)[p1], m_iNeighborhoodSize); + COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], + (NeighborhoodDists)[p2], m_iNeighborhoodSize); + } + } + } + else + { + return; } - COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], (NeighborhoodDists)[p1], m_iNeighborhoodSize); - COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], (NeighborhoodDists)[p2], m_iNeighborhoodSize); } - } + }); } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); TptreeDataIndices[i].clear(); TptreeLeafNodes[i].clear(); } @@ -369,14 +438,14 @@ break; TptreeLeafNodes.clear(); auto t3 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Process TPTree time (s): %lld\n", std::chrono::duration_cast(t3 - t2).count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Process TPTree time (s): %lld\n", std::chrono::duration_cast(t3 - t2).count()); } #endif template void BuildGraph(VectorIndex* index, const std::unordered_map* idmap = nullptr) { - LOG(Helper::LogLevel::LL_Info, "build RNG graph!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "build RNG graph!\n"); m_iGraphSize = index->GetNumSamples(); m_iNeighborhoodSize = (DimensionType)(ceil(m_iNeighborhoodSize * m_fNeighborhoodScale) * (m_rebuild + 1)); @@ -384,25 +453,25 @@ break; if (m_iGraphSize < 1000) { RefineGraph(index, idmap); - LOG(Helper::LogLevel::LL_Info, "Build RNG Graph end!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build RNG Graph end!\n"); return; } auto t1 = std::chrono::high_resolution_clock::now(); BuildInitKNNGraph(index, idmap); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "BuildInitKNNGraph time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "BuildInitKNNGraph time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); RefineGraph(index, idmap); auto t3 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "BuildGraph time (s): %lld\n", std::chrono::duration_cast(t3 - t1).count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "BuildGraph time (s): %lld\n", std::chrono::duration_cast(t3 - t1).count()); if (m_rebuild) { m_iNeighborhoodSize = m_iNeighborhoodSize / 2; RebuildGraph(index, idmap); auto t4 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "ReBuildGraph time (s): %lld\n", std::chrono::duration_cast(t4 - t3).count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReBuildGraph time (s): %lld\n", std::chrono::duration_cast(t4 - t3).count()); } if (idmap != nullptr) { @@ -417,57 +486,111 @@ break; template void RebuildGraph(VectorIndex* index, const std::unordered_map* idmap = nullptr) { - std::vector indegree(m_iGraphSize); - -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < m_iGraphSize; i++) indegree[i] = 0; + std::vector indegree(m_iGraphSize, 0); auto t0 = std::chrono::high_resolution_clock::now(); -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < m_iGraphSize; i++) { - SizeType* outnodes = m_pNeighborhoodGraph[i]; - for (SizeType j = 0; j < m_iNeighborhoodSize; j++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - int node = outnodes[j]; - if (node >= 0) { - indegree[node]++; - } + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iGraphSize) + { + SizeType *outnodes = m_pNeighborhoodGraph[i]; + for (SizeType j = 0; j < m_iNeighborhoodSize; j++) + { + int node = outnodes[j]; + if (node >= 0) + { + indegree[node]++; + } + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); } auto t1 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Calculate Indegree time (s): %lld\n", std::chrono::duration_cast(t1 - t0).count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Calculate Indegree time (s): %lld\n", std::chrono::duration_cast(t1 - t0).count()); int rebuild_threshold = m_iNeighborhoodSize / 2; int rebuildstart = m_iNeighborhoodSize / 2; -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < m_iGraphSize; i++) + + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - SizeType* outnodes = m_pNeighborhoodGraph[i]; - std::vector reserve(2 * m_iNeighborhoodSize, false); - int total = 0; - for (SizeType j = rebuildstart; j < m_iNeighborhoodSize * 2; j++) - if ( outnodes[j] >= 0 && indegree[outnodes[j]] < rebuild_threshold) { - reserve[j] = true; - total++; - } + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iGraphSize) + { + SizeType *outnodes = m_pNeighborhoodGraph[i]; + std::vector reserve(2 * m_iNeighborhoodSize, false); + int total = 0; + for (SizeType j = rebuildstart; j < m_iNeighborhoodSize * 2; j++) + if (outnodes[j] >= 0 && indegree[outnodes[j]] < rebuild_threshold) + { + reserve[j] = true; + total++; + } + + for (SizeType j = rebuildstart; + j < m_iNeighborhoodSize * 2 && total < m_iNeighborhoodSize - rebuildstart; j++) + { + if (!reserve[j]) + { + reserve[j] = true; + total++; + } + } + for (SizeType j = rebuildstart, z = rebuildstart; j < m_iNeighborhoodSize; j++) + { + while (!reserve[z]) + z++; + if (outnodes[j] >= 0) + indegree[outnodes[j]] = indegree[outnodes[j]] - 1; + if (outnodes[z] >= 0) + indegree[outnodes[z]] = indegree[outnodes[z]] + 1; + outnodes[j] = outnodes[z]; + z++; + } + if ((i * 5) % m_iGraphSize == 0) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Rebuild %d%%\n", + static_cast(i * 1.0 / m_iGraphSize * 100)); - for (SizeType j = rebuildstart; j < m_iNeighborhoodSize * 2 && total < m_iNeighborhoodSize - rebuildstart; j++) { - if (!reserve[j]) { - reserve[j] = true; - total++; + } + else + { + return; + } } - } - for (SizeType j = rebuildstart, z = rebuildstart; j < m_iNeighborhoodSize; j++) { - while (!reserve[z]) z++; - if(outnodes[j] >= 0) indegree[outnodes[j]] = indegree[outnodes[j]] - 1; - if(outnodes[z] >= 0) indegree[outnodes[z]] = indegree[outnodes[z]] + 1; - outnodes[j] = outnodes[z]; - z++; - } - if ((i * 5) % m_iGraphSize == 0) LOG(Helper::LogLevel::LL_Info, "Rebuild %d%%\n", static_cast(i * 1.0 / m_iGraphSize * 100)); + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Rebuild RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Rebuild RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); } template @@ -476,31 +599,79 @@ break; for (int iter = 0; iter < m_iRefineIter - 1; iter++) { auto t1 = std::chrono::high_resolution_clock::now(); -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < m_iGraphSize; i++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iGraphSize) + { + RefineNode(index, i, false, false, (int)(m_iCEF * m_fCEFScale)); + if ((i * 5) % m_iGraphSize == 0) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine %d %d%%\n", iter, + static_cast(i * 1.0 / m_iGraphSize * 100)); + + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) { - RefineNode(index, i, false, false, (int)(m_iCEF * m_fCEFScale)); - if ((i * 5) % m_iGraphSize == 0) LOG(Helper::LogLevel::LL_Info, "Refine %d %d%%\n", iter, static_cast(i * 1.0 / m_iGraphSize * 100)); + t.join(); } + mythreads.clear(); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Refine RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); } m_iNeighborhoodSize = (DimensionType)(m_iNeighborhoodSize / m_fNeighborhoodScale); if (m_iRefineIter > 0) { auto t1 = std::chrono::high_resolution_clock::now(); -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < m_iGraphSize; i++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - RefineNode(index, i, false, false, m_iCEF); - if ((i * 5) % m_iGraphSize == 0) LOG(Helper::LogLevel::LL_Info, "Refine %d %d%%\n", m_iRefineIter - 1, static_cast(i * 1.0 / m_iGraphSize * 100)); + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_iGraphSize) + { + RefineNode(index, i, false, false, m_iCEF); + if ((i * 5) % m_iGraphSize == 0) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine %d %d%%\n", m_iRefineIter - 1, + static_cast(i * 1.0 / m_iGraphSize * 100)); + + } + else + { + return; + } + } + }); } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Refine RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine RNG time (s): %lld Graph Acc: %f\n", std::chrono::duration_cast(t2 - t1).count(), GraphAccuracyEstimation(index, 100, idmap)); } else { - LOG(Helper::LogLevel::LL_Info, "Graph Acc: %f\n", GraphAccuracyEstimation(index, 100, idmap)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Graph Acc: %f\n", GraphAccuracyEstimation(index, 100, idmap)); } } @@ -519,27 +690,52 @@ break; newGraph->m_iGraphSize = R; newGraph->m_iNeighborhoodSize = m_iNeighborhoodSize; -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < R; i++) + std::vector mythreads; + mythreads.reserve(m_iThreadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iThreadNum; tid++) { - if ((i * 5) % R == 0) LOG(Helper::LogLevel::LL_Info, "Refine %d%%\n", static_cast(i * 1.0 / R * 100)); + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < R) + { + if ((i * 5) % R == 0) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine %d%%\n", + static_cast(i * 1.0 / R * 100)); - SizeType* outnodes = newGraph->m_pNeighborhoodGraph[i]; + SizeType *outnodes = newGraph->m_pNeighborhoodGraph[i]; - COMMON::QueryResultSet query((const T*)index->GetSample(indices[i]), m_iCEF + 1); - index->RefineSearchIndex(query, false); - RebuildNeighbors(index, indices[i], outnodes, query.GetResults(), m_iCEF + 1); + COMMON::QueryResultSet query((const T *)index->GetSample(indices[i]), m_iCEF + 1); + index->RefineSearchIndex(query, false); + RebuildNeighbors(index, indices[i], outnodes, query.GetResults(), m_iCEF + 1); - std::unordered_map::const_iterator iter; - for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) - { - if (outnodes[j] >= 0 && outnodes[j] < reverseIndices.size()) outnodes[j] = reverseIndices[outnodes[j]]; - if (idmap != nullptr && (iter = idmap->find(outnodes[j])) != idmap->end()) outnodes[j] = iter->second; - } - if (idmap != nullptr && (iter = idmap->find(-1 - i)) != idmap->end()) - outnodes[m_iNeighborhoodSize - 1] = -2 - iter->second; - } + std::unordered_map::const_iterator iter; + for (DimensionType j = 0; j < m_iNeighborhoodSize; j++) + { + if (outnodes[j] >= 0 && outnodes[j] < reverseIndices.size()) + outnodes[j] = reverseIndices[outnodes[j]]; + if (idmap != nullptr && (iter = idmap->find(outnodes[j])) != idmap->end()) + outnodes[j] = iter->second; + } + if (idmap != nullptr && (iter = idmap->find(-1 - i)) != idmap->end()) + outnodes[m_iNeighborhoodSize - 1] = -2 - iter->second; + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); if (output != nullptr) newGraph->SaveGraph(output); return ErrorCode::Success; } @@ -610,7 +806,7 @@ break; ErrorCode SaveGraph(std::string sGraphFilename) const { - LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_pNeighborhoodGraph.Name().c_str(), sGraphFilename.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_pNeighborhoodGraph.Name().c_str(), sGraphFilename.c_str()); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(sGraphFilename.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; return SaveGraph(ptr); @@ -623,7 +819,7 @@ break; for (int i = 0; i < m_iGraphSize; i++) IOBINARY(output, WriteBinary, sizeof(SizeType) * m_iNeighborhoodSize, (char*)m_pNeighborhoodGraph[i]); - LOG(Helper::LogLevel::LL_Info, "Save %s (%d,%d) Finish!\n", m_pNeighborhoodGraph.Name().c_str(), m_iGraphSize, m_iNeighborhoodSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s (%d,%d) Finish!\n", m_pNeighborhoodGraph.Name().c_str(), m_iGraphSize, m_iNeighborhoodSize); return ErrorCode::Success; } @@ -665,7 +861,8 @@ break; int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit; DimensionType m_iNeighborhoodSize; float m_fNeighborhoodScale, m_fCEFScale, m_fRNGFactor; - int m_iRefineIter, m_iCEF, m_iAddCEF, m_iMaxCheckForRefineGraph, m_iGPUGraphType, m_iGPURefineSteps, m_iGPURefineDepth, m_iGPULeafSize, m_iheadNumGPUs, m_iTPTBalanceFactor, m_rebuild; + int m_iRefineIter, m_iCEF, m_iAddCEF, m_iMaxCheckForRefineGraph, m_iGPUGraphType, m_iGPURefineSteps, + m_iGPURefineDepth, m_iGPULeafSize, m_iheadNumGPUs, m_iTPTBalanceFactor, m_rebuild, m_iThreadNum; }; } } diff --git a/AnnService/inc/Core/Common/OPQQuantizer.h b/AnnService/inc/Core/Common/OPQQuantizer.h index dae464c5a..533fb3ea7 100644 --- a/AnnService/inc/Core/Common/OPQQuantizer.h +++ b/AnnService/inc/Core/Common/OPQQuantizer.h @@ -26,7 +26,7 @@ namespace SPTAG OPQQuantizer(DimensionType NumSubvectors, SizeType KsPerSubvector, DimensionType DimPerSubvector, bool EnableADC, std::unique_ptr&& Codebooks, std::unique_ptr&& OPQMatrix); - virtual void QuantizeVector(const void* vec, std::uint8_t* vecout) const; + virtual void QuantizeVector(const void* vec, std::uint8_t* vecout, bool ADC = true) const; void ReconstructVector(const std::uint8_t* qvec, void* vecout) const; @@ -94,7 +94,7 @@ namespace SPTAG } template - void OPQQuantizer::QuantizeVector(const void* vec, std::uint8_t* vecout) const + void OPQQuantizer::QuantizeVector(const void* vec, std::uint8_t* vecout, bool ADC) const { OPQMatrixType* mat_vec = (OPQMatrixType*) ALIGN_ALLOC(sizeof(OPQMatrixType) * m_matrixDim); OPQMatrixType* typed_vec; @@ -111,7 +111,7 @@ namespace SPTAG } m_VectorMatrixMultiply(m_OPQMatrix_T.get(), typed_vec, mat_vec); - PQQuantizer::QuantizeVector(mat_vec, vecout); + PQQuantizer::QuantizeVector(mat_vec, vecout, ADC); ALIGN_FREE(mat_vec); ISNOTSAME(T, OPQMatrixType) @@ -142,7 +142,7 @@ namespace SPTAG IOBINARY(p_out, WriteBinary, sizeof(DimensionType), (char*)&m_DimPerSubvector); IOBINARY(p_out, WriteBinary, sizeof(OPQMatrixType) * m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector, (char*)m_codebooks.get()); IOBINARY(p_out, WriteBinary, sizeof(OPQMatrixType) * m_matrixDim * m_matrixDim, (char*)m_OPQMatrix.get()); - LOG(Helper::LogLevel::LL_Info, "Saving quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); return ErrorCode::Success; } @@ -157,7 +157,7 @@ namespace SPTAG m_matrixDim = m_NumSubvectors * m_DimPerSubvector; m_OPQMatrix = std::make_unique(m_matrixDim * m_matrixDim); IOBINARY(p_in, ReadBinary, sizeof(OPQMatrixType) * m_matrixDim * m_matrixDim, (char*)m_OPQMatrix.get()); - LOG(Helper::LogLevel::LL_Info, "After read OPQ Matrix.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read OPQ Matrix.\n"); m_InitMatrixTranspose(); return ErrorCode::Success; @@ -207,4 +207,4 @@ namespace SPTAG } } -#endif _SPTAG_COMMON_OPQQUANTIZER_H_ +#endif // _SPTAG_COMMON_OPQQUANTIZER_H_ diff --git a/AnnService/inc/Core/Common/PQQuantizer.h b/AnnService/inc/Core/Common/PQQuantizer.h index 4dc0aaf62..5bece0225 100644 --- a/AnnService/inc/Core/Common/PQQuantizer.h +++ b/AnnService/inc/Core/Common/PQQuantizer.h @@ -33,7 +33,7 @@ namespace SPTAG virtual float CosineDistance(const std::uint8_t* pX, const std::uint8_t* pY) const; - virtual void QuantizeVector(const void* vec, std::uint8_t* vecout) const; + virtual void QuantizeVector(const void* vec, std::uint8_t* vecout, bool ADC = true) const; virtual SizeType QuantizeSize() const; @@ -76,6 +76,8 @@ namespace SPTAG float* GetL2DistanceTables(); + T* GetCodebooks(); + protected: DimensionType m_NumSubvectors; SizeType m_KsPerSubvector; @@ -129,14 +131,14 @@ namespace SPTAG float PQQuantizer::CosineDistance(const std::uint8_t* pX, const std::uint8_t* pY) const // pX must be query distance table for ADC { - LOG(Helper::LogLevel::LL_Error, "Quantizer does not support CosineDistance!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Quantizer does not support CosineDistance!\n"); return 0; } template - void PQQuantizer::QuantizeVector(const void* vec, std::uint8_t* vecout) const + void PQQuantizer::QuantizeVector(const void* vec, std::uint8_t* vecout, bool ADC) const { - if (GetEnableADC()) + if (ADC && GetEnableADC()) { auto distCalc = DistanceCalcSelector(DistCalcMethod::L2); float* ADCtable = (float*) vecout; @@ -171,7 +173,7 @@ namespace SPTAG subcodebooks += m_DimPerSubvector; } assert(bestIndex != -1); - vecout[i] = bestIndex; + vecout[i] = (std::uint8_t)bestIndex; subvec += m_DimPerSubvector; } } @@ -232,52 +234,52 @@ namespace SPTAG IOBINARY(p_out, WriteBinary, sizeof(SizeType), (char*)&m_KsPerSubvector); IOBINARY(p_out, WriteBinary, sizeof(DimensionType), (char*)&m_DimPerSubvector); IOBINARY(p_out, WriteBinary, sizeof(T) * m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector, (char*)m_codebooks.get()); - LOG(Helper::LogLevel::LL_Info, "Saving quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); return ErrorCode::Success; } template ErrorCode PQQuantizer::LoadQuantizer(std::shared_ptr p_in) { - LOG(Helper::LogLevel::LL_Info, "Loading Quantizer.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading Quantizer.\n"); IOBINARY(p_in, ReadBinary, sizeof(DimensionType), (char*)&m_NumSubvectors); - LOG(Helper::LogLevel::LL_Info, "After read subvecs: %s.\n", std::to_string(m_NumSubvectors).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read subvecs: %s.\n", std::to_string(m_NumSubvectors).c_str()); IOBINARY(p_in, ReadBinary, sizeof(SizeType), (char*)&m_KsPerSubvector); - LOG(Helper::LogLevel::LL_Info, "After read ks: %s.\n", std::to_string(m_KsPerSubvector).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read ks: %s.\n", std::to_string(m_KsPerSubvector).c_str()); IOBINARY(p_in, ReadBinary, sizeof(DimensionType), (char*)&m_DimPerSubvector); - LOG(Helper::LogLevel::LL_Info, "After read dim: %s.\n", std::to_string(m_DimPerSubvector).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read dim: %s.\n", std::to_string(m_DimPerSubvector).c_str()); m_codebooks = std::make_unique(m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector); - LOG(Helper::LogLevel::LL_Info, "sizeof(T): %s.\n", std::to_string(sizeof(T)).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "sizeof(T): %s.\n", std::to_string(sizeof(T)).c_str()); IOBINARY(p_in, ReadBinary, sizeof(T) * m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector, (char*)m_codebooks.get()); - LOG(Helper::LogLevel::LL_Info, "After read codebooks.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read codebooks.\n"); m_BlockSize = m_KsPerSubvector * m_KsPerSubvector; InitializeDistanceTables(); - LOG(Helper::LogLevel::LL_Info, "Loaded quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loaded quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); return ErrorCode::Success; } template ErrorCode PQQuantizer::LoadQuantizer(std::uint8_t* raw_bytes) { - LOG(Helper::LogLevel::LL_Info, "Loading Quantizer.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading Quantizer.\n"); m_NumSubvectors = *(DimensionType*)raw_bytes; raw_bytes += sizeof(DimensionType); - LOG(Helper::LogLevel::LL_Info, "After read subvecs: %s.\n", std::to_string(m_NumSubvectors).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read subvecs: %s.\n", std::to_string(m_NumSubvectors).c_str()); m_KsPerSubvector = *(SizeType*)raw_bytes; raw_bytes += sizeof(SizeType); - LOG(Helper::LogLevel::LL_Info, "After read ks: %s.\n", std::to_string(m_KsPerSubvector).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read ks: %s.\n", std::to_string(m_KsPerSubvector).c_str()); m_DimPerSubvector = *(DimensionType*)raw_bytes; raw_bytes += sizeof(DimensionType); - LOG(Helper::LogLevel::LL_Info, "After read dim: %s.\n", std::to_string(m_DimPerSubvector).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read dim: %s.\n", std::to_string(m_DimPerSubvector).c_str()); m_codebooks = std::make_unique(m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector); - LOG(Helper::LogLevel::LL_Info, "sizeof(T): %s.\n", std::to_string(sizeof(T)).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "sizeof(T): %s.\n", std::to_string(sizeof(T)).c_str()); std::memcpy(m_codebooks.get(), raw_bytes, sizeof(T) * m_NumSubvectors * m_KsPerSubvector * m_DimPerSubvector); - LOG(Helper::LogLevel::LL_Info, "After read codebooks.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After read codebooks.\n"); m_BlockSize = m_KsPerSubvector * m_KsPerSubvector; InitializeDistanceTables(); - LOG(Helper::LogLevel::LL_Info, "Loaded quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loaded quantizer: Subvectors:%d KsPerSubvector:%d DimPerSubvector:%d\n", m_NumSubvectors, m_KsPerSubvector, m_DimPerSubvector); return ErrorCode::Success; } @@ -349,6 +351,11 @@ namespace SPTAG float* PQQuantizer::GetL2DistanceTables() { return (float*)(m_L2DistanceTables.get()); } + + template + T* PQQuantizer::GetCodebooks() { + return (T*)(m_codebooks.get()); + } } } diff --git a/AnnService/inc/Core/Common/PostingSizeRecord.h b/AnnService/inc/Core/Common/PostingSizeRecord.h new file mode 100644 index 000000000..8577a64f6 --- /dev/null +++ b/AnnService/inc/Core/Common/PostingSizeRecord.h @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_POSTINGSIZERECORD_H_ +#define _SPTAG_COMMON_POSTINGSIZERECORD_H_ + +#include +#include "Dataset.h" + +namespace SPTAG +{ + namespace COMMON + { + class PostingSizeRecord + { + private: + Dataset m_data; + + public: + PostingSizeRecord() + { + m_data.SetName("PostingSizeRecord"); + } + + void Initialize(SizeType size, SizeType blockSize, SizeType capacity) + { + m_data.Initialize(size, 1, blockSize, capacity); + } + + inline int GetSize(const SizeType& headID) + { + return *m_data[headID]; + } + + inline bool UpdateSize(const SizeType& headID, int newSize) + { + while (true) { + int oldSize = GetSize(headID); + if (InterlockedCompareExchange((unsigned*)m_data[headID], (unsigned)newSize, (unsigned)oldSize) == oldSize) { + return true; + } + } + } + + inline bool IncSize(const SizeType& headID, int appendNum) + { + while (true) { + int oldSize = GetSize(headID); + int newSize = oldSize + appendNum; + if (InterlockedCompareExchange((unsigned*)m_data[headID], (unsigned)newSize, (unsigned)oldSize) == oldSize) { + return true; + } + } + } + + inline SizeType GetPostingNum() + { + return m_data.R(); + } + + inline ErrorCode Save(std::shared_ptr output) + { + return m_data.Save(output); + } + + inline ErrorCode Save(const std::string& filename) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_data.Name().c_str(), filename.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + return Save(ptr); + } + + inline ErrorCode Load(std::shared_ptr input, SizeType blockSize, SizeType capacity) + { + return m_data.Load(input, blockSize, capacity); + } + + inline ErrorCode Load(const std::string& filename, SizeType blockSize, SizeType capacity) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", m_data.Name().c_str(), filename.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + return Load(ptr, blockSize, capacity); + } + + inline ErrorCode Load(char* pmemoryFile, SizeType blockSize, SizeType capacity) + { + return m_data.Load(pmemoryFile + sizeof(SizeType), blockSize, capacity); + } + + inline ErrorCode AddBatch(SizeType num) + { + return m_data.AddBatch(num); + } + + inline std::uint64_t BufferSize() const + { + return m_data.BufferSize() + sizeof(SizeType); + } + + inline void SetR(SizeType num) + { + m_data.SetR(num); + } + }; + } +} + +#endif // _SPTAG_COMMON_LABELSET_H_ diff --git a/AnnService/inc/Core/Common/QueryResultSet.h b/AnnService/inc/Core/Common/QueryResultSet.h index 3193b97ad..35fc9b554 100644 --- a/AnnService/inc/Core/Common/QueryResultSet.h +++ b/AnnService/inc/Core/Common/QueryResultSet.h @@ -43,42 +43,19 @@ class QueryResultSet : public QueryResult { } - inline void SetTarget(const T *p_target) - { - m_target = p_target; - if (m_quantizedTarget) - { - ALIGN_FREE(m_quantizedTarget); - } - m_quantizedTarget = nullptr; - m_quantizedSize = 0; - } - inline void SetTarget(const T* p_target, const std::shared_ptr& quantizer) { - m_target = p_target; - if (quantizer) - { - if (m_quantizedTarget && (m_quantizedSize == quantizer->QuantizeSize())) - { - quantizer->QuantizeVector((void*)p_target, (uint8_t*)m_quantizedTarget); - } - else - { - if (m_quantizedTarget) ALIGN_FREE(m_quantizedTarget); - m_quantizedSize = quantizer->QuantizeSize(); - m_quantizedTarget = ALIGN_ALLOC(m_quantizedSize); - quantizer->QuantizeVector((void*)p_target, (uint8_t*)m_quantizedTarget); - } - } + if (quantizer == nullptr) QueryResult::SetTarget((const void*)p_target); else { - if (m_quantizedTarget) + if (m_target == m_quantizedTarget || (m_quantizedSize != quantizer->QuantizeSize())) { - ALIGN_FREE(m_quantizedTarget); + if (m_target != m_quantizedTarget) ALIGN_FREE(m_quantizedTarget); + m_quantizedTarget = ALIGN_ALLOC(quantizer->QuantizeSize()); + m_quantizedSize = quantizer->QuantizeSize(); } - m_quantizedTarget = nullptr; - m_quantizedSize = 0; + m_target = p_target; + quantizer->QuantizeVector((void*)p_target, (uint8_t*)m_quantizedTarget); } } @@ -89,19 +66,7 @@ class QueryResultSet : public QueryResult T* GetQuantizedTarget() { - if (m_quantizedTarget) - { - return reinterpret_cast(m_quantizedTarget); - } - else - { - return (T*)reinterpret_cast(m_target); - } - } - - bool HasQuantizedTarget() - { - return m_quantizedTarget; + return reinterpret_cast(m_quantizedTarget); } inline float worstDist() const diff --git a/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h b/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h index 99adae3c0..28e9245f8 100644 --- a/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h +++ b/AnnService/inc/Core/Common/RelativeNeighborhoodGraph.h @@ -4,7 +4,6 @@ #ifndef _SPTAG_COMMON_RNG_H_ #define _SPTAG_COMMON_RNG_H_ -#include #include "NeighborhoodGraph.h" namespace SPTAG @@ -47,7 +46,9 @@ namespace SPTAG _mm_prefetch((const char*)(nodeVec), _MM_HINT_T0); _mm_prefetch((const char*)(insertVec), _MM_HINT_T0); for (DimensionType i = 0; i < m_iNeighborhoodSize; i++) { - _mm_prefetch((const char*)(index->GetSample(nodes[i])), _MM_HINT_T0); + auto futureNode = nodes[i]; + if (futureNode < 0) break; + _mm_prefetch((const char*)(index->GetSample(futureNode)), _MM_HINT_T0); } SizeType tmpNode; diff --git a/AnnService/inc/Core/Common/SIMDUtils.h b/AnnService/inc/Core/Common/SIMDUtils.h index 73e5e2334..9baf24412 100644 --- a/AnnService/inc/Core/Common/SIMDUtils.h +++ b/AnnService/inc/Core/Common/SIMDUtils.h @@ -4,7 +4,6 @@ #ifndef _SPTAG_COMMON_SIMDUTILS_H_ #define _SPTAG_COMMON_SIMDUTILS_H_ -#include #include #include diff --git a/AnnService/inc/Core/Common/TruthSet.h b/AnnService/inc/Core/Common/TruthSet.h index 08fb613f5..946b61845 100644 --- a/AnnService/inc/Core/Common/TruthSet.h +++ b/AnnService/inc/Core/Common/TruthSet.h @@ -23,14 +23,14 @@ namespace SPTAG { truth[i].clear(); if (ptr->ReadString(lineBufferSize, currentLine, '\n') == 0) { - LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); exit(1); } char* tmp = strtok(currentLine.get(), " "); for (int j = 0; j < K; ++j) { if (tmp == nullptr) { - LOG(Helper::LogLevel::LL_Error, "Truth number(%d, %d) and query number(%d) are not match!\n", i, j, p_iTruthNumber); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number(%d, %d) and query number(%d) are not match!\n", i, j, p_iTruthNumber); exit(1); } int vid = std::atoi(tmp); @@ -47,12 +47,12 @@ namespace SPTAG std::vector vec(K); for (int i = 0; i < p_iTruthNumber; i++) { if (ptr->ReadBinary(4, (char*)&originalK) != 4 || originalK < K) { - LOG(Helper::LogLevel::LL_Error, "Error: Xvec file has No.%d vector whose dims are fewer than expected. Expected: %d, Fact: %d\n", i, K, originalK); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Xvec file has No.%d vector whose dims are fewer than expected. Expected: %d, Fact: %d\n", i, K, originalK); exit(1); } if (originalK > K) vec.resize(originalK); if (ptr->ReadBinary(originalK * 4, (char*)vec.data()) != originalK * 4) { - LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); exit(1); } truth[i].insert(vec.begin(), vec.begin() + K); @@ -63,7 +63,7 @@ namespace SPTAG if (ptr->TellP() == 0) { int row; if (ptr->ReadBinary(4, (char*)&row) != 4 || ptr->ReadBinary(4, (char*)&originalK) != 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to read truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read truth file!\n"); exit(1); } } @@ -73,7 +73,7 @@ namespace SPTAG for (int i = 0; i < p_iTruthNumber; i++) { if (ptr->ReadBinary(4 * originalK, (char*)vec.data()) != 4 * originalK) { - LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, p_iTruthNumber); exit(1); } truth[i].insert(vec.begin(), vec.begin() + K); @@ -95,7 +95,7 @@ namespace SPTAG } else { - LOG(Helper::LogLevel::LL_Error, "TruthFileType Unsupported.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "TruthFileType Unsupported.\n"); exit(1); } } @@ -103,7 +103,7 @@ namespace SPTAG static void writeTruthFile(const std::string truthFile, SizeType queryNumber, const int K, std::vector>& truthset, std::vector>& distset, SPTAG::TruthFileType TFT) { auto ptr = SPTAG::f_createIO(); if (ptr == nullptr || !ptr->Initialize(truthFile.c_str(), std::ios::out | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", truthFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", truthFile.c_str()); exit(1); } @@ -114,12 +114,12 @@ namespace SPTAG for (int k = 0; k < K; k++) { if (ptr->WriteString((std::to_string(truthset[i][k]) + " ").c_str()) == 0) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); exit(1); } } if (ptr->WriteString("\n") == 0) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); exit(1); } } @@ -129,7 +129,7 @@ namespace SPTAG for (SizeType i = 0; i < queryNumber; i++) { if (ptr->WriteBinary(sizeof(K), (char*)&K) != sizeof(K) || ptr->WriteBinary(K * 4, (char*)(truthset[i].data())) != K * 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); exit(1); } } @@ -141,83 +141,27 @@ namespace SPTAG for (SizeType i = 0; i < queryNumber; i++) { if (ptr->WriteBinary(K * 4, (char*)(truthset[i].data())) != K * 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); exit(1); } } for (SizeType i = 0; i < queryNumber; i++) { if (ptr->WriteBinary(K * 4, (char*)(distset[i].data())) != K * 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); exit(1); } } } else { - LOG(Helper::LogLevel::LL_Error, "Found unsupported file type for generating truth."); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Found unsupported file type for generating truth."); exit(-1); } } template static void GenerateTruth(std::shared_ptr querySet, std::shared_ptr vectorSet, const std::string truthFile, - const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr& quantizer) { - if (querySet->Dimension() != vectorSet->Dimension() && !quantizer) - { - LOG(Helper::LogLevel::LL_Error, "query and vector have different dimensions."); - exit(1); - } - - LOG(Helper::LogLevel::LL_Info, "Begin to generate truth for query(%d,%d) and doc(%d,%d)...\n", querySet->Count(), querySet->Dimension(), vectorSet->Count(), vectorSet->Dimension()); - std::vector< std::vector > truthset(querySet->Count(), std::vector(K, 0)); - std::vector< std::vector > distset(querySet->Count(), std::vector(K, 0)); - auto fComputeDistance = quantizer ? quantizer->DistanceCalcSelector(distMethod) : COMMON::DistanceCalcSelector(distMethod); -#pragma omp parallel for - for (int i = 0; i < querySet->Count(); ++i) - { - SPTAG::COMMON::QueryResultSet query((const T*)(querySet->GetVector(i)), K); - query.SetTarget((const T*)(querySet->GetVector(i)), quantizer); - for (SPTAG::SizeType j = 0; j < vectorSet->Count(); j++) - { - float dist = fComputeDistance(query.GetQuantizedTarget(), reinterpret_cast(vectorSet->GetVector(j)), vectorSet->Dimension()); - query.AddPoint(j, dist); - } - query.SortResult(); - - for (int k = 0; k < K; k++) - { - truthset[i][k] = (query.GetResult(k))->VID; - distset[i][k] = (query.GetResult(k))->Dist; - } - - } - LOG(Helper::LogLevel::LL_Info, "Start to write truth file...\n"); - writeTruthFile(truthFile, querySet->Count(), K, truthset, distset, p_truthFileType); - - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize((truthFile + ".dist.bin").c_str(), std::ios::out | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", (truthFile + ".dist.bin").c_str()); - exit(1); - } - - int int32_queryNumber = (int)querySet->Count(); - ptr->WriteBinary(4, (char*)&int32_queryNumber); - ptr->WriteBinary(4, (char*)&K); - - for (size_t i = 0; i < int32_queryNumber; i++) - { - for (int k = 0; k < K; k++) { - if (ptr->WriteBinary(4, (char*)(&(truthset[i][k]))) != 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); - exit(1); - } - if (ptr->WriteBinary(4, (char*)(&(distset[i][k]))) != 4) { - LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); - exit(1); - } - } - } - } + const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr& quantizer); template static float CalculateRecall(VectorIndex* index, std::vector& results, const std::vector>& truth, int K, int truthK, std::shared_ptr querySet, std::shared_ptr vectorSet, SizeType NumQuerys, std::ofstream* log = nullptr, bool debug = false, float* MRR = nullptr) @@ -277,11 +221,11 @@ namespace SPTAG std::sort(truthvec.begin(), truthvec.end()); for (int j = 0; j < truthvec.size(); j++) ll += std::to_string(truthvec[j].node) + "@" + std::to_string(truthvec[j].distance) + ","; - LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); ll = "ann:"; for (int j = 0; j < K; j++) ll += std::to_string(results[i].GetResult(j)->VID) + "@" + std::to_string(results[i].GetResult(j)->Dist) + ","; - LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); } } meanrecall /= NumQuerys; @@ -343,4 +287,5 @@ namespace SPTAG } } + #endif // _SPTAG_COMMON_TRUTHSET_H_ diff --git a/AnnService/inc/Core/Common/VersionLabel.h b/AnnService/inc/Core/Common/VersionLabel.h new file mode 100644 index 000000000..4305d244a --- /dev/null +++ b/AnnService/inc/Core/Common/VersionLabel.h @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_COMMON_VERSIONLABEL_H_ +#define _SPTAG_COMMON_VERSIONLABEL_H_ + +#include +#include "Dataset.h" + +namespace SPTAG +{ + namespace COMMON + { + class VersionLabel + { + private: + std::atomic m_deleted; + Dataset m_data; + + public: + VersionLabel() + { + m_deleted = 0; + m_data.SetName("versionLabelID"); + } + + void Initialize(SizeType size, SizeType blockSize, SizeType capacity) + { + m_data.Initialize(size, 1, blockSize, capacity); + } + + inline size_t Count() const { return m_data.R(); } + + inline size_t GetDeleteCount() const { return m_deleted.load();} + + inline bool Deleted(const SizeType& key) const + { + if (key < 0 || key >= m_data.R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error vid in VersionLabel Delete check: %d. Max Allowed: %d\n", key, m_data.R()); + return true; + } + return *m_data[key] == 0xfe; + } + + inline bool Delete(const SizeType& key) + { + if (key < 0 || key >= m_data.R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error vid in VersionLabel Delete operation: %d. Max Allowed: %d\n", key, m_data.R()); + return true; + } + uint8_t oldvalue = (uint8_t)InterlockedExchange8((char*)(m_data[key]), (char)0xfe); + if (oldvalue == 0xfe) return false; + m_deleted++; + return true; + } + + inline uint8_t GetVersion(const SizeType& key) + { + if (key < 0 || key >= m_data.R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error vid in VersionLabel GetVersion operation: %d. Max Allowed: %d\n", key, m_data.R()); + return 0xfe; + } + return *m_data[key]; + } + + inline void SetVersion(const SizeType& key, const uint8_t& version) + { + if (key < 0 || key >= m_data.R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error vid in VersionLabel SetVersion operation: %d. Max Allowed: %d\n", key, m_data.R()); + return; + } + (*m_data[key]) = version; + } + + inline bool IncVersion(const SizeType& key, uint8_t* newVersion) + { + if (key < 0 || key >= m_data.R()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error vid in VersionLabel IncVersion operation: %d. Max Allowed: %d\n", key, m_data.R()); + return false; + } + + while (true) { + if (Deleted(key)) return false; + uint8_t oldVersion = GetVersion(key); + *newVersion = (oldVersion+1) & 0x7f; + if (((uint8_t)InterlockedCompareExchange((char*)m_data[key], (char)*newVersion, (char)oldVersion)) == oldVersion) { + return true; + } + } + } + + inline SizeType GetVectorNum() + { + return m_data.R(); + } + + inline ErrorCode Save(std::shared_ptr output) + { + SizeType deleted = m_deleted.load(); + IOBINARY(output, WriteBinary, sizeof(SizeType), (char*)&deleted); + return m_data.Save(output); + } + + inline ErrorCode Save(const std::string& filename) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save %s To %s\n", m_data.Name().c_str(), filename.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + return Save(ptr); + } + + inline ErrorCode Load(std::shared_ptr input, SizeType blockSize, SizeType capacity) + { + SizeType deleted; + IOBINARY(input, ReadBinary, sizeof(SizeType), (char*)&deleted); + m_deleted = deleted; + return m_data.Load(input, blockSize, capacity); + } + + inline ErrorCode Load(const std::string& filename, SizeType blockSize, SizeType capacity) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load %s From %s\n", m_data.Name().c_str(), filename.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + return Load(ptr, blockSize, capacity); + } + + inline ErrorCode Load(char* pmemoryFile, SizeType blockSize, SizeType capacity) + { + m_deleted = *((SizeType*)pmemoryFile); + return m_data.Load(pmemoryFile + sizeof(SizeType), blockSize, capacity); + } + + inline ErrorCode AddBatch(SizeType num) + { + return m_data.AddBatch(num); + } + + inline std::uint64_t BufferSize() const + { + return m_data.BufferSize() + sizeof(SizeType); + } + + inline void SetR(SizeType num) + { + m_data.SetR(num); + } + }; + } +} + +#endif // _SPTAG_COMMON_LABELSET_H_ \ No newline at end of file diff --git a/AnnService/inc/Core/Common/WorkSpace.h b/AnnService/inc/Core/Common/WorkSpace.h index 2689095f9..0eca10ef8 100644 --- a/AnnService/inc/Core/Common/WorkSpace.h +++ b/AnnService/inc/Core/Common/WorkSpace.h @@ -9,11 +9,38 @@ #include "Heap.h" #include +#include namespace SPTAG { namespace COMMON { + template + class IWorkSpaceFactory + { + public: + virtual std::unique_ptr GetWorkSpace() = 0; + virtual void ReturnWorkSpace(std::unique_ptr ws) = 0; + }; + + template + class ThreadLocalWorkSpaceFactory : public IWorkSpaceFactory + { + public: + static thread_local std::unique_ptr m_workspace; + + virtual std::unique_ptr< WorkSpaceType> GetWorkSpace() override + { + return std::move(m_workspace); + } + + virtual void ReturnWorkSpace(std::unique_ptr ws) override + { + m_workspace = std::move(ws); + } + + }; + class OptHashPosVector { protected: @@ -48,7 +75,7 @@ namespace SPTAG public: OptHashPosVector(): m_secondHash(false), m_exp(2), m_poolSize(8191) {} - ~OptHashPosVector() {} + ~OptHashPosVector() { m_hashTable.reset(); } void Init(SizeType size, int exp) @@ -133,7 +160,7 @@ namespace SPTAG } DoubleSize(); - LOG(Helper::LogLevel::LL_Error, "Hash table is full! Set HashTableExponent to larger value (default is 2). NewHashTableExponent=%d NewPoolSize=%d\n", m_exp, m_poolSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Hash table is full! Set HashTableExponent to larger value (default is 2). NewHashTableExponent=%d NewPoolSize=%d\n", m_exp, m_poolSize); return _CheckAndSet(m_hashTable.get(), m_poolSize, true, idx); } }; @@ -147,6 +174,8 @@ namespace SPTAG public: DistPriorityQueue(): m_size(0), m_length(0), m_count(0) {} + ~DistPriorityQueue() { m_data.reset(); } + void Resize(int size_) { m_size = size_; m_data.reset(new float[size_ + 1]); @@ -198,8 +227,10 @@ namespace SPTAG } }; + class IWorkSpace {}; + // Variables for each single NN search - struct WorkSpace + struct WorkSpace : public IWorkSpace { WorkSpace() {} @@ -208,6 +239,10 @@ namespace SPTAG Initialize(other.m_iMaxCheck, other.nodeCheckStatus.HashTableExponent()); } + ~WorkSpace() { + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Delete workspace happens!\n"); + } + void Initialize(int maxCheck, int hashExp) { nodeCheckStatus.Init(maxCheck, hashExp); @@ -220,6 +255,7 @@ namespace SPTAG m_iNumberOfTreeCheckedLeaves = 0; m_iNumberOfCheckedLeaves = 0; m_iMaxCheck = maxCheck; + m_relaxedMono = false; } void Initialize(va_list& arg) @@ -232,8 +268,8 @@ namespace SPTAG void Reset(int maxCheck, int resultNum) { nodeCheckStatus.clear(); - m_SPTQueue.clear(); - m_NGQueue.clear(); + m_SPTQueue.clear(maxCheck * 10); + m_NGQueue.clear(maxCheck * 30); m_Results.clear(max(maxCheck / 16, resultNum)); m_iNumOfContinuousNoBetterPropagation = 0; @@ -241,6 +277,15 @@ namespace SPTAG m_iNumberOfTreeCheckedLeaves = 0; m_iNumberOfCheckedLeaves = 0; m_iMaxCheck = maxCheck; + m_relaxedMono = false; + } + + void ResetResult(int maxCheck, int resultNum) + { + m_Results.clear(max(maxCheck / 16, resultNum)); + m_iNumOfContinuousNoBetterPropagation = 0; + m_iNumberOfTreeCheckedLeaves = 0; + m_iNumberOfCheckedLeaves = 0; } inline bool CheckAndSet(SizeType idx) @@ -258,11 +303,12 @@ namespace SPTAG OptHashPosVector nodeCheckStatus; // counter for dynamic pivoting - int m_iNumOfContinuousNoBetterPropagation; - int m_iContinuousLimit; - int m_iNumberOfTreeCheckedLeaves; - int m_iNumberOfCheckedLeaves; - int m_iMaxCheck; + int m_iNumOfContinuousNoBetterPropagation = 0; + int m_iContinuousLimit = 128; + int m_iNumberOfTreeCheckedLeaves = 0; + int m_iNumberOfCheckedLeaves = 0; + int m_iMaxCheck = 8192; + bool m_relaxedMono = false; // Prioriy queue used for neighborhood graph Heap m_NGQueue; @@ -274,6 +320,7 @@ namespace SPTAG Heap m_nextBSPTQueue; DistPriorityQueue m_Results; + std::function m_filterFunc; }; } } diff --git a/AnnService/inc/Core/Common/cuda/Distance.hxx b/AnnService/inc/Core/Common/cuda/Distance.hxx index eaa182ec6..793ad70dd 100644 --- a/AnnService/inc/Core/Common/cuda/Distance.hxx +++ b/AnnService/inc/Core/Common/cuda/Distance.hxx @@ -32,6 +32,7 @@ #include #include #include +#include #include "params.h" #include "inc/Core/VectorIndex.h" @@ -45,14 +46,18 @@ using namespace SPTAG; // Templated infinity value template __host__ __device__ T INFTY() {} template<> __forceinline__ __host__ __device__ int INFTY() {return INT_MAX;} +template<> __forceinline__ __host__ __device__ uint32_t INFTY() {return UINT_MAX;} template<> __forceinline__ __host__ __device__ long long int INFTY() {return LLONG_MAX;} template<> __forceinline__ __host__ __device__ float INFTY() {return FLT_MAX;} //template<> __forceinline__ __host__ __device__ __half INFTY<__half>() {return FLT_MAX;} +template<> __forceinline__ __host__ __device__ int8_t INFTY() {return 127;} template<> __forceinline__ __host__ __device__ uint8_t INFTY() {return 255;} template __device__ T BASE() {} template<> __forceinline__ __device__ float BASE() {return 1;} -template<> __forceinline__ __device__ uint32_t BASE() {return 16384;} +template<> __forceinline__ __device__ int BASE() {return 16384;} +template<> __forceinline__ __device__ uint32_t BASE() {return 65536;} + template __forceinline__ __device__ SUMTYPE cosine(T* a, T* b) { @@ -64,6 +69,43 @@ __forceinline__ __device__ SUMTYPE cosine(T* a, T* b) { return BASE()-(total[0]+total[1]); } + +template +__device__ int32_t cosine_int8(int8_t* a, int8_t* b) { + int32_t prod=0; + int32_t src=0; + int32_t target=0; + + uint32_t* newA = reinterpret_cast(a); + uint32_t* newB = reinterpret_cast(b); + + for(int i=0; i() - prod; +} + +template +__device__ float cosine_int8_rfloat(int8_t* a, int8_t* b) { + float prod=0; + float src=0; + float target=0; + + uint32_t* newA = reinterpret_cast(a); + uint32_t* newB = reinterpret_cast(b); + + for(int i=0; i()) - prod; +} + template __forceinline__ __device__ SUMTYPE l2(T* aVec, T* bVec) { SUMTYPE total[2]={0,0}; @@ -78,6 +120,12 @@ __forceinline__ __device__ SUMTYPE l2(T* aVec, T* bVec) { template __device__ SUMTYPE dist(T* a, T* b) { if(metric == (int)DistMetric::Cosine) { + if(::cuda::std::is_same::value) { +// if(::cuda::std::is_same::value) { +// return cosine_int8_rfloat((int8_t*)a, (int8_t*)b); +// } + return cosine_int8((int8_t*)a, (int8_t*)b); + } return cosine(a, b); } else { @@ -133,10 +181,6 @@ __host__ void copyRawDataToMultiGPU(SPTAG::VectorIndex* index, T** d_data, size_ delete temp; } -/********************************************************************* -* DEPRICATED CODE - Old data structures used before code restructure -*********************************************************************/ - /********************************************************************* * Object representing a Dim-dimensional point, with each coordinate * represented by a element of datatype T @@ -172,6 +216,10 @@ class Point { return *this; } + __host__ __device__ Point& operator>(const Point& other) { + return id > other.id; + } + // Computes euclidean dist. Uses 2 registers to increase pipeline efficiency and ILP __device__ __host__ SUMTYPE l2(Point* other) { SUMTYPE total[2]={0,0}; @@ -255,7 +303,6 @@ class Point { } } - __host__ __device__ Point& operator=( const Point& other ) { for(int i=0; i +#include +#include +#include +#include +#include + +#include "inc/Core/VectorIndex.h" + +using namespace std; + +using namespace SPTAG; + +// Templated infinity value +/********************************************************************* +* Object representing a Dim-dimensional point, with each coordinate +* represented by a element of datatype T +* NOTE: Dim must be templated so that we can store coordinate values in registers +*********************************************************************/ +template +class Point { +public: + int id; + T coords[Dim]; + + __host__ void load(vector data) { + for (int i = 0; i < Dim; i++) { + coords[i] = data[i]; + } + } + + __host__ __device__ void loadChunk(T* data, int exact_dim) { + for (int i = 0; i < exact_dim; i++) { + coords[i] = data[i]; + } + for (int i = exact_dim; i < Dim; i++) { + coords[i] = 0; + } + } + + __host__ __device__ Point& operator=(const Point& other) { + for (int i = 0; i < Dim; i++) { + coords[i] = other.coords[i]; + } + id = other.id; + return *this; + } + + __host__ __device__ Point& operator>(const Point& other) { + return id > other.id; + } + + // Computes euclidean dist. Uses 2 registers to increase pipeline efficiency and ILP + __device__ __host__ SUMTYPE l2(Point* other) { + SUMTYPE total[2] = { 0,0 }; + + for (int i = 0; i < Dim; i += 2) { + total[0] += (coords[i] - other->coords[i]) * (coords[i] - other->coords[i]); + total[1] += (coords[i + 1] - other->coords[i + 1]) * (coords[i + 1] - other->coords[i + 1]); + } + return total[0] + total[1]; + } + + __device__ SUMTYPE l2_block(Point* other) { + SUMTYPE total = 0; + __shared__ SUMTYPE final_val; + final_val = 0; + __syncthreads(); + + for (int i = threadIdx.x; i < Dim; i += blockDim.x) { + total += (coords[i] - other->coords[i]) * (coords[i] - other->coords[i]); + } + atomicAdd(&final_val, total); + __syncthreads(); + return final_val; + } + + + // Computes Cosine dist. Uses 2 registers to increase pipeline efficiency and ILP + // Assumes coordinates are normalized so each vector is of unit length. This lets us + // perform a dot-product instead of the full cosine distance computation. + // __device__ SUMTYPE cosine(Point* other, bool test) {return NULL;} + __device__ SUMTYPE cosine(Point* other) { + SUMTYPE total[2] = { 0,0 }; + + for (int i = 0; i < Dim; i += 2) { + total[0] += ((SUMTYPE)((SUMTYPE)coords[i] * (SUMTYPE)other->coords[i])); + total[1] += ((SUMTYPE)((SUMTYPE)coords[i + 1] * (SUMTYPE)other->coords[i + 1])); + } + return (SUMTYPE)Dim - (total[0] + total[1]); + } + + __forceinline__ __device__ SUMTYPE dist(Point* other, int metric) { + if (metric == 0) return l2(other); + else return cosine(other); + } + +}; + +// Less-than operator between two points. +template +__host__ __device__ bool operator<(const Point& first, const Point& other) { + return first.id < other.id; +} + + +// Specialized version of Point structure for 1-byte datatype (int8) +// Packs coordinates into Dim/4 total integer values, and functions un-pack as needed +template +class Point { +public: + int id; + uint32_t coords[Dim / 4]; + + __host__ void load(vector data) { + for (int i = 0; i < Dim / 4; i++) { + coords[i] = 0; + for (int j = 0; j < 4; j++) { + coords[i] += (data[i * 4 + j] << (j * 8)); + } + } + } + + __host__ __device__ void loadChunk(uint8_t* data, int exact_dims) { + for (int i = 0; i < exact_dims / 4; i++) { + coords[i] = 0; + for (int j = 0; j < 4; j++) { + coords[i] += (data[i * 4 + j] << (j * 8)); + } + } + for (int i = exact_dims / 4; i < Dim / 4; i++) { + coords[i] = 0; + } + } + + __host__ __device__ Point& operator=(const Point& other) { + for (int i = 0; i < Dim / 4; i++) { + coords[i] = other.coords[i]; + } + id = other.id; + return *this; + } + + __device__ __host__ SUMTYPE l2_block(Point* other) { return 0; } + __device__ __host__ SUMTYPE l2(Point* other) { + + SUMTYPE totals[4] = { 0,0,0,0 }; + SUMTYPE temp[4]; + SUMTYPE temp_other[4]; + + for (int i = 0; i < Dim / 4; ++i) { + temp[0] = (coords[i] & 0x000000FF); + temp_other[0] = (other->coords[i] & 0x000000FF); + + temp[1] = (coords[i] & 0x0000FF00) >> 8; + temp_other[1] = (other->coords[i] & 0x0000FF00) >> 8; + + temp[2] = (coords[i] & 0x00FF0000) >> 16; + temp_other[2] = (other->coords[i] & 0x00FF0000) >> 16; + + temp[3] = (coords[i]) >> 24; + temp_other[3] = (other->coords[i]) >> 24; + + totals[0] += (temp[0] - temp_other[0]) * (temp[0] - temp_other[0]); + totals[1] += (temp[1] - temp_other[1]) * (temp[1] - temp_other[1]); + totals[2] += (temp[2] - temp_other[2]) * (temp[2] - temp_other[2]); + totals[3] += (temp[3] - temp_other[3]) * (temp[3] - temp_other[3]); + } + return totals[0] + totals[1] + totals[2] + totals[3]; + } + +#if __CUDA_ARCH__ > 610 // Use intrinsics if available for GPU being compiled for + + // With int8 datatype, values are packed into integers so they need to be + // unpacked while computing distance + __device__ SUMTYPE cosine(Point* other) { + uint32_t prod = 0; + uint32_t src = 0; + uint32_t target = 0; + + for (int i = 0; i < Dim / 4; ++i) { + src = coords[i]; + target = other->coords[i]; + prod = __dp4a(src, target, prod); + } + + return ((SUMTYPE)65536) - prod; + } + +#else + __device__ SUMTYPE cosine(Point* other) { + SUMTYPE prod[4]; + prod[0] = 0; + + for (int i = 0; i < Dim / 4; ++i) { + prod[0] += (coords[i] & 0x000000FF) * (other->coords[i] & 0x000000FF); + prod[1] = ((coords[i] & 0x0000FF00) >> 8) * ((other->coords[i] & 0x0000FF00) >> 8); + prod[2] = ((coords[i] & 0x00FF0000) >> 16) * ((other->coords[i] & 0x00FF0000) >> 16); + prod[3] = ((coords[i]) >> 24) * ((other->coords[i]) >> 24); + + prod[0] += prod[1] + prod[2] + prod[3]; + } + + return ((SUMTYPE)65536) - prod[0]; + } +#endif + + __forceinline__ __device__ SUMTYPE dist(Point* other, int metric) { + if (metric == 0) return l2(other); + else return cosine(other); + } + +}; + +// Specialized version of Point structure for SIGNED 1-byte datatype (int8) +// Packs coordinates into Dim/4 total integer values, and functions un-pack as needed +template +class Point { +public: + int id; + uint32_t coords[Dim / 4]; + + __host__ void load(vector data) { + + uint8_t* test = reinterpret_cast(data.data()); + for (int i = 0; i < Dim / 4; i++) { + coords[i] = 0; + for (int j = 0; j < 4; j++) { + coords[i] += ((test[i * 4 + j]) << (j * 8)); + } + } + } + + __host__ __device__ void loadChunk(int8_t* data, int exact_dims) { + + uint8_t* test = reinterpret_cast(data); + for (int i = 0; i < exact_dims / 4; i++) { + coords[i] = 0; + for (int j = 0; j < 4; j++) { + coords[i] += (test[i * 4 + j] << (j * 8)); + } + } + for (int i = exact_dims / 4; i < Dim / 4; i++) { + coords[i] = 0; + } + } + + __host__ int8_t getVal(int idx) { + if (idx % 4 == 0) { + return (int8_t)(coords[idx / 4] & 0x000000FF); + } + else if (idx % 4 == 1) { + return (int8_t)((coords[idx / 4] & 0x0000FF00) >> 8); + } + else if (idx % 4 == 2) { + return (int8_t)((coords[idx / 4] & 0x00FF0000) >> 16); + } + else if (idx % 4 == 3) { + return (int8_t)((coords[idx / 4]) >> 24); + } + return 0; + } + + __host__ __device__ Point& operator=(const Point& other) { + for (int i = 0; i < Dim / 4; i++) { + coords[i] = other.coords[i]; + } + id = other.id; + return *this; + } + __device__ __host__ SUMTYPE l2_block(Point* other) { return 0; } + __device__ __host__ SUMTYPE l2(Point* other) { + + SUMTYPE totals[4] = { 0,0,0,0 }; + int32_t temp[4]; + int32_t temp_other[4]; + + for (int i = 0; i < Dim / 4; ++i) { + temp[0] = (int8_t)(coords[i] & 0x000000FF); + temp_other[0] = (int8_t)(other->coords[i] & 0x000000FF); + + temp[1] = (int8_t)((coords[i] & 0x0000FF00) >> 8); + temp_other[1] = (int8_t)((other->coords[i] & 0x0000FF00) >> 8); + + temp[2] = (int8_t)((coords[i] & 0x00FF0000) >> 16); + temp_other[2] = (int8_t)((other->coords[i] & 0x00FF0000) >> 16); + + temp[3] = (int8_t)((coords[i]) >> 24); + temp_other[3] = (int8_t)((other->coords[i]) >> 24); + + totals[0] += (temp[0] - temp_other[0]) * (temp[0] - temp_other[0]); + totals[1] += (temp[1] - temp_other[1]) * (temp[1] - temp_other[1]); + totals[2] += (temp[2] - temp_other[2]) * (temp[2] - temp_other[2]); + totals[3] += (temp[3] - temp_other[3]) * (temp[3] - temp_other[3]); + } + return totals[0] + totals[1] + totals[2] + totals[3]; + } + + +#if __CUDA_ARCH__ > 610 // Use intrinsics if available for GPU being compiled for + // With int8 datatype, values are packed into integers so they need to be + // unpacked while computing distance + __device__ SUMTYPE cosine(Point* other) { + int32_t prod = 0; + int32_t src = 0; + int32_t target = 0; + + for (int i = 0; i < Dim / 4; ++i) { + src = coords[i]; + target = other->coords[i]; + prod = __dp4a(src, target, prod); + } + + return ((SUMTYPE)16384) - (SUMTYPE)prod; + } + +#else + __device__ SUMTYPE cosine(Point* other) { + SUMTYPE prod[4]; + prod[0] = 0; + + for (int i = 0; i < Dim / 4; ++i) { + prod[0] += ((int8_t)(coords[i] & 0x000000FF)) * ((int8_t)(other->coords[i] & 0x000000FF)); + prod[1] = ((int8_t)((coords[i] & 0x0000FF00) >> 8)) * ((int8_t)((other->coords[i] & 0x0000FF00) >> 8)); + prod[2] = ((int8_t)((coords[i] & 0x00FF0000) >> 16)) * ((int8_t)((other->coords[i] & 0x00FF0000) >> 16)); + prod[3] = ((int8_t)((coords[i]) >> 24)) * ((int8_t)((other->coords[i]) >> 24)); + prod[0] += prod[1] + prod[2] + prod[3]; + } + + return ((SUMTYPE)1) - prod[0]; + } +#endif + + __forceinline__ __device__ SUMTYPE dist(Point* other, int metric) { + if (metric == 0) return l2(other); + else return cosine(other); + } + +}; + +/********************************************************************* + * Create an array of Point structures out of an input array + ********************************************************************/ +template +__host__ Point* convertMatrix(T* data, size_t rows, int exact_dim) { + Point* pointArray = (Point*)malloc(rows * sizeof(Point)); + for (size_t i = 0; i < rows; i++) { + pointArray[i].loadChunk(&data[i * exact_dim], exact_dim); + } + return pointArray; +} + +template +__host__ Point* convertMatrix(SPTAG::VectorIndex* index, size_t rows, int exact_dim) { + Point* pointArray = (Point*)malloc(rows * sizeof(Point)); + + T* data; + + for (int i = 0; i < rows; i++) { + data = (T*)index->GetSample(i); + pointArray[i].loadChunk(data, exact_dim); + } + return pointArray; +} + +template +__host__ void extractHeadPoints(T* data, Point* headPoints, size_t totalRows, std::unordered_set headVectorIDS, int exact_dim) { + int headIdx = 0; + for (size_t i = 0; i < totalRows; i++) { + if (headVectorIDS.count(i) != 0) { + headPoints[headIdx].loadChunk(&data[i * exact_dim], exact_dim); + headPoints[headIdx].id = i; + headIdx++; + } + } +} + +template +__host__ void extractTailPoints(T* data, Point* tailPoints, int totalRows, std::unordered_set headVectorIDS, int exact_dim) { + int tailIdx = 0; + for (size_t i = 0; i < totalRows; i++) { + if (headVectorIDS.count(i) == 0) { + tailPoints[tailIdx].loadChunk(&data[i * exact_dim], exact_dim); + tailPoints[tailIdx].id = i; + tailIdx++; + } + } +} + +template +__host__ void extractFullVectorPoints(T* data, Point* tailPoints, size_t totalRows, int exact_dim) { + + for (size_t i = 0; i < totalRows; ++i) { + tailPoints[i].loadChunk(&data[i * exact_dim], exact_dim); + tailPoints[i].id = i; + } +} + +template +__host__ void extractHeadPointsFromIndex(T* data, SPTAG::VectorIndex* headIndex, Point* headPoints, int exact_dim) { + size_t headRows = headIndex->GetNumSamples(); + + for (size_t i = 0; i < headRows; ++i) { + headPoints[i].loadChunk((T*)headIndex->GetSample(i), exact_dim); + headPoints[i].id = i; + } +} + +/************************************************************************ + * Wrapper around shared memory holding transposed points to avoid + * bank conflicts. + ************************************************************************/ +template +class TransposePoint { +public: + T* dataPtr; + + __device__ void setMem(T* ptr) { + dataPtr = ptr; + } + + // Load regular point into memory transposed + __device__ void loadPoint(Point p) { + for (int i = 0; i < Dim; i++) { + dataPtr[i * Stride] = p.coords[i]; + } + } + + // Return the idx-th coordinate of the transposed vector + __forceinline__ __device__ T getCoord(int idx) { + return dataPtr[idx * Stride]; + } + + /****************************************************************************************** + * Main L2 distance metric used by approx-KNN application. + ******************************************************************************************/ + __forceinline__ __device__ __host__ SUMTYPE l2(Point* other) { + SUMTYPE total = 0; +#pragma unroll + for (int i = 0; i < Dim; ++i) { + total += ((getCoord(i) - other->coords[i]) * (getCoord(i) - other->coords[i])); + } + + return total; + } + + /****************************************************************************************** + * Cosine distance metric comparison operation. Requires that the SUMTYPE is floating point, + * regardless of the datatype T, because it requires squareroot. + ******************************************************************************************/ + __forceinline__ __device__ __host__ SUMTYPE cosine(Point < T, SUMTYPE, Dim >* other) { + SUMTYPE prod = 0; + SUMTYPE a = 0; + SUMTYPE b = 0; +#pragma unroll + for (int i = 0; i < Dim; ++i) { + a += getCoord(i) * getCoord(i); + b += other->coords[i] * other->coords[i]; + prod += getCoord(i) * other->coords[i]; + } + + return 1 - (prod / (sqrt(a * b))); + } +}; + +template +class TransposePoint { +public: + uint32_t* dataPtr; + + __device__ void setMem(void* ptr) { + dataPtr = reinterpret_cast (ptr); + } + + // Load regular point into memory transposed + __device__ void loadPoint(Point p) { + for (int i = 0; i < Dim / 4; i++) { + dataPtr[i * Stride] = p.coords[i]; + } + } + + // Return the idx-th coordinate of the transposed vector + __forceinline__ __device__ uint32_t getCoord(int idx) { + return dataPtr[idx * Stride]; + } + + __forceinline__ __device__ __host__ SUMTYPE l2(Point* other) { + + SUMTYPE totals[4] = { 0,0,0,0 }; + int32_t temp[4]; + int32_t temp_other[4]; + + for (int i = 0; i < Dim / 4; ++i) { + temp[0] = (int8_t)(getCoord(i) & 0x000000FF); + temp_other[0] = (int8_t)(other->coords[i] & 0x000000FF); + + temp[1] = (int8_t)((getCoord(i) & 0x0000FF00) >> 8); + temp_other[1] = (int8_t)((other->coords[i] & 0x0000FF00) >> 8); + + temp[2] = (int8_t)((getCoord(i) & 0x00FF0000) >> 16); + temp_other[2] = (int8_t)((other->coords[i] & 0x00FF0000) >> 16); + + temp[3] = (int8_t)((getCoord(i)) >> 24); + temp_other[3] = (int8_t)((other->coords[i]) >> 24); + + totals[0] += (temp[0] - temp_other[0]) * (temp[0] - temp_other[0]); + totals[1] += (temp[1] - temp_other[1]) * (temp[1] - temp_other[1]); + totals[2] += (temp[2] - temp_other[2]) * (temp[2] - temp_other[2]); + totals[3] += (temp[3] - temp_other[3]) * (temp[3] - temp_other[3]); + } + return totals[0] + totals[1] + totals[2] + totals[3]; + } + + __device__ __host__ SUMTYPE cosine(Point* other) { + SUMTYPE prod[4] = { 0,0,0,0 }; + int8_t temp[4]; + int8_t temp_other[4]; + + for (int i = 0; i < Dim / 4; ++i) { + temp[0] = ((int8_t)(getCoord(i) & 0x000000FF)); + temp_other[0] = ((int8_t)(other->coords[i] & 0x000000FF)); + + temp[1] = ((int8_t)((getCoord(i) & 0x0000FF00) >> 8)); + temp_other[1] = ((int8_t)((other->coords[i] & 0x0000FF00) >> 8)); + + temp[2] = ((int8_t)((getCoord(i) & 0x00FF0000) >> 16)); + temp_other[2] = ((int8_t)((other->coords[i] & 0x00FF0000) >> 16)); + + temp[3] = ((int8_t)((getCoord(i)) >> 24)); + temp_other[3] = ((int8_t)((other->coords[i]) >> 24)); + + prod[0] += temp[0] * temp_other[0]; + prod[1] += temp[1] * temp_other[1]; + prod[2] += temp[2] * temp_other[2]; + prod[3] += temp[3] * temp_other[3]; + //prod[0] += ((int8_t)(getCoord(i) & 0x000000FF)) * ((int8_t)(other->coords[i] & 0x000000FF)); + //prod[1] += ((int8_t)((getCoord(i) & 0x0000FF00) >> 8)) * ((int8_t)((other->coords[i] & 0x0000FF00) >> 8)); + //prod[2] += ((int8_t)((getCoord(i) & 0x00FF0000) >> 16)) * ((int8_t)((other->coords[i] & 0x00FF0000) >> 16)); + //prod[3] += ((int8_t)((getCoord(i)) >> 24)) * ((int8_t)((other->coords[i]) >> 24)); + + //prod[0] += prod[1] + prod[2] + prod[3]; + } + return ((SUMTYPE)65536) - prod[0] - prod[1] - prod[2] - prod[3]; + } +}; + +template +class TransposePoint { +public: + uint32_t* dataPtr; + + __device__ void setMem(void* ptr) { + dataPtr = reinterpret_cast (ptr); + } + + // Load regular point into memory transposed + __device__ void loadPoint(Point p) { + for (int i = 0; i < Dim / 4; i++) { + dataPtr[i * Stride] = p.coords[i]; + } + } + + // Return the idx-th coordinate of the transposed vector + __forceinline__ __device__ uint32_t getCoord(int idx) { + return dataPtr[idx * Stride]; + } + + __forceinline__ __device__ __host__ SUMTYPE l2(Point* other) { + + SUMTYPE totals[4] = { 0,0,0,0 }; + int32_t temp[4]; + int32_t temp_other[4]; + + for (int i = 0; i < Dim / 4; ++i) { + temp[0] = (uint8_t)(getCoord(i) & 0x000000FF); + temp_other[0] = (uint8_t)(other->coords[i] & 0x000000FF); + + temp[1] = (uint8_t)((getCoord(i) & 0x0000FF00) >> 8); + temp_other[1] = (uint8_t)((other->coords[i] & 0x0000FF00) >> 8); + + temp[2] = (uint8_t)((getCoord(i) & 0x00FF0000) >> 16); + temp_other[2] = (uint8_t)((other->coords[i] & 0x00FF0000) >> 16); + + temp[3] = (uint8_t)((getCoord(i)) >> 24); + temp_other[3] = (uint8_t)((other->coords[i]) >> 24); + + totals[0] += (temp[0] - temp_other[0]) * (temp[0] - temp_other[0]); + totals[1] += (temp[1] - temp_other[1]) * (temp[1] - temp_other[1]); + totals[2] += (temp[2] - temp_other[2]) * (temp[2] - temp_other[2]); + totals[3] += (temp[3] - temp_other[3]) * (temp[3] - temp_other[3]); + } + return totals[0] + totals[1] + totals[2] + totals[3]; + } + + __forceinline__ __device__ __host__ SUMTYPE cosine(Point* other) { + SUMTYPE prod[4]; + prod[0] = 0; + + for (int i = 0; i < Dim / 4; ++i) { + prod[0] += (uint8_t)(getCoord(i) & 0x000000FF) * (uint8_t)(other->coords[i] & 0x000000FF); + prod[1] = (uint8_t)((getCoord(i) & 0x0000FF00) >> 8) * (uint8_t)((other->coords[i] & 0x0000FF00) >> 8); + prod[2] = (uint8_t)((getCoord(i) & 0x00FF0000) >> 16) * (uint8_t)((other->coords[i] & 0x00FF0000) >> 16); + prod[3] = (uint8_t)((getCoord(i)) >> 24) * (uint8_t)((other->coords[i]) >> 24); + + prod[0] += prod[1] + prod[2] + prod[3]; + } + + return ((SUMTYPE)65536) - prod[0]; + } +}; +#endif \ No newline at end of file diff --git a/AnnService/inc/Core/Common/cuda/GPUQuantizer.hxx b/AnnService/inc/Core/Common/cuda/GPUQuantizer.hxx index 1fcfef4e1..4fcf06325 100644 --- a/AnnService/inc/Core/Common/cuda/GPUQuantizer.hxx +++ b/AnnService/inc/Core/Common/cuda/GPUQuantizer.hxx @@ -35,6 +35,7 @@ #include "params.h" #include "../PQQuantizer.h" +#include "../OPQQuantizer.h" #include "inc/Core/VectorIndex.h" @@ -50,48 +51,110 @@ using namespace SPTAG; enum DistMetric { L2, Cosine }; -class GPU_PQQuantizer { +enum RType {Int8, UInt8, Float, Int16}; + +namespace { + +__host__ int GetRTypeWidth(RType type) { + if(type == RType::Int8) return 1; + if(type == RType::UInt8) return 1; + if(type == RType::Int16) return 2; + if(type == RType::Float) return 4; +} + +class GPU_Quantizer { // private: public: int m_NumSubvectors; size_t m_KsPerSubvector; size_t m_BlockSize; + int m_DimPerSubvector; float* m_DistanceTables; // Don't need L2 and cosine tables, just copy from correct tables on host + RType m_type; + + void* m_codebooks; // Codebooks to reconstruct the original vectors + __device__ inline size_t m_DistIndexCalc(size_t i, size_t j, size_t k) { return m_BlockSize * i + j * m_KsPerSubvector + k; } public: - __host__ GPU_PQQuantizer() {} + __host__ GPU_Quantizer() {} - __host__ GPU_PQQuantizer(std::shared_ptr quantizer, DistMetric metric) { + __host__ GPU_Quantizer(std::shared_ptr quantizer, DistMetric metric) { + + // Make sure L2 is used, since other + if(metric != DistMetric::L2) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Only L2 distance currently supported for PQ or OPQ\n"); + exit(1); + } + + VectorValueType rType; - SPTAG::COMMON::PQQuantizer* pq_quantizer = (SPTAG::COMMON::PQQuantizer*)quantizer.get(); + if(quantizer->GetQuantizerType() == QuantizerType::PQQuantizer) { + SPTAG::COMMON::PQQuantizer* pq_quantizer = (SPTAG::COMMON::PQQuantizer*)quantizer.get(); - m_NumSubvectors = pq_quantizer->GetNumSubvectors(); - m_KsPerSubvector = pq_quantizer->GetKsPerSubvector(); - m_BlockSize = pq_quantizer->GetBlockSize(); + m_NumSubvectors = pq_quantizer->GetNumSubvectors(); + m_KsPerSubvector = pq_quantizer->GetKsPerSubvector(); + m_BlockSize = pq_quantizer->GetBlockSize(); + m_DimPerSubvector = pq_quantizer->GetDimPerSubvector(); - CUDA_CHECK(cudaMalloc(&m_DistanceTables, m_BlockSize * m_NumSubvectors * sizeof(float))); + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "Using PQ - numSubVectors:%d, KsPerSub:%ld, BlockSize:%ld, DimPerSub:%d, total size of tables:%ld\n", m_NumSubvectors, m_KsPerSubvector, m_BlockSize, m_DimPerSubvector, m_BlockSize*m_NumSubvectors*sizeof(float)); - if(metric == DistMetric::Cosine) { - CUDA_CHECK(cudaMemcpy(m_DistanceTables, pq_quantizer->GetCosineDistanceTables(), m_BlockSize*m_NumSubvectors*sizeof(float), cudaMemcpyHostToDevice)); + rType = pq_quantizer->GetReconstructType(); + CUDA_CHECK(cudaMalloc(&m_DistanceTables, m_BlockSize * m_NumSubvectors * sizeof(float))); + CUDA_CHECK(cudaMemcpy(m_DistanceTables, pq_quantizer->GetL2DistanceTables(), m_BlockSize*m_NumSubvectors*sizeof(float), cudaMemcpyHostToDevice)); + } + else if(quantizer->GetQuantizerType() == QuantizerType::OPQQuantizer) { + SPTAG::COMMON::OPQQuantizer* opq_quantizer = (SPTAG::COMMON::OPQQuantizer*)quantizer.get(); + + m_NumSubvectors = opq_quantizer->GetNumSubvectors(); + m_KsPerSubvector = opq_quantizer->GetKsPerSubvector(); + m_BlockSize = opq_quantizer->GetBlockSize(); + m_DimPerSubvector = opq_quantizer->GetDimPerSubvector(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "Using OPQ - numSubVectors:%d, KsPerSub:%ld, BlockSize:%ld, DimPerSub:%d, total size of tables:%ld\n", m_NumSubvectors, m_KsPerSubvector, m_BlockSize, m_DimPerSubvector, m_BlockSize*m_NumSubvectors*sizeof(float)); + + rType = opq_quantizer->GetReconstructType(); + CUDA_CHECK(cudaMalloc(&m_DistanceTables, m_BlockSize * m_NumSubvectors * sizeof(float))); + CUDA_CHECK(cudaMemcpy(m_DistanceTables, opq_quantizer->GetL2DistanceTables(), m_BlockSize*m_NumSubvectors*sizeof(float), cudaMemcpyHostToDevice)); } else { - CUDA_CHECK(cudaMemcpy(m_DistanceTables, pq_quantizer->GetL2DistanceTables(), m_BlockSize*m_NumSubvectors*sizeof(float), cudaMemcpyHostToDevice)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Only PQ and OPQ quantizers are supported for GPU build\n"); + exit(1); } + if(rType == SPTAG::VectorValueType::Float) m_type = RType::Float; + else if(rType == SPTAG::VectorValueType::Int8) m_type = RType::Int8; + else if(rType == SPTAG::VectorValueType::UInt8) m_type = RType::UInt8; + else if(rType == SPTAG::VectorValueType::Int16) m_type = RType::Int16; + } - __host__ ~GPU_PQQuantizer() + __host__ ~GPU_Quantizer() { cudaFree(m_DistanceTables); } +/* + template + __device__ void ReconstructVector(uint8_t* qvec, R* vecout) + { + for (int i = 0; i < m_NumSubvectors; i++) { + SizeType codebook_idx = (i * m_KsPerSubvector * m_DimPerSubvector) + (qvec[i] * m_DimPerSubvector); + R* sub_vecout = &(vecout[i * m_DimPerSubvector]); +// for (int j = 0; j < m_DimPerSubvector; j++) { +// sub_vecout[j] = m_codebooks[codebook_idx + j]; +// } + memcpy(sub_vecout, m_codebooks+codebook_idx, m_DimPerSubvector*sizeof(R)); + } + } +*/ + // TODO - Optimize quantized distance comparator __forceinline__ __device__ float dist(uint8_t* pX, uint8_t* pY) { float totals[2]; totals[0]=0; totals[1]=0; @@ -100,10 +163,6 @@ class GPU_PQQuantizer { totals[1] += m_DistanceTables[m_DistIndexCalc(i+1, pX[i+1], pY[i+1])]; } -// if(metric == DistMetric::Cosine) { -// out = 1 - out; -// } -// return 1 - (totals[0]+totals[1]); return totals[0]+totals[1]; } @@ -112,7 +171,35 @@ class GPU_PQQuantizer { return between <= distance; } +// Attempt to improve perf by checking threshold during computation to short-circuit +// distance lookups +/* + __forceinline__ __device__ float dist(uint8_t* pX, uint8_t* pY, float target) { + float totals[2]; totals[0]=0; totals[1]=0; + float temp; + + for(int i = 0; i < m_NumSubvectors; i+=2) { + totals[0] += m_DistanceTables[m_DistIndexCalc(i, pX[i], pY[i])]; +// totals[1] += m_DistanceTables[m_DistIndexCalc(i+1, pX[i+1], pY[i+1])]; + temp = m_DistanceTables[m_DistIndexCalc(i+1, pX[i+1], pY[i+1])]; + + if(totals[0] >= target || totals[0] >= (target*(i+1))/(float)m_NumSubvectors) { + return target+1.0; + } + totals[0]+=temp; + } + + return totals[0]; + } + + __device__ bool violatesRNG_test(uint8_t* a, uint8_t* b, float distance) { + float between = dist(a, b, distance); + return between <= distance; + } +*/ }; +} + #endif // GPUQUANTIZER_H diff --git a/AnnService/inc/Core/Common/cuda/KNN.hxx b/AnnService/inc/Core/Common/cuda/KNN.hxx index 3fcd01203..abcb44c04 100644 --- a/AnnService/inc/Core/Common/cuda/KNN.hxx +++ b/AnnService/inc/Core/Common/cuda/KNN.hxx @@ -26,13 +26,16 @@ #define _SPTAG_COMMON_CUDA_KNN_H_ #include "Refine.hxx" +#include "log.hxx" #include "ThreadHeap.hxx" #include "TPtree.hxx" #include "GPUQuantizer.hxx" +#include +#include template -__device__ void findRNG_PQ(PointSet* ps, TPtree* tptree, int KVAL, int* results, size_t min_id, size_t max_id, DistPair* threadList, GPU_PQQuantizer* quantizer) { +__device__ void findRNG_PQ(PointSet* ps, TPtree* tptree, int KVAL, int* results, size_t min_id, size_t max_id, DistPair* threadList, GPU_Quantizer* quantizer) { uint8_t query[Dim]; uint8_t* candidate_vec; @@ -199,7 +202,7 @@ __device__ void findRNG(PointSet* ps, TPtree* tptree, int KVAL, int* results, candidate_vec = ps->getVec(candidate.idx); candidate.dist = dist_comp(query, candidate_vec); - if(candidate.dist < max_dist){ // If it is a candidate to be added to neighbor list + if(max_dist >= candidate.dist){ // If it is a candidate to be added to neighbor list for(read_id=0; candidate.dist > threadList[read_id].dist && good; read_id++) { if(violatesRNG(candidate_vec, ps->getVec(threadList[read_id].idx), candidate.dist, dist_comp)) { good = false; @@ -263,40 +266,47 @@ __device__ void findRNG(PointSet* ps, TPtree* tptree, int KVAL, int* results, return; \ } -#define MAX_SHAPE 1024 +#define MAX_SHAPE 384 +#define MAX_PQ_SHAPE 100 template -__global__ void findRNG_selector(PointSet* ps, TPtree* tptree, int KVAL, int* results, size_t min_id, size_t max_id, int dim, GPU_PQQuantizer* quantizer) { +__global__ void findRNG_selector(PointSet* ps, TPtree* tptree, int KVAL, int* results, size_t min_id, size_t max_id, int dim, GPU_Quantizer* quantizer) { extern __shared__ char sharememory[]; +// Enable dimension of dataset that you will be using for maximum performance if(quantizer == NULL) { -// RUN_KERNEL(64); + // Add specific vector dimension needed if not already listed here + RUN_KERNEL(64); RUN_KERNEL(100); - RUN_KERNEL(200); -// RUN_KERNEL(MAX_SHAPE); + RUN_KERNEL(MAX_SHAPE); } else { -// RUN_KERNEL_QUANTIZED(25); RUN_KERNEL_QUANTIZED(50); -// RUN_KERNEL_QUANTIZED(64); - RUN_KERNEL_QUANTIZED(100); + RUN_KERNEL_QUANTIZED(MAX_PQ_SHAPE); } } template -void run_TPT_batch_multigpu(size_t dataSize, int** d_results, TPtree** tptrees, TPtree** d_tptrees, int iters, int levels, int NUM_GPUS, int KVAL, cudaStream_t* streams, std::vector batch_min, std::vector batch_max, int balanceFactor, PointSet** d_pointset, int dim, GPU_PQQuantizer* quantizer) +void run_TPT_batch_multigpu(size_t dataSize, int** d_results, TPtree** tptrees, TPtree** d_tptrees, int iters, int levels, int NUM_GPUS, int KVAL, cudaStream_t* streams, std::vector batch_min, std::vector batch_max, int balanceFactor, PointSet** d_pointset, int dim, GPU_Quantizer* quantizer, SPTAG::VectorIndex* index) { // Set num blocks for all GPU kernel calls int KNN_blocks= max(tptrees[0]->num_leaves, BLOCKS); + for(int tree_id=0; tree_id < iters; ++tree_id) { // number of TPTs used to create approx. KNN graph + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, 0)); // Get avil. memory + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Starting TPtree number %d, Total memory of GPU 0:%ld, GPU 0 memory used:%ld\n", tree_id, totalMem, freeMem); + auto before_tpt = std::chrono::high_resolution_clock::now(); // Start timer for TPT building - create_tptree_multigpu(tptrees, d_pointset, dataSize, levels, NUM_GPUS, streams, balanceFactor); + create_tptree_multigpu(tptrees, d_pointset, dataSize, levels, NUM_GPUS, streams, balanceFactor, index); CUDA_CHECK(cudaDeviceSynchronize()); // Copy TPTs to each GPU @@ -308,6 +318,13 @@ void run_TPT_batch_multigpu(size_t dataSize, int** d_results, TPtree** tptrees, auto after_tpt = std::chrono::high_resolution_clock::now(); + if(quantizer == NULL && dim > MAX_PQ_SHAPE) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Input PQ dimension is %d, GPU index build with PQ dimension larger than %d not supported.\n", dim, MAX_SHAPE); + } + else if(dim > MAX_SHAPE) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Input vector dimension is %d, GPU index build of vector dimensions larger than %d not supported.\n", dim, MAX_SHAPE); + } + for(int gpuNum=0; gpuNum < NUM_GPUS; gpuNum++) { CUDA_CHECK(cudaSetDevice(gpuNum)); // Compute the STRICT RNG for each leaf node @@ -322,7 +339,7 @@ void run_TPT_batch_multigpu(size_t dataSize, int** d_results, TPtree** tptrees, double loop_tpt_time = GET_CHRONO_TIME(before_tpt, after_tpt); double loop_work_time = GET_CHRONO_TIME(after_tpt, after_work); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished tree %d - tree build time:%.2lf, neighbor compute time:%.2lf\n", tree_id, loop_tpt_time, loop_work_time); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished tree %d - tree build time:%.2lf, neighbor compute time:%.2lf\n", tree_id, loop_tpt_time, loop_work_time); } } @@ -341,6 +358,8 @@ void buildGraphGPU(SPTAG::VectorIndex* index, size_t dataSize, int KVAL, int tre size_t rawSize = dataSize*dim; cudaError_t resultErr; + printf("leafSize:%d, levels:%d\n", leafSize, levels); + // Timers double tree_time=0.0; double KNN_time=0.0; @@ -360,16 +379,13 @@ void buildGraphGPU(SPTAG::VectorIndex* index, size_t dataSize, int KVAL, int tre /**** Varibales if PQ is enabled ****/ - GPU_PQQuantizer* d_quantizer = NULL; // Only use if quantizer is enabled - GPU_PQQuantizer* h_quantizer = NULL; + GPU_Quantizer* d_quantizer = NULL; // Only use if quantizer is enabled + GPU_Quantizer* h_quantizer = NULL; if(use_q) { -printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric == DistMetric::L2); - h_quantizer = new GPU_PQQuantizer(index->m_pQuantizer, (DistMetric)metric); - CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_PQQuantizer))); - CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_PQQuantizer), cudaMemcpyHostToDevice)); - - + h_quantizer = new GPU_Quantizer(index->m_pQuantizer, (DistMetric)metric); + CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_Quantizer))); + CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_Quantizer), cudaMemcpyHostToDevice)); } /********* Allocate and transfer data to each GPU ********/ @@ -378,7 +394,7 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric CUDA_CHECK(cudaSetDevice(gpuNum)); CUDA_CHECK(cudaStreamCreate(&streams[gpuNum])); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating raw coordinate data on GPU %d, total of %zu bytes\n", gpuNum, rawSize*sizeof(DTYPE)); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating raw coordinate data on GPU %d, total of %zu bytes\n", gpuNum, rawSize*sizeof(DTYPE)); CUDA_CHECK(cudaMalloc(&d_data_raw[gpuNum], rawSize*sizeof(DTYPE))); } @@ -401,7 +417,7 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric tptrees[gpuNum] = new TPtree; tptrees[gpuNum]->initialize(dataSize, levels, dim); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "TPT structure initialized for %lu points, %d levels, leaf size:%d\n", dataSize, levels, leafSize); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "TPT structure initialized for %lu points, %d levels, leaf size:%d\n", dataSize, levels, leafSize); // Allocate results buffer on each GPU CUDA_CHECK(cudaMalloc(&d_results[gpuNum], (size_t)batchSize[gpuNum]*KVAL*sizeof(int))); @@ -432,7 +448,7 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric batch_min[gpuNum] = GPUOffset[gpuNum]+batchOffset[gpuNum]; batch_max[gpuNum] = batch_min[gpuNum] + curr_batch_size[gpuNum]; - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d - starting batch with min_id:%ld, max_id:%ld\n", gpuNum, batch_min[gpuNum], batch_max[gpuNum]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d - starting batch with min_id:%ld, max_id:%ld\n", gpuNum, batch_min[gpuNum], batch_max[gpuNum]); CUDA_CHECK(cudaSetDevice(gpuNum)); // Initialize results to all -1 (special value that is set to distance INFTY) @@ -442,10 +458,14 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric /***** Run batch on GPU (all TPT iters) *****/ if(metric == (int)DistMetric::Cosine) { - run_TPT_batch_multigpu(dataSize, d_results, tptrees, d_tptrees, trees, levels, NUM_GPUS, KVAL, streams.data(), batch_min, batch_max, balanceFactor, d_pointset, dim, d_quantizer); + if(d_quantizer != NULL) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cosine distance not currently supported when using quantization.\n"); + exit(1); + } + run_TPT_batch_multigpu(dataSize, d_results, tptrees, d_tptrees, trees, levels, NUM_GPUS, KVAL, streams.data(), batch_min, batch_max, balanceFactor, d_pointset, dim, d_quantizer, index); } else { - run_TPT_batch_multigpu(dataSize, d_results, tptrees, d_tptrees, trees, levels, NUM_GPUS, KVAL, streams.data(), batch_min, batch_max, balanceFactor, d_pointset, dim, d_quantizer); + run_TPT_batch_multigpu(dataSize, d_results, tptrees, d_tptrees, trees, levels, NUM_GPUS, KVAL, streams.data(), batch_min, batch_max, balanceFactor, d_pointset, dim, d_quantizer, index); } auto before_copy = std::chrono::high_resolution_clock::now(); @@ -455,7 +475,7 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric auto after_copy = std::chrono::high_resolution_clock::now(); double batch_copy_time = GET_CHRONO_TIME(before_copy, after_copy); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished batch - time to copy result:%.2lf\n", batch_copy_time); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished batch - time to copy result:%.2lf\n", batch_copy_time); D2H_time += batch_copy_time; // Update all batchOffsets and check if done @@ -470,13 +490,15 @@ printf("Using quantizer, and metric:%d (L2?:%d)\n", (metric), (DistMetric)metric auto end_t = std::chrono::high_resolution_clock::now(); - LOG(SPTAG::Helper::LogLevel::LL_Info, "Total times - prep time:%0.3lf, tree build:%0.3lf, neighbor compute:%0.3lf, Copy results:%0.3lf, Total runtime:%0.3lf\n", prep_time, tree_time, KNN_time, D2H_time, GET_CHRONO_TIME(start_t, end_t)); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Total times - prep time:%0.3lf, tree build:%0.3lf, neighbor compute:%0.3lf, Copy results:%0.3lf, Total runtime:%0.3lf\n", prep_time, tree_time, KNN_time, D2H_time, GET_CHRONO_TIME(start_t, end_t)); for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { tptrees[gpuNum]->destroy(); delete tptrees[gpuNum]; - cudaFree(d_tptrees[gpuNum]); - cudaFree(d_results[gpuNum]); + CUDA_CHECK(cudaFree(d_tptrees[gpuNum])); + CUDA_CHECK(cudaFree(d_results[gpuNum])); + CUDA_CHECK(cudaFree(d_data_raw[gpuNum])); + CUDA_CHECK(cudaFree(d_pointset[gpuNum])); } delete[] tptrees; delete[] d_results; @@ -485,6 +507,7 @@ auto end_t = std::chrono::high_resolution_clock::now(); delete[] d_pointset; if(use_q) { + CUDA_CHECK(cudaFree(d_quantizer)); delete h_quantizer; } @@ -516,11 +539,11 @@ void buildGraph(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhood int numDevicesOnHost; CUDA_CHECK(cudaGetDeviceCount(&numDevicesOnHost)); if(numDevicesOnHost < NUM_GPUS) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "HeadNumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "HeadNumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost); exit(1); } - LOG(SPTAG::Helper::LogLevel::LL_Info, "Building Head graph with %d GPUs...\n", NUM_GPUS); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Building Head graph with %d GPUs...\n", NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS); /**** Compute result batch sizes for each GPU ****/ std::vector batchSize(NUM_GPUS); @@ -533,7 +556,7 @@ void buildGraph(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhood cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); // Get avil. memory - LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); size_t freeMem, totalMem; CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); @@ -546,22 +569,22 @@ void buildGraph(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhood resMemAvail = (freeMem*0.9) - (rawDataSize+pointSetSize+treeSize); // Only use 90% of total memory to be safe maxEltsPerBatch = resMemAvail / (dim*sizeof(T) + KVAL*sizeof(int)); - batchSize[gpuNum] = std::min(maxEltsPerBatch, resPerGPU[gpuNum]); + batchSize[gpuNum] = (std::min)(maxEltsPerBatch, resPerGPU[gpuNum]); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Memory for rawData:%lu MiB, pointSet structure:%lu MiB, Memory for TP trees:%lu MiB, Memory left for results:%lu MiB, total vectors:%lu, batch size:%d, total batches:%d\n", rawSize/1000000, pointSetSize/1000000, treeSize/1000000, resMemAvail/1000000, resPerGPU[gpuNum], batchSize[gpuNum], (((batchSize[gpuNum]-1)+resPerGPU[gpuNum]) / batchSize[gpuNum])); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Memory for rawData:%lu MiB, pointSet structure:%lu MiB, Memory for TP trees:%lu MiB, Memory left for results:%lu MiB, total vectors:%lu, batch size:%d, total batches:%d\n", rawSize/1000000, pointSetSize/1000000, treeSize/1000000, resMemAvail/1000000, resPerGPU[gpuNum], batchSize[gpuNum], (((batchSize[gpuNum]-1)+resPerGPU[gpuNum]) / batchSize[gpuNum])); // If GPU memory is insufficient or so limited that we need so many batches it becomes inefficient, return error if(batchSize[gpuNum] == 0 || ((int)resPerGPU[gpuNum]) / batchSize[gpuNum] > 10000) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build Head index on GPU %d. Available GPU memory:%lu MB, Points and tpt require:%lu MB, leaving a maximum batch size of %d results to be computed, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (rawDataSize+pointSetSize+treeSize)/1000000, maxEltsPerBatch); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build Head index on GPU %d. Available GPU memory:%lu MB, Points and tpt require:%lu MB, leaving a maximum batch size of %d results to be computed, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (rawDataSize+pointSetSize+treeSize)/1000000, maxEltsPerBatch); exit(1); } } std::vector GPUOffset(NUM_GPUS); GPUOffset[0] = 0; - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: results:%lu, offset:%lu\n", resPerGPU[0], GPUOffset[0]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: results:%lu, offset:%lu\n", resPerGPU[0], GPUOffset[0]); for(int gpuNum=1; gpuNum < NUM_GPUS; ++gpuNum) { GPUOffset[gpuNum] = GPUOffset[gpuNum-1] + resPerGPU[gpuNum-1]; - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: results:%lu, offset:%lu\n", gpuNum, resPerGPU[gpuNum], GPUOffset[gpuNum]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: results:%lu, offset:%lu\n", gpuNum, resPerGPU[gpuNum], GPUOffset[gpuNum]); } if(index->m_pQuantizer != NULL) { @@ -574,7 +597,7 @@ void buildGraph(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhood buildGraphGPU(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim); } else { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Selected datatype not currently supported.\n"); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Selected datatype not currently supported.\n"); exit(1); } } @@ -675,4 +698,376 @@ __global__ void findKNN_leaf_nodes(Point* data, TPtree >& batch_truth, std::vector< std::vector >& batch_dist, std::vector< std::vector >temp_truth, std::vector< std::vector >temp_dist, + int result_size, int K) { + //SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Entered updateFile.\n"); + std::vector< std::vector > result_truth = batch_truth; + std::vector< std::vector > result_dist = batch_dist; + for (int i = 0; i < result_size; i++) { + //For each result, we need to compare the 2K DistPairs and take top K. + int reader1 = 0; + int reader2 = 0; + for (int writer_ptr = 0; writer_ptr < K; writer_ptr++) { + if (batch_dist[i][reader1] <= temp_dist[i][reader2]) { + result_truth[i][writer_ptr] = batch_truth[i][reader1]; + result_dist[i][writer_ptr] = batch_dist[i][reader1++]; + } + else { + result_truth[i][writer_ptr] = temp_truth[i][reader2]; + result_dist[i][writer_ptr] = temp_dist[i][reader2++]; + } + } + } + batch_truth = result_truth; + batch_dist = result_dist; +} + +template +__global__ void query_KNN(Point* querySet, Point* data, int dataSize, int idx_offset, int numQueries, DistPair* results, int KVAL, int metric) { + extern __shared__ char sharememory[]; + __shared__ ThreadHeap heapMem[BLOCK_DIM]; + DistPair extra; // extra variable to store the largest distance/id for all KNN of the point + //TransposePoint query; // Stores in strided memory to avoid bank conflicts + //__shared__ DTYPE transpose_mem[Dim * BLOCK_DIM]; + Point query; + + //if (cuda::std::is_same::value || cuda::std::is_same::value) { + // query.setMem(&transpose_mem[threadIdx.x*4]); + //} + //else { + // query.setMem(&transpose_mem[threadIdx.x]); + //} + + heapMem[threadIdx.x].initialize(&((DistPair*)sharememory)[(KVAL - 1) * threadIdx.x], KVAL - 1); + + SUMTYPE dist; + // Loop through all query points + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numQueries; i += blockDim.x * gridDim.x) { + heapMem[threadIdx.x].reset(); + extra.dist = INFTY(); + query=querySet[i]; + //query.loadPoint(querySet[i]); // Load into shared memory + // Compare with all points in the dataset + for (int j = 0; j < dataSize; j++) { + if (metric == 0) { + dist = query.l2(&data[j]); + } + else if (metric == 1){ + dist = query.cosine(&data[j]); + } + if (dist < extra.dist) { + if (dist < heapMem[threadIdx.x].top()) { + extra.dist = heapMem[threadIdx.x].vals[0].dist; + extra.idx = heapMem[threadIdx.x].vals[0].idx + idx_offset; + //When recording the index, remember the off_set to discriminate different subvectorSets + heapMem[threadIdx.x].insert(dist, j + idx_offset); + } + else { + extra.dist = dist; + extra.idx = j + idx_offset; + } + } + } + // Write KNN to result list in sorted order + results[(i + 1) * KVAL - 1].idx = extra.idx; + results[(i + 1) * KVAL - 1].dist = extra.dist; + for (int j = KVAL - 2; j >= 0; j--) { + results[i * KVAL + j].idx = heapMem[threadIdx.x].vals[0].idx; + results[i * KVAL + j].dist = heapMem[threadIdx.x].vals[0].dist; + heapMem[threadIdx.x].vals[0].dist = -1; + heapMem[threadIdx.x].heapify(); + + } + } + +} + +template +__host__ void GenerateTruthGPUCore(std::shared_ptr querySet, std::shared_ptr vectorSet, const std::string truthFile, + const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr& quantizer, + std::vector< std::vector >&truthset, std::vector< std::vector >&distset) { + int numDevicesOnHost; + CUDA_CHECK(cudaGetDeviceCount(&numDevicesOnHost)); + int NUM_GPUS = numDevicesOnHost; + int metric = distMethod == DistCalcMethod::Cosine ? 1 : 0; + //const int NUM_GPUS = 4; + + std::vector< std::vector > temp_truth = truthset; + std::vector< std::vector > temp_dist = distset; + + int result_size = querySet->Count(); + int vector_size = vectorSet->Count(); + int dim = querySet->Dimension(); + auto vectors = (DTYPE*)vectorSet->GetData(); + int per_gpu_result = (result_size)*K; + + LOG_INFO("QueryDatatype: %s, Rows:%ld, Columns:%d\n", (STR(DTYPE)), result_size, dim); + LOG_INFO("Datatype: %s, Rows:%ld, Columns:%d\n", (STR(DTYPE)), vector_size, dim); + + //Assign vectors to every GPU + std::vector vectorsPerGPU(NUM_GPUS); + for (int gpuNum = 0; gpuNum < NUM_GPUS; ++gpuNum) { + vectorsPerGPU[gpuNum] = vector_size / NUM_GPUS; + if (vector_size % NUM_GPUS > gpuNum) vectorsPerGPU[gpuNum]++; + } + //record the offset in different GPU + std::vector GPUOffset(NUM_GPUS); + GPUOffset[0] = 0; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: results:%lu, offset:%lu\n", vectorsPerGPU[0], GPUOffset[0]); + for (int gpuNum = 1; gpuNum < NUM_GPUS; ++gpuNum) { + GPUOffset[gpuNum] = GPUOffset[gpuNum - 1] + vectorsPerGPU[gpuNum - 1]; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: results:%lu, offset:%lu\n", gpuNum, vectorsPerGPU[gpuNum], GPUOffset[gpuNum]); + } + + std::vector streams(NUM_GPUS); + std::vector batchSize(NUM_GPUS); + + cudaError_t resultErr; + + //KNN + Point* points = convertMatrix < DTYPE, SUMTYPE, MAX_DIM >((DTYPE*)querySet->GetData(), result_size, dim); + Point** d_points = new Point*[NUM_GPUS]; + Point** d_check_points = new Point*[NUM_GPUS]; + Point** sub_vectors_points = new Point*[NUM_GPUS]; + DistPair** d_results = new DistPair*[NUM_GPUS]; // malloc space for each gpu to save bf result + + //Point** d_points = new Point*[NUM_GPUS]; + //TPtree** tptrees = new TPtree*[NUM_GPUS]; + //TPtree** d_tptrees = new TPtree*[NUM_GPUS]; + //int** d_results = new int* [NUM_GPUS]; + // Calculate GPU memory usage on each device + for (int gpuNum = 0; gpuNum < NUM_GPUS; gpuNum++) { + CUDA_CHECK(cudaSetDevice(gpuNum)); + CUDA_CHECK(cudaStreamCreate(&streams[gpuNum])); + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); + + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + + size_t queryPointSize = querySet->Count() * sizeof(Point); + size_t resultSetSize = querySet->Count() * K * 4;//size(int) = 4 + size_t resMemAvail = (freeMem * 0.9) - (queryPointSize + resultSetSize); // Only use 90% of total memory to be safe + size_t cpuSaveMem = resMemAvail / NUM_GPUS;//using 90% of multi-GPUs might be to much for CPU. + int maxEltsPerBatch = cpuSaveMem / (sizeof(Point)); + batchSize[gpuNum] = min(maxEltsPerBatch, (int)(vectorsPerGPU[gpuNum])); + // If GPU memory is insufficient or so limited that we need so many batches it becomes inefficient, return error + if (batchSize[gpuNum] == 0 || ((int)vectorsPerGPU[gpuNum]) / batchSize[gpuNum] > 10000) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build Head index on GPU %d. Available GPU memory:%lu MB, Points and resultsSet require: %lu MB, leaving a maximum batch size of %d results to be computed, which is too small to run efficiently.\n", gpuNum, (freeMem) / 1000000, (queryPointSize + resultSetSize) / 1000000, maxEltsPerBatch); + exit(1); + } + + size_t total_batches = ((batchSize[gpuNum] - 1) + vectorsPerGPU[gpuNum]) / batchSize[gpuNum]; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Memory for query Points vectors:%lu MiB, Memory for resultSet:%lu MiB, Memory left for vectors:%lu MiB, total vectors:%lu, batch size:%d, total batches:%lu\n", queryPointSize / 1000000, resultSetSize / 1000000, resMemAvail / 1000000, vectorsPerGPU[gpuNum], batchSize[gpuNum], total_batches); + + } + + int KNN_blocks = querySet->Count() / KNN_THREADS; + std::vector curr_batch_size(NUM_GPUS); + std::vector batchOffset(NUM_GPUS); + for (int gpuNum = 0; gpuNum < NUM_GPUS; ++gpuNum) { + batchOffset[gpuNum] = 0; + } + + // Allocate space on gpu + for (int i = 0; i < NUM_GPUS; i++) { + cudaSetDevice(i); + cudaStreamCreate(&streams[i]); // Copy data over on a separate stream for each GPU + // Device result memory on each GPU + CUDA_CHECK(cudaMalloc(&d_results[i], per_gpu_result * sizeof(DistPair))); + + //Copy Queryvectors + CUDA_CHECK(cudaMalloc(&d_points[i], querySet->Count() * sizeof(Point))); + CUDA_CHECK(cudaMemcpyAsync(d_points[i], points, querySet->Count() * sizeof(Point), cudaMemcpyHostToDevice, streams[i])); + //Batchvectors + CUDA_CHECK(cudaMalloc(&d_check_points[i], batchSize[i] * sizeof(Point))); + } + + LOG_INFO("Starting KNN Kernel timer\n"); + auto t1 = std::chrono::high_resolution_clock::now(); + bool done = false; + bool update = false; + //RN, The querys are copied, and the space for result are malloced. + // Need to copy the result, and copy back to host to update. + // The vectors split, malloc, copy neeed to be done in here. + while (!done) { // Continue until all GPUs have completed all of their batches + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Avaliable Memory out of %lu MiB total Memory for GPU 0 is %lu MiB\n", totalMem / 1000000, freeMem / 1000000); + + // Prep next batch for each GPU + for (int gpuNum = 0; gpuNum < NUM_GPUS; ++gpuNum) { + curr_batch_size[gpuNum] = batchSize[gpuNum]; + // Check if final batch is smaller than previous + if (batchOffset[gpuNum] + batchSize[gpuNum] > vectorsPerGPU[gpuNum]) { + curr_batch_size[gpuNum] = vectorsPerGPU[gpuNum] - batchOffset[gpuNum]; + } + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - starting batch with offset:%lu, with GPU offset:%lu, size:%lu, TailRows:%lld\n", gpuNum, batchOffset[gpuNum], GPUOffset[gpuNum], curr_batch_size[gpuNum], vectorsPerGPU[gpuNum] -batchOffset[gpuNum]); + } + + CUDA_CHECK(cudaDeviceSynchronize()); + + for (int i = 0; i < NUM_GPUS; i++) { + CUDA_CHECK(cudaSetDevice(i)); + + size_t start = GPUOffset[i] + batchOffset[i]; + sub_vectors_points[i] = convertMatrix < DTYPE, SUMTYPE, MAX_DIM >(&vectors[start * dim], curr_batch_size[i], dim); + CUDA_CHECK(cudaMemcpyAsync(d_check_points[i], sub_vectors_points[i], curr_batch_size[i] * sizeof(Point), cudaMemcpyHostToDevice, streams[i])); + + // Perfrom brute-force KNN from the subsets assigned to the GPU for the querySets + int KNN_blocks = (KNN_THREADS - 1 + querySet->Count()) / KNN_THREADS; + size_t dynamicSharedMem = KNN_THREADS * sizeof(DistPair < SUMTYPE>) * (K - 1); + //64*9*8=4608 for K=10 , KNN_Threads = 64 + //64*99*8=50688 for K = 100, KNN_Threads = 64 + if (dynamicSharedMem > (1024 * 48)) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Cannot Launch CUDA kernel on %d, because of using to much shared memory size, %zu\n", i, dynamicSharedMem); + exit(1); + } + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Launching kernel on %d\n", i); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Launching Parameters: KNN_blocks = %d, KNN_Thread = %d , dynamicSharedMem = %d, \n", KNN_blocks, KNN_THREADS, dynamicSharedMem); + + query_KNN << > > (d_points[i], d_check_points[i], curr_batch_size[i], start, result_size, d_results[i], K, metric); + cudaError_t c_ret = cudaGetLastError(); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Error: %s\n", cudaGetErrorString(c_ret)); + } + + //Copy back to result + for (int gpuNum = 0; gpuNum < NUM_GPUS; gpuNum++) { + cudaSetDevice(gpuNum); + cudaDeviceSynchronize(); + } + DistPair* results = (DistPair*)malloc(per_gpu_result * sizeof(DistPair) * NUM_GPUS); + for (int gpuNum = 0; gpuNum < NUM_GPUS; gpuNum++) { + cudaMemcpy(&results[(gpuNum * per_gpu_result)], d_results[gpuNum], per_gpu_result * sizeof(DistPair), cudaMemcpyDeviceToHost); + } + + // To combine the results, we need to compare value from different GPU + //Results [KNN from GPU0][KNN from GPU1][KNN from GPU2][KNN from GPU3] + // Every block will be [result_size*K] [K|K|K|...|K] + // The real KNN is selected from 4*KNN + size_t* reader_ptrs = (size_t*)malloc(NUM_GPUS * sizeof(size_t)); + for (size_t i = 0; i < result_size; i++) { + //For each result, there should be writer pointer, four reader pointers + // reader ptr is assigned to i+gpu_idx*(result_size*K)/4 + for (size_t gpu_idx = 0; gpu_idx < NUM_GPUS; gpu_idx++) { + reader_ptrs[gpu_idx] = i * K + gpu_idx * (result_size * K); + } + for (size_t writer_ptr = i * K; writer_ptr < (i + 1) * K; writer_ptr++) { + int selected = 0; + for (size_t gpu_idx = 1; gpu_idx < NUM_GPUS; gpu_idx++) { + //if the other gpu has shorter dist + if (results[reader_ptrs[selected]].dist > results[reader_ptrs[gpu_idx]].dist) { + selected = gpu_idx; + } + } + temp_truth[i][writer_ptr - (i * K)] = results[reader_ptrs[selected]].idx; + temp_dist[i][writer_ptr - (i * K)] = results[reader_ptrs[selected]++].dist; + } + } + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Collected all the result from GPUs.\n"); + //update results, just assign to truthset for first batch + if (update) { + updateKNNResults(truthset, distset, temp_truth, temp_dist, result_size, K); + } + else { + truthset = temp_truth; + distset = temp_dist; + update = true; + } + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Updated batch result.\n"); + // Update all batchOffsets and check if done + done = true; + for (int gpuNum = 0; gpuNum < NUM_GPUS; ++gpuNum) { + batchOffset[gpuNum] += curr_batch_size[gpuNum]; + if (batchOffset[gpuNum] < vectorsPerGPU[gpuNum]) { + done = false; + } + //avoid memory leak + free(sub_vectors_points[gpuNum]); + } + free(results); + } // Batches loop (while !done) + + auto t2 = std::chrono::high_resolution_clock::now(); + double gpuRunTime = std::chrono::duration_cast(t2 - t1).count(); + double gpuRunTimeMs = std::chrono::duration_cast(t2 - t1).count(); + + LOG_ALL("Total GPU time (sec): %ld.%ld\n", (int)gpuRunTime, (((int)gpuRunTimeMs)%1000)); +} + +template +void GenerateTruthGPU(std::shared_ptr querySet, std::shared_ptr vectorSet, const std::string truthFile, + const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr& quantizer, + std::vector< std::vector > &truthset, std::vector< std::vector > &distset) { + using SUMTYPE = std::conditional_t::value, float, int32_t>; + int m_iFeatureDim = querySet->Dimension(); + if (typeid(T) == typeid(float)) { + if (m_iFeatureDim <= 64) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 100) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 128) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 184) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 384) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 768) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + //else if (m_iFeatureDim <= 1024) { + // GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + //} + //else if (m_iFeatureDim <= 2048) { + // GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + //} + //else if (m_iFeatureDim <= 4096) { + // GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + //} + else { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%d dimensions not currently supported for GPU generate Truth.\n"); + exit(1); + } + } + else if (typeid(T) == typeid(uint8_t) || typeid(T) == typeid(int8_t)) { + if (m_iFeatureDim <= 64) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 100) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 128) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 184) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 384) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else if (m_iFeatureDim <= 768) { + GenerateTruthGPUCore(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + } + else { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%d dimensions not currently supported for GPU generate Truth.\n"); + exit(1); + } + } + else { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Selected datatype not currently supported.\n"); + exit(1); + } +} +#define DefineVectorValueType(Name, Type) template void GenerateTruthGPU(std::shared_ptr querySet, std::shared_ptr vectorSet, const std::string truthFile, const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr& quantizer, std::vector< std::vector > &truthset, std::vector< std::vector > &distset); +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + #endif diff --git a/AnnService/inc/Core/Common/cuda/PerfTestGPU.hxx b/AnnService/inc/Core/Common/cuda/PerfTestGPU.hxx new file mode 100644 index 000000000..52808d0d0 --- /dev/null +++ b/AnnService/inc/Core/Common/cuda/PerfTestGPU.hxx @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * Licensed under the MIT License. + */ + +#ifndef _SPTAG_COMMON_CUDA_PERFTEST_H_ +#define _SPTAG_COMMON_CUDA_PERFTEST_H_ + +#include "Refine.hxx" +#include "log.hxx" +#include "ThreadHeap.hxx" +#include "TPtree.hxx" +#include "GPUQuantizer.hxx" + +#include +#include + + +/*************************************************************************************** + * Function called by SPTAG to create an initial graph on the GPU. + ***************************************************************************************/ +template +void benchmarkDist(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhoodSize, int trees, int* results, int refines, int refineDepth, int graph, int leafSize, int initSize, int NUM_GPUS, int balanceFactor) { + + int m_disttype = (int)index->GetDistCalcMethod(); + size_t dataSize = (size_t)m_iGraphSize; + int KVAL = m_iNeighborhoodSize; + int dim = index->GetFeatureDim(); + size_t rawSize = dataSize*dim; + + if(index->m_pQuantizer != NULL) { + SPTAG::COMMON::PQQuantizer* pq_quantizer = (SPTAG::COMMON::PQQuantizer*)index->m_pQuantizer.get(); + dim = pq_quantizer->GetNumSubvectors(); + } + +// srand(time(NULL)); // random number seed for TP tree random hyperplane partitions + srand(1); // random number seed for TP tree random hyperplane partitions + +/******************************* + * Error checking +********************************/ + int numDevicesOnHost; + CUDA_CHECK(cudaGetDeviceCount(&numDevicesOnHost)); + if(numDevicesOnHost < NUM_GPUS) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "HeadNumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost); + exit(1); + } + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Building Head graph with %d GPUs...\n", NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS); + +/**** Compute result batch sizes for each GPU ****/ + std::vector batchSize(NUM_GPUS); + std::vector resPerGPU(NUM_GPUS); + for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { + CUDA_CHECK(cudaSetDevice(gpuNum)); + + resPerGPU[gpuNum] = dataSize / NUM_GPUS; // Results per GPU + if(dataSize % NUM_GPUS > gpuNum) resPerGPU[gpuNum]++; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); // Get avil. memory + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); + + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + + size_t rawDataSize, pointSetSize, treeSize, resMemAvail, maxEltsPerBatch; + + rawDataSize = rawSize*sizeof(T); + pointSetSize = sizeof(PointSet); + treeSize = 20*dataSize; + resMemAvail = (freeMem*0.9) - (rawDataSize+pointSetSize+treeSize); // Only use 90% of total memory to be safe + maxEltsPerBatch = resMemAvail / (dim*sizeof(T) + KVAL*sizeof(int)); + + batchSize[gpuNum] = (std::min)(maxEltsPerBatch, resPerGPU[gpuNum]); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Memory for rawData:%lu MiB, pointSet structure:%lu MiB, Memory for TP trees:%lu MiB, Memory left for results:%lu MiB, total vectors:%lu, batch size:%d, total batches:%d\n", rawSize/1000000, pointSetSize/1000000, treeSize/1000000, resMemAvail/1000000, resPerGPU[gpuNum], batchSize[gpuNum], (((batchSize[gpuNum]-1)+resPerGPU[gpuNum]) / batchSize[gpuNum])); + + // If GPU memory is insufficient or so limited that we need so many batches it becomes inefficient, return error + if(batchSize[gpuNum] == 0 || ((int)resPerGPU[gpuNum]) / batchSize[gpuNum] > 10000) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build Head index on GPU %d. Available GPU memory:%lu MB, Points and tpt require:%lu MB, leaving a maximum batch size of %d results to be computed, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (rawDataSize+pointSetSize+treeSize)/1000000, maxEltsPerBatch); + exit(1); + } + } + std::vector GPUOffset(NUM_GPUS); + GPUOffset[0] = 0; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: results:%lu, offset:%lu\n", resPerGPU[0], GPUOffset[0]); + for(int gpuNum=1; gpuNum < NUM_GPUS; ++gpuNum) { + GPUOffset[gpuNum] = GPUOffset[gpuNum-1] + resPerGPU[gpuNum-1]; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: results:%lu, offset:%lu\n", gpuNum, resPerGPU[gpuNum], GPUOffset[gpuNum]); + } + + if(index->m_pQuantizer != NULL) { + buildGraphGPU(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim); + } + else if(typeid(T) == typeid(float)) { + buildGraphGPU(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim); + } + else if(typeid(T) == typeid(uint8_t) || typeid(T) == typeid(int8_t)) { + buildGraphGPU(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim); + } + else { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Selected datatype not currently supported.\n"); + exit(1); + } +} + + +#endif diff --git a/AnnService/inc/Core/Common/cuda/Refine.hxx b/AnnService/inc/Core/Common/cuda/Refine.hxx index 7ed695559..753b57480 100644 --- a/AnnService/inc/Core/Common/cuda/Refine.hxx +++ b/AnnService/inc/Core/Common/cuda/Refine.hxx @@ -396,7 +396,7 @@ void refineGraphGPU(SPTAG::VectorIndex* index, Point* d_point cudaMemcpy(d_candidates, candidates, dataSize*candidatesPerVector*sizeof(int), cudaMemcpyHostToDevice); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(SPTAG::Helper::LogLevel::LL_Info, "find candidates time (ms): %lld\n", std::chrono::duration_cast(t2 - t1).count()); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "find candidates time (ms): %lld\n", std::chrono::duration_cast(t2 - t1).count()); // Use a number of batches of refinement to overlap CPU and GPU work (TODO) int NUM_BATCHES = 1; @@ -414,13 +414,13 @@ void refineGraphGPU(SPTAG::VectorIndex* index, Point* d_point refineBatch_kernel<<>>(d_points, batch_size, i*batch_size, d_graph, d_candidates, listMem, candidatesPerVector, KVAL, refineDepth, metric); cudaError_t status = cudaDeviceSynchronize(); if(status != cudaSuccess) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Refine error code:%d\n", status); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Refine error code:%d\n", status); } } } t2 = std::chrono::high_resolution_clock::now(); - LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU refine time (ms): %lld\n", std::chrono::duration_cast(t2 - t1).count()); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU refine time (ms): %lld\n", std::chrono::duration_cast(t2 - t1).count()); cudaFree(listMem); cudaFree(d_candidates); diff --git a/AnnService/inc/Core/Common/cuda/TPtree.hxx b/AnnService/inc/Core/Common/cuda/TPtree.hxx index 4bbf2471d..e3be6c86e 100644 --- a/AnnService/inc/Core/Common/cuda/TPtree.hxx +++ b/AnnService/inc/Core/Common/cuda/TPtree.hxx @@ -33,8 +33,11 @@ #include #include "params.h" #include "Distance.hxx" +#include "GPUKNNDistance.hxx" class TPtree; +//template +//class TPtree; /************************************************************************************ * Structure that defines the memory locations where points/ids are for a leaf node @@ -80,7 +83,7 @@ __global__ void assign_leaf_points_out_batch(LeafNode* leafs, int* leaf_points, * Set of functions to compute mean to pick dividing hyperplanes ************************************************************************************/ template -__global__ void find_level_sum(PointSet* points, KEYTYPE* weights, int Dim, int* node_ids, KEYTYPE* split_keys, int* node_sizes, int N, int nodes_on_level, int level); +__global__ void find_level_sum(PointSet* points, KEYTYPE* weights, int Dim, int* node_ids, KEYTYPE* split_keys, int* node_sizes, int N, int nodes_on_level, int level, int size); /***************************************************************************************** @@ -203,21 +206,20 @@ class TPtree { // For debugging purposes ************************************************************************************/ __host__ void print_tree() { - printf("nodes:%d, leaves:%d, levels:%d\n", num_nodes, num_leaves, levels); int level_offset; print_level_device<<<1,1>>>(node_sizes, split_keys, 1, leafs, leaf_points); -/* + for(int i=0; i>>(node_sizes+(level_offset), split_keys+level_offset, level_offset+1, leafs, leaf_points); CUDA_CHECK(cudaDeviceSynchronize()); -*/ + // for(int j=0; j __host__ void construct_trees_multigpu(TPtree** d_trees, PointSet** ps, int N, int NUM_GPUS, cudaStream_t* streams, int balanceFactor) { int nodes_on_level=1; + int sample_size; const int RUN_BLOCKS = min(N/THREADS, BLOCKS); const int RAND_BLOCKS = min(N/THREADS, 1024); // Use fewer blocks for kernels using random numbers to cut down memory usage @@ -239,15 +242,18 @@ __host__ void construct_trees_multigpu(TPtree** d_trees, PointSet** ps, int N CUDA_CHECK(cudaMalloc(&states[gpuNum], RAND_BLOCKS*THREADS*sizeof(curandState))); initialize_rands<<>>(states[gpuNum], 0); } - + for(int i=0; ilevels; ++i) { + + sample_size = min(N, nodes_on_level*SAMPLES); // number of samples to use to compute level sums for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { cudaSetDevice(gpuNum); - find_level_sum<<>>(ps[gpuNum], d_trees[gpuNum]->weight_list, d_trees[gpuNum]->Dim, d_trees[gpuNum]->node_ids, d_trees[gpuNum]->split_keys, d_trees[gpuNum]->node_sizes, N, nodes_on_level, i); + find_level_sum<<>>(ps[gpuNum], d_trees[gpuNum]->weight_list, d_trees[gpuNum]->Dim, d_trees[gpuNum]->node_ids, d_trees[gpuNum]->split_keys, d_trees[gpuNum]->node_sizes, N, nodes_on_level, i, sample_size); } -/* TODO - fix rebalancing +// TODO - fix rebalancing +/* // Check and rebalance all levels beyond the first (first level has only 1 node) if(i > 0) { for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { @@ -266,9 +272,13 @@ __host__ void construct_trees_multigpu(TPtree** d_trees, PointSet** ps, int N cudaSetDevice(gpuNum); compute_mean<<>>(d_trees[gpuNum]->split_keys, d_trees[gpuNum]->node_sizes, d_trees[gpuNum]->num_nodes); +CUDA_CHECK(cudaDeviceSynchronize()); update_node_assignments<<>>(ps[gpuNum], d_trees[gpuNum]->weight_list, d_trees[gpuNum]->node_ids, d_trees[gpuNum]->split_keys, d_trees[gpuNum]->node_sizes, N, i, d_trees[gpuNum]->Dim); + +CUDA_CHECK(cudaDeviceSynchronize()); } + nodes_on_level*=2; } @@ -310,12 +320,135 @@ __host__ void construct_trees_multigpu(TPtree** d_trees, PointSet** ps, int N cudaSetDevice(gpuNum); assign_leaf_points_out_batch<<>>(d_trees[gpuNum]->leafs, d_trees[gpuNum]->leaf_points, d_trees[gpuNum]->node_ids, N, d_trees[gpuNum]->num_nodes - d_trees[gpuNum]->num_leaves, 0, N); } +} + +/* +template +__host__ void find_level_sum_batch_PQ(TPtree** d_trees, PointSet** ps, int N, int NUM_GPUS, int nodes_on_level, int level, SPTAG::VectorIndex* index, size_t recon_batch_size) { + + const int RUN_BLOCKS = min(N/THREADS, BLOCKS); + const int RAND_BLOCKS = min(N/THREADS, 1024); // Use fewer blocks for kernels using random numbers to cut down memory usage + + size_t reconDim = index->m_pQuantizer->ReconstructDim(); + R* recon_coords = new R[recon_batch_size*reconDim]; + + R** d_recon_coords = new R*[NUM_GPUS]; // raw reconstructed coordinates on GPU + PointSet** d_recon_ps = new PointSet*[NUM_GPUS]; + + for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { + cudaSetDevice(gpuNum); + + } + for(size_t i=0; im_pQuantizer->ReconstructVector((const uint8_t*)ps->getVec(i+j), &recon_coords[j]); + } + CUDA_CHECK(cudaMemcpy(&d_ + } + +// find_level_sum<<>>(ps[gpuNum], d_trees[gpuNum]->weight_list, d_trees[gpuNum]->Dim, d_trees[gpuNum]->node_ids, d_trees[gpuNum]->split_keys, d_trees[gpuNum]->node_sizes, N, nodes_on_level, level); + + delete recon_coords; + +} +*/ + +template +__host__ void construct_trees_PQ(TPtree** d_trees, PointSet** ps, int N, int NUM_GPUS, cudaStream_t* streams, SPTAG::VectorIndex* index) { + + size_t reconDim = index->m_pQuantizer->ReconstructDim(); + PointSet temp_ps; + temp_ps.dim = reconDim; + R* h_recon_raw = new R[N*reconDim]; + + for(int i=0; im_pQuantizer->ReconstructVector((const uint8_t*)(index->GetSample(i)), &h_recon_raw[i*reconDim]); + } + + R** d_recon_raw = new R*[NUM_GPUS]; + PointSet** d_recon_ps = new PointSet*[NUM_GPUS]; + + for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { + CUDA_CHECK(cudaSetDevice(gpuNum)); + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); // Get avil. memory + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + size_t neededMem = N*reconDim*sizeof(R); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Memory needed for reconstructed vectors to build TPT: %ld, memory availalbe: %ld\n", neededMem, totalMem); + if(freeMem*0.9 < neededMem) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient memory for reconstructed vectors to build TPTree.\n"); + exit(1); + } + + CUDA_CHECK(cudaMalloc(&d_recon_raw[gpuNum], N*reconDim*sizeof(R))); + CUDA_CHECK(cudaMemcpy(d_recon_raw[gpuNum], h_recon_raw, N*reconDim*sizeof(R), cudaMemcpyHostToDevice)); + temp_ps.data = d_recon_raw[gpuNum]; + CUDA_CHECK(cudaMalloc(&d_recon_ps[gpuNum], sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_recon_ps[gpuNum], &temp_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + } + + construct_trees_multigpu(d_trees, d_recon_ps, N, NUM_GPUS, streams, 0); + CUDA_CHECK(cudaDeviceSynchronize()); + + for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { + CUDA_CHECK(cudaFree(d_recon_raw[gpuNum])); + CUDA_CHECK(cudaFree(d_recon_ps[gpuNum])); + } + delete d_recon_raw; + delete d_recon_ps; + delete h_recon_raw; +/* +int min_size=99999999; +int max_size=0; +for(int i=0; inum_leaves; ++i) { + if(d_trees[0]->leafs[i].size > max_size) max_size = d_trees[0]->leafs[i].size; + if(d_trees[0]->leafs[i].size < min_size) min_size = d_trees[0]->leafs[i].size; +} +printf("Num leaves:%d, min leaf:%d, max leaf:%d\n", d_trees[0]->num_leaves, min_size, max_size); +*/ +/* + int nodes_on_level=1; + + const int RUN_BLOCKS = min(N/THREADS, BLOCKS); + const int RAND_BLOCKS = min(N/THREADS, 1024); // Use fewer blocks for kernels using random numbers to cut down memory usage + + float** frac_to_move = new float*[NUM_GPUS]; + curandState** states = new curandState*[NUM_GPUS]; + + recon_batch_size = N; + + // Prepare random number generator + for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { + CUDA_CHECK(cudaSetDevice(gpuNum)); + CUDA_CHECK(cudaMalloc(&frac_to_move[gpuNum], d_trees[0]->num_nodes*sizeof(float))); + CUDA_CHECK(cudaMalloc(&states[gpuNum], RAND_BLOCKS*THREADS*sizeof(curandState))); + initialize_rands<<>>(states[gpuNum], 0); + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); // Get avil. memory + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); + + size_t freeMem, totalMem; + CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); + + size_t gpu_batch_size = (freeMem*0.9) / sizeof(R); + if(gpu_batch_size < recon_batch_size) recon_batch_size = gpu_batch_size; // Use batch size of smallest GPU + } +*/ +// Get all reconstructed vectors on CPU +// for(int i=0; ilevels; ++i) { +// find_level_sum_batch_PQ(d_trees, ps, N, NUM_GPUS, nodes_on_level, i, index, recon_batch_size); +// } } template -__host__ void create_tptree_multigpu(TPtree** d_trees, PointSet** ps, int N, int MAX_LEVELS, int NUM_GPUS, cudaStream_t* streams, int balanceFactor) { +__host__ void create_tptree_multigpu(TPtree** d_trees, PointSet** ps, int N, int MAX_LEVELS, int NUM_GPUS, cudaStream_t* streams, int balanceFactor, SPTAG::VectorIndex* index) { KEYTYPE* h_weights = new KEYTYPE[d_trees[0]->levels*d_trees[0]->Dim]; for(int i=0; ilevels*d_trees[0]->Dim; ++i) { @@ -330,7 +463,20 @@ __host__ void create_tptree_multigpu(TPtree** d_trees, PointSet** ps, int N, } // Build TPT on each GPU - construct_trees_multigpu(d_trees, ps, N, NUM_GPUS, streams, balanceFactor); +// construct_trees_multigpu(d_trees, ps, N, NUM_GPUS, streams, balanceFactor); + + if(index == NULL || index->m_pQuantizer == NULL) { // Build directly if no quantizer + construct_trees_multigpu(d_trees, ps, N, NUM_GPUS, streams, balanceFactor); + } + else { + VectorValueType reconType = index->m_pQuantizer->GetReconstructType(); + if(reconType == SPTAG::VectorValueType::Float) { + construct_trees_PQ(d_trees, ps, N, NUM_GPUS, streams, index); + } + else if (reconType == SPTAG::VectorValueType::Int8) { + construct_trees_PQ(d_trees, ps, N, NUM_GPUS, streams, index); + } + } delete h_weights; } @@ -401,12 +547,11 @@ __device__ KEY_T weighted_val(T* data, KEYTYPE* weights, int Dims, bool print) { */ template -__global__ void find_level_sum(PointSet* ps, KEYTYPE* weights, int Dim, int* node_ids, KEYTYPE* split_keys, int* node_sizes, int N, int nodes_on_level, int level) { +__global__ void find_level_sum(PointSet* ps, KEYTYPE* weights, int Dim, int* node_ids, KEYTYPE* split_keys, int* node_sizes, int N, int nodes_on_level, int level, int size) { KEYTYPE val=0; - int size = min(N, nodes_on_level*SAMPLES); - int step = N/size; for(int i=blockIdx.x*blockDim.x+threadIdx.x; i(ps->getVec(i), &weights[level*Dim], Dim); + atomicAdd(&split_keys[node_ids[i]], val); atomicAdd(&node_sizes[node_ids[i]], 1); } diff --git a/AnnService/inc/Core/Common/cuda/TailNeighbors.hxx b/AnnService/inc/Core/Common/cuda/TailNeighbors.hxx index 3581588c7..a2eeec1f5 100644 --- a/AnnService/inc/Core/Common/cuda/TailNeighbors.hxx +++ b/AnnService/inc/Core/Common/cuda/TailNeighbors.hxx @@ -84,7 +84,7 @@ __device__ void findTailNeighbors(PointSet* headPS, PointSet* tailPS, TPtr candidate_vec = headPS->getVec(candidate.idx); candidate.dist = dist_comp(query, candidate_vec); - if(candidate.dist < max_dist && candidate.idx != tailId) { // If it is a candidate to be added to neighbor list + if(max_dist >= candidate.dist && candidate.idx != tailId) { // If it is a candidate to be added to neighbor list for(read_id=0; candidate.dist > threadList[read_id].dist && good; read_id++) { if(violatesRNG(candidate_vec, headPS->getVec(threadList[read_id].idx), candidate.dist, dist_comp)) { @@ -138,7 +138,7 @@ __device__ void findTailNeighbors(PointSet* headPS, PointSet* tailPS, TPtr * *** Uses Quantizer *** ***********************************************************************************/ template -__device__ void findTailNeighbors_PQ(PointSet* headPS, PointSet* tailPS, TPtree* tptree, int KVAL, DistPair* results, size_t numTails, int numHeads, QueryGroup* groups, DistPair* threadList, GPU_PQQuantizer* quantizer) { +__device__ void findTailNeighbors_PQ(PointSet* headPS, PointSet* tailPS, TPtree* tptree, int KVAL, DistPair* results, size_t numTails, int numHeads, QueryGroup* groups, DistPair* threadList, GPU_Quantizer* quantizer) { uint8_t query[Dim]; float max_dist = INFTY(); @@ -316,7 +316,6 @@ __host__ void get_query_groups(QueryGroup* groups, TPtree* tptree, PointSet* -#define MAX_SHAPE 1024 #define RUN_TAIL_KERNEL(size) \ if(dim <= size) { \ @@ -336,27 +335,31 @@ __host__ void get_query_groups(QueryGroup* groups, TPtree* tptree, PointSet* return; \ } + +#define MAX_SHAPE 1024 + template __global__ void findTailNeighbors_selector(PointSet* headPS, PointSet* tailPS, TPtree* tptree, int KVAL, DistPair* results, size_t curr_batch_size, size_t numHeads, QueryGroup* groups, int dim) { extern __shared__ char sharememory[]; DistPair* threadList = (&((DistPair*)sharememory)[KVAL*threadIdx.x]); - + RUN_TAIL_KERNEL(64) RUN_TAIL_KERNEL(100) - RUN_TAIL_KERNEL(200) - RUN_TAIL_KERNEL(768) -// RUN_TAIL_KERNEL(MAX_SHAPE) + RUN_TAIL_KERNEL(384) + RUN_TAIL_KERNEL(MAX_SHAPE) } -__global__ void findTailNeighbors_PQ_selector(PointSet* headPS, PointSet* tailPS, TPtree* tptree, int KVAL, DistPair* results, size_t curr_batch_size, size_t numHeads, QueryGroup* groups, int dim, GPU_PQQuantizer* quantizer) { +#define MAX_PQ_SHAPE 100 + +__global__ void findTailNeighbors_PQ_selector(PointSet* headPS, PointSet* tailPS, TPtree* tptree, int KVAL, DistPair* results, size_t curr_batch_size, size_t numHeads, QueryGroup* groups, int dim, GPU_Quantizer* quantizer) { extern __shared__ char sharememory[]; DistPair* threadList = (&((DistPair*)sharememory)[KVAL*threadIdx.x]); - RUN_TAIL_KERNEL_PQ(50) - RUN_TAIL_KERNEL_PQ(100) + RUN_TAIL_KERNEL_PQ(50); + RUN_TAIL_KERNEL_PQ(MAX_PQ_SHAPE); } @@ -421,12 +424,12 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head CUDA_CHECK(cudaGetDeviceCount(&numDevicesOnHost)); if(numDevicesOnHost < NUM_GPUS) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "NumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "NumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost); exit(1); } - LOG(SPTAG::Helper::LogLevel::LL_Info, "Building SSD index with %d GPUs...\n", NUM_GPUS); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Building SSD index with %d GPUs...\n", NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS); /***** General variables *****/ bool use_q = (headIndex->m_pQuantizer != NULL); // Using quantization? @@ -454,10 +457,10 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head if(N % NUM_GPUS > gpuNum) pointsPerGPU[gpuNum]++; } GPUPointOffset[0] = 0; - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: points:%lu, offset:%lu\n", pointsPerGPU[0], GPUPointOffset[0]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: points:%lu, offset:%lu\n", pointsPerGPU[0], GPUPointOffset[0]); for(int gpuNum=1; gpuNum < NUM_GPUS; ++gpuNum) { GPUPointOffset[gpuNum] = GPUPointOffset[gpuNum-1] + pointsPerGPU[gpuNum-1]; - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: points:%lu, offset:%lu\n", gpuNum, pointsPerGPU[gpuNum], GPUPointOffset[gpuNum]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: points:%lu, offset:%lu\n", gpuNum, pointsPerGPU[gpuNum], GPUPointOffset[gpuNum]); } // Streams and memory pointers for each GPU @@ -479,16 +482,16 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head std::vector d_queryGroups(NUM_GPUS); // Quantizer structures only used if quantization is enabled - GPU_PQQuantizer* d_quantizer = NULL; - GPU_PQQuantizer* h_quantizer = NULL; + GPU_Quantizer* d_quantizer = NULL; + GPU_Quantizer* h_quantizer = NULL; if(use_q) { - h_quantizer = new GPU_PQQuantizer(headIndex->m_pQuantizer, (DistMetric)metric); - CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_PQQuantizer))); - CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_PQQuantizer), cudaMemcpyHostToDevice)); + h_quantizer = new GPU_Quantizer(headIndex->m_pQuantizer, (DistMetric)metric); + CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_Quantizer))); + CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_Quantizer), cudaMemcpyHostToDevice)); } - LOG(SPTAG::Helper::LogLevel::LL_Info, "Setting up each of the %d GPUs...\n", NUM_GPUS); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Setting up each of the %d GPUs...\n", NUM_GPUS); /*** Compute batch sizes and allocate all data on GPU ***/ for(int gpuNum=0; gpuNum < NUM_GPUS; gpuNum++) { @@ -497,13 +500,13 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head debug_warm_up_gpu<<<1,32,0,streams[gpuNum]>>>(); resultErr = cudaStreamSynchronize(streams[gpuNum]); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU test/warmup complete - kernel status:%d\n", resultErr); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU test/warmup complete - kernel status:%d\n", resultErr); // Get GPU info to compute batch sizes cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); - LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name); size_t freeMem, totalMem; CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem)); @@ -521,11 +524,11 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head * Batch size check and Debug information printing *****************************************************/ if(BATCH_SIZE[gpuNum] == 0 || ((int)pointsPerGPU[gpuNum]) / BATCH_SIZE[gpuNum] > 10000) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build SSD index on GPU %d. Available GPU memory:%lu MB, Head index requires:%lu MB, leaving a maximum batch size of %d elements, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (headVecSize+treeSize)/1000000, maxEltsPerBatch); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build SSD index on GPU %d. Available GPU memory:%lu MB, Head index requires:%lu MB, leaving a maximum batch size of %d elements, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (headVecSize+treeSize)/1000000, maxEltsPerBatch); exit(1); } - LOG(SPTAG::Helper::LogLevel::LL_Info, "Memory for head vectors:%lu MiB, Memory for TP trees:%lu MiB, Memory left for tail vectors:%lu MiB, total tail vectors:%lu, batch size:%d, total batches:%d\n", headVecSize/1000000, treeSize/1000000, tailMemAvail/1000000, pointsPerGPU[gpuNum], BATCH_SIZE[gpuNum], (((BATCH_SIZE[gpuNum]-1)+pointsPerGPU[gpuNum]) / BATCH_SIZE[gpuNum])); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating GPU memory: tail points:%lu MiB, head points:%lu MiB, results:%lu MiB, TPT:%lu MiB, Total:%lu MiB\n", (BATCH_SIZE[gpuNum]*sizeof(T))/1000000, (headRows*sizeof(T))/1000000, (BATCH_SIZE[gpuNum]*RNG_SIZE*sizeof(DistPair))/1000000, (sizeof(TPtree))/1000000, ((BATCH_SIZE[gpuNum]+headRows)*sizeof(T) + (BATCH_SIZE[gpuNum]*RNG_SIZE*sizeof(DistPair)) + (sizeof(TPtree)))/1000000); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Memory for head vectors:%lu MiB, Memory for TP trees:%lu MiB, Memory left for tail vectors:%lu MiB, total tail vectors:%lu, batch size:%d, total batches:%d\n", headVecSize/1000000, treeSize/1000000, tailMemAvail/1000000, pointsPerGPU[gpuNum], BATCH_SIZE[gpuNum], (((BATCH_SIZE[gpuNum]-1)+pointsPerGPU[gpuNum]) / BATCH_SIZE[gpuNum])); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating GPU memory: tail points:%lu MiB, head points:%lu MiB, results:%lu MiB, TPT:%lu MiB, Total:%lu MiB\n", (BATCH_SIZE[gpuNum]*sizeof(T))/1000000, (headRows*sizeof(T))/1000000, (BATCH_SIZE[gpuNum]*RNG_SIZE*sizeof(DistPair))/1000000, (sizeof(TPtree))/1000000, ((BATCH_SIZE[gpuNum]+headRows)*sizeof(T) + (BATCH_SIZE[gpuNum]*RNG_SIZE*sizeof(DistPair)) + (sizeof(TPtree)))/1000000); // Allocate needed memory on the GPU CUDA_CHECK(cudaMalloc(&d_headRaw[gpuNum], headRows*dim*sizeof(T))); @@ -539,7 +542,7 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head tptree[gpuNum] = new TPtree; tptree[gpuNum]->initialize(headRows, TPTlevels, dim); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "tpt structure initialized for %lu head vectors, %d levels, leaf size:%d\n", headRows, TPTlevels, LEAF_SIZE); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "tpt structure initialized for %lu head vectors, %d levels, leaf size:%d\n", headRows, TPTlevels, LEAF_SIZE); // Alloc memory for QuerySet structure CUDA_CHECK(cudaMalloc(&d_queryMem[gpuNum], BATCH_SIZE[gpuNum]*sizeof(int) + 2*tptree[gpuNum]->num_leaves*sizeof(int))); @@ -578,7 +581,7 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head if(offset[gpuNum]+BATCH_SIZE[gpuNum] > pointsPerGPU[gpuNum]) { curr_batch_size[gpuNum] = pointsPerGPU[gpuNum]-offset[gpuNum]; } - LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - starting batch with GPUOffset:%lu, offset:%lu, total offset:%lu, size:%lu, TailRows:%lld\n", gpuNum, GPUPointOffset[gpuNum], offset[gpuNum], GPUPointOffset[gpuNum]+offset[gpuNum], curr_batch_size[gpuNum], pointsPerGPU[gpuNum]); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - starting batch with GPUOffset:%lu, offset:%lu, total offset:%lu, size:%lu, TailRows:%lld\n", gpuNum, GPUPointOffset[gpuNum], offset[gpuNum], GPUPointOffset[gpuNum]+offset[gpuNum], curr_batch_size[gpuNum], pointsPerGPU[gpuNum]); cudaSetDevice(gpuNum); @@ -589,7 +592,7 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head temp_ps.data = d_tailRaw[gpuNum]; CUDA_CHECK(cudaMemcpy(d_tailPS[gpuNum], &temp_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Copied %lu tail points to GPU - kernel status:%d\n", curr_batch_size[gpuNum], resultErr); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Copied %lu tail points to GPU - kernel status:%d\n", curr_batch_size[gpuNum], resultErr); int copy_size = COPY_BUFF_SIZE*RNG_SIZE; for(int i=0; i), cudaMemcpyHostToDevice, streams[gpuNum]); } - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Copying initialized result list to GPU - batch size:%lu - replica count:%d, copy bytes:%lu, GPU offset:%lu, total result offset:%lu, - kernel status:%d\n", curr_batch_size[gpuNum], RNG_SIZE, curr_batch_size[gpuNum]*RNG_SIZE*sizeof(DistPair), GPUPointOffset[gpuNum], (GPUPointOffset[gpuNum]+offset[gpuNum])*RNG_SIZE, resultErr); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Copying initialized result list to GPU - batch size:%lu - replica count:%d, copy bytes:%lu, GPU offset:%lu, total result offset:%lu, - kernel status:%d\n", curr_batch_size[gpuNum], RNG_SIZE, curr_batch_size[gpuNum]*RNG_SIZE*sizeof(DistPair), GPUPointOffset[gpuNum], (GPUPointOffset[gpuNum]+offset[gpuNum])*RNG_SIZE, resultErr); } // OFFSET for batch of vector IDS: GPUPointOffset[gpuNum]+offset[gpuNum]; @@ -611,9 +614,9 @@ void getTailNeighborsTPT(T* vectors, SPTAG::SizeType N, SPTAG::VectorIndex* head auto t1 = std::chrono::high_resolution_clock::now(); // Create TPT on each GPU - create_tptree_multigpu(tptree, d_headPS, headRows, TPTlevels, NUM_GPUS, streams.data(), 2); + create_tptree_multigpu(tptree, d_headPS, headRows, TPTlevels, NUM_GPUS, streams.data(), 2, headIndex); CUDA_CHECK(cudaDeviceSynchronize()); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "TPT %d created on all GPUs\n", tree_id); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "TPT %d created on all GPUs\n", tree_id); // Copy TPTs to each GPU for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { @@ -635,6 +638,10 @@ auto t2 = std::chrono::high_resolution_clock::now(); CUDA_CHECK(cudaSetDevice(gpuNum)); if(!use_q) { + if(dim > MAX_SHAPE) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Input vector dimension is %d, GPU index build of vector dimensions larger than %d not supported.\n", dim, MAX_SHAPE); + } + if(metric == (int)DistMetric::Cosine) { findTailNeighbors_selector<<)*RNG_SIZE*NUM_THREADS, streams[gpuNum]>>>(d_headPS[gpuNum], d_tailPS[gpuNum], d_tptree[gpuNum], RNG_SIZE, d_results[gpuNum], curr_batch_size[gpuNum], headRows, d_queryGroups[gpuNum], dim); } @@ -643,26 +650,29 @@ auto t2 = std::chrono::high_resolution_clock::now(); } } else { + if(dim > MAX_PQ_SHAPE) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Input PQ dimension is %d, GPU index build with PQ dimension larger than %d not supported.\n", dim, MAX_SHAPE); + } findTailNeighbors_PQ_selector<<)*RNG_SIZE*NUM_THREADS, streams[gpuNum]>>>((PointSet*)d_headPS[gpuNum], (PointSet*)d_tailPS[gpuNum], d_tptree[gpuNum], RNG_SIZE, (DistPair*)d_results[gpuNum], curr_batch_size[gpuNum], headRows, d_queryGroups[gpuNum], dim, d_quantizer); } } CUDA_CHECK(cudaDeviceSynchronize()); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished finding neighbors of tails using TPT %d\n", tree_id); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "All GPUs finished finding neighbors of tails using TPT %d\n", tree_id); auto t3 = std::chrono::high_resolution_clock::now(); -LOG(SPTAG::Helper::LogLevel::LL_Debug, "Tree %d complete, time to build tree:%.2lf, time to compute tail neighbors:%.2lf\n", tree_id, GET_CHRONO_TIME(t1, t2), GET_CHRONO_TIME(t2, t3)); +SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Tree %d complete, time to build tree:%.2lf, time to compute tail neighbors:%.2lf\n", tree_id, GET_CHRONO_TIME(t1, t2), GET_CHRONO_TIME(t2, t3)); } // TPT loop - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Batch complete on all GPUs, copying results back to CPU...\n"); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Batch complete on all GPUs, copying results back to CPU...\n"); // Copy results of batch from each GPU to CPU result set for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) { CUDA_CHECK(cudaSetDevice(gpuNum)); -LOG(SPTAG::Helper::LogLevel::LL_Debug, "gpu:%d, copying results size %lu to results offset:%d*%d=%d\n", gpuNum, curr_batch_size[gpuNum], GPUPointOffset[gpuNum]+offset[gpuNum],RNG_SIZE, (GPUPointOffset[gpuNum]+offset[gpuNum])*RNG_SIZE); +SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "gpu:%d, copying results size %lu to results offset:%d*%d=%d\n", gpuNum, curr_batch_size[gpuNum], GPUPointOffset[gpuNum]+offset[gpuNum],RNG_SIZE, (GPUPointOffset[gpuNum]+offset[gpuNum])*RNG_SIZE); size_t copy_size=COPY_BUFF_SIZE; for(int i=0; i(ssd_t1-premem_t).count()) + ((double)std::chrono::duration_cast(ssd_t1-premem_t).count())/1000, ((double)std::chrono::duration_cast(ssd_t2-ssd_t1).count()) + ((double)std::chrono::duration_cast(ssd_t2-ssd_t1).count())/1000, ((double)std::chrono::duration_cast(ssd_t3-ssd_t2).count()) + ((double)std::chrono::duration_cast(ssd_t3-ssd_t2).count())/1000); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Mam alloc time:%0.2lf, GPU time to build index:%.2lf, Memory free time:%.2lf\n", ((double)std::chrono::duration_cast(ssd_t1-premem_t).count()) + ((double)std::chrono::duration_cast(ssd_t1-premem_t).count())/1000, ((double)std::chrono::duration_cast(ssd_t2-ssd_t1).count()) + ((double)std::chrono::duration_cast(ssd_t2-ssd_t1).count())/1000, ((double)std::chrono::duration_cast(ssd_t3-ssd_t2).count()) + ((double)std::chrono::duration_cast(ssd_t3-ssd_t2).count())/1000); } @@ -773,7 +783,7 @@ void GPU_SortSelections(std::vector* selections) { int num_batches = (N + (sortBatchSize-1)) / sortBatchSize; - LOG(SPTAG::Helper::LogLevel::LL_Info, "Sorting final results. Size of result:%ld elements = %0.2lf GB, Available GPU memory:%0.2lf, sorting in %d batches\n", N, ((double)(N*sizeof(GPUEdge))/1000000000.0), ((double)freeMem)/1000000000.0, num_batches); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Sorting final results. Size of result:%ld elements = %0.2lf GB, Available GPU memory:%0.2lf, sorting in %d batches\n", N, ((double)(N*sizeof(GPUEdge))/1000000000.0), ((double)freeMem)/1000000000.0, num_batches); int batchNum=0; @@ -782,7 +792,7 @@ void GPU_SortSelections(std::vector* selections) { merge_mem = new GPUEdge[N]; } - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating %ld bytes on GPU for sorting\n", sortBatchSize*sizeof(GPUEdge)); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Allocating %ld bytes on GPU for sorting\n", sortBatchSize*sizeof(GPUEdge)); GPUEdge* d_selections; CUDA_CHECK(cudaMalloc(&d_selections, sortBatchSize*sizeof(GPUEdge))); @@ -794,7 +804,7 @@ void GPU_SortSelections(std::vector* selections) { if(startIdx + batchSize > N) { batchSize = N - startIdx; } - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Sorting batch id:%ld, size:%ld\n", startIdx, batchSize); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Sorting batch id:%ld, size:%ld\n", startIdx, batchSize); GPUEdge* batchPtr = &(new_selections->data()[startIdx]); @@ -803,7 +813,7 @@ void GPU_SortSelections(std::vector* selections) { thrust::sort(thrust::device, d_selections, d_selections+batchSize, gpu_edgeComparer); } catch (thrust::system_error &e){ - LOG(SPTAG::Helper::LogLevel::LL_Info, "Error: %s \n",e.what()); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Error: %s \n",e.what()); } CUDA_CHECK(cudaMemcpy(batchPtr, d_selections, batchSize*sizeof(GPUEdge), cudaMemcpyDeviceToHost)); @@ -822,7 +832,7 @@ void GPU_SortSelections(std::vector* selections) { auto t3 = std::chrono::high_resolution_clock::now(); - LOG(SPTAG::Helper::LogLevel::LL_Debug, "Sort batch %d - GPU transfer/sort time:%0.2lf, CPU merge time:%.2lf\n", batchNum, ((double)std::chrono::duration_cast(t2-t1).count()) + ((double)std::chrono::duration_cast(t2-t1).count())/1000, ((double)std::chrono::duration_cast(t3-t2).count()) + ((double)std::chrono::duration_cast(t3-t2).count())/1000); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Sort batch %d - GPU transfer/sort time:%0.2lf, CPU merge time:%.2lf\n", batchNum, ((double)std::chrono::duration_cast(t2-t1).count()) + ((double)std::chrono::duration_cast(t2-t1).count())/1000, ((double)std::chrono::duration_cast(t3-t2).count()) + ((double)std::chrono::duration_cast(t3-t2).count())/1000); batchNum++; } if(num_batches > 1) { diff --git a/AnnService/inc/Core/Common/cuda/ThreadHeap.hxx b/AnnService/inc/Core/Common/cuda/ThreadHeap.hxx index 97a67c8c9..28f8ac310 100644 --- a/AnnService/inc/Core/Common/cuda/ThreadHeap.hxx +++ b/AnnService/inc/Core/Common/cuda/ThreadHeap.hxx @@ -27,6 +27,7 @@ #include #include "Distance.hxx" +#include "GPUKNNDistance.hxx" // Object used to store id,distance combination for KNN graph @@ -73,8 +74,8 @@ class ThreadHeap { while(2*i+2 < KVAL) { swapDest = 2*i; - swapDest += (vals[i].dist < vals[2*i+1].dist && vals[2*i+2].dist <= vals[2*i+1].dist); - swapDest += 2*(vals[i].dist < vals[2*i+2].dist && vals[2*i+1].dist < vals[2*i+2].dist); + swapDest += (vals[2*i+1].dist >= vals[i].dist && vals[2*i+1].dist > vals[2*i+2].dist); + swapDest += 2*(vals[2*i+2].dist >= vals[i].dist && vals[2*i+2].dist >= vals[2*i+1].dist); if(swapDest == 2*i) return; @@ -89,10 +90,10 @@ class ThreadHeap { int i=idx; int swapDest=0; - while(2*i+2 < KVAL) { + while(KVAL >= 2*i+2) { swapDest = 2*i; - swapDest += (vals[i].dist < vals[2*i+1].dist && vals[2*i+2].dist <= vals[2*i+1].dist); - swapDest += 2*(vals[i].dist < vals[2*i+2].dist && vals[2*i+1].dist < vals[2*i+2].dist); + swapDest += (vals[2*i+1].dist >= vals[i].dist && vals[2*i+1].dist > vals[2*i+2].dist); + swapDest += 2*(vals[2*i+2].dist >= vals[i].dist && vals[2*i+2].dist >= vals[2*i+1].dist); if(swapDest == 2*i) return; @@ -164,7 +165,7 @@ class ThreadHeap { heapifyAt(idx); } -/* + // Load a sorted set of vectors into the heap __device__ void load_mem_sorted(Point* data, int* mem, Point query, int metric) { for(int i=0; i<=KVAL-1; i++) { @@ -187,7 +188,7 @@ class ThreadHeap { __device__ SUMTYPE top() { return vals[0].dist; } -*/ + }; #endif diff --git a/AnnService/inc/Core/Common/cuda/log.hxx b/AnnService/inc/Core/Common/cuda/log.hxx new file mode 100644 index 000000000..fc9d26e8d --- /dev/null +++ b/AnnService/inc/Core/Common/cuda/log.hxx @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +//#include +#include +#include +#include +using namespace std; + +#include "params.h" + +enum log_level_t { + LOG_NOTHING, + LOG_CRITICAL, + LOG_ERROR, + LOG_WARNING, + LOG_INFO, + LOG_DEBUG +}; + + +#define LOG_ALL(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_ALL(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0) printf((f_), ##__VA_ARGS__); } + +#if LOG_LEVEL >= 1 +#define LOG_CRIT(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_CRIT(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0) printf((f_), ##__VA_ARGS__); } +#else +#define LOG_CRIT(f_, ...) {} +#define DLOG_CRIT(f_, ...) {} +#endif + +#if LOG_LEVEL >= 2 +#define LOG_ERR(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_ERR(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0) printf((f_), ##__VA_ARGS__); } +#else +#define LOG_ERR(f_, ...) {} +#define DLOG_ERR(f_, ...) {} +#endif + +#if LOG_LEVEL >= 3 +#define LOG_WARN(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_WARN(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0) printf((f_), ##__VA_ARGS__); } +#else +#define LOG_WARN(f_, ...) {} +#define DLOG_WARN(f_, ...) {} +#endif + +#if LOG_LEVEL >= 4 +#define LOG_INFO(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_INFO(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0) printf((f_), ##__VA_ARGS__); } +#else +#define LOG_INFO(f_, ...) {} +#define DLOG_INFO(f_, ...) {} +#endif + +#if LOG_LEVEL >= 5 +#define LOG_DEBUG(f_, ...) printf((f_), ##__VA_ARGS__) +#define DLOG_DEBUG(f_, ...) {if(threadIdx.x==0 && blockIdx.x==0) printf((f_), ##__VA_ARGS__); } +#else +#define LOG_DEBUG(f_, ...) {} +#define DLOG_DEBUG(f_, ...) {} +#endif + +#define STR_EXPAND(arg) #arg +#define STR(arg) STR_EXPAND(arg) diff --git a/AnnService/inc/Core/Common/cuda/params.h b/AnnService/inc/Core/Common/cuda/params.h index d79ef9e86..ce34f6300 100644 --- a/AnnService/inc/Core/Common/cuda/params.h +++ b/AnnService/inc/Core/Common/cuda/params.h @@ -28,13 +28,19 @@ #include #include + /****************************************************** * Parameters that have been optimized experimentally ******************************************************/ #define THREADS 64 // Number of threads per block +#define KNN_THREADS 32 // Number of threads per block #define BLOCKS 10240 // Total blocks used #define SAMPLES 5000 // number of samples used to determine median for TPT construction #define KEYTYPE float // Keys used to divide TPTs at each node +#define ILP 1 // Increase of ILP using registers in distance calculations +#define TPT_ITERS 1 // Number of random sets of weights tried for each level +#define TPT_PART_DIMS D // Number of dimensions used to create hyperplane +#define REFINE_DEPTH 1 // Depth of refinement step. No refinement if 0 #define REORDER 1 // Option to re-order queries for perf improvement (1 = reorder, 0 = no reorder) diff --git a/AnnService/inc/Core/Common/cuda/tests/Makefile b/AnnService/inc/Core/Common/cuda/tests/Makefile deleted file mode 100644 index a1580e994..000000000 --- a/AnnService/inc/Core/Common/cuda/tests/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -ARCH=sm_70 - -CUDA_HOME=/usr/local/cuda -FLAGS=-std=c++14 -Xptxas -O3 - -all: test_dists - -test_dists: cuda_unit_tests.cu ../KNN.hxx ../Distance.hxx - nvcc ${FLAGS} -lineinfo --use_fast_math --gpu-architecture=${ARCH} -I${CUDA_HOME}/include -I ../../../../../ -o test_dists cuda_unit_tests.cu diff --git a/AnnService/inc/Core/Common/cuda/tests/cuda_unit_tests.cu b/AnnService/inc/Core/Common/cuda/tests/cuda_unit_tests.cu deleted file mode 100644 index 8e66df0f6..000000000 --- a/AnnService/inc/Core/Common/cuda/tests/cuda_unit_tests.cu +++ /dev/null @@ -1,138 +0,0 @@ -#include "../Distance.hxx" -#include "../KNN.hxx" -#include "../params.h" -#include -#include - - -float* create_dataset(size_t rows, int dim) { - - srand(0); - float* data = new float[rows*dim]; - for(size_t i=0; i (rand()) / static_cast (RAND_MAX); - } - return data; -} - -template -__global__ void test_KNN(PointSet* ps, int* results, int rows, int K) { - - extern __shared__ char sharememory[]; - DistPair* threadList = (&((DistPair*)sharememory)[K*threadIdx.x]); - - T query[Dim]; - T candidate_vec[Dim]; - - DistPair target; - DistPair candidate; - DistPair temp; - - bool good; - float max_dist = INFTY(); - int read_id, write_id; - - float (*dist_comp)(T*,T*) = cosine; - - for(size_t i=blockIdx.x*blockDim.x + threadIdx.x; igetVec(i)[j]; - - for(int k=0; k(); - } - } - - for(size_t j=0; jgetVec(i)); - - - if(candidate.dist < max_dist) { - for(read_id=0; candidate.dist > threadList[read_id].dist && good; read_id++) { - if(violatesRNG(candidate_vec, ps->getVec(threadList[read_id].idx), candidate.dist, dist_comp)) { - good = false; - } - } - if(good) { - target = threadList[read_id]; - threadList[read_id] = candidate; - read_id++; - for(write_id = read_id; read_id < K && threadList[read_id].idx != -1; read_id++) { - if(!violatesRNG(ps->getVec(threadList[read_id].idx), candidate_vec, threadList[read_id].dist, dist_comp)) { - if(read_id == write_id) { - temp = threadList[read_id]; - threadList[write_id] = target; - target = temp; - } - else { - threadList[write_id] = target; - target = threadList[read_id]; - } - write_id++; - } - } - if(write_id < K) { - threadList[write_id] = target; - write_id++; - } - for(int k=write_id; k(); - threadList[k].idx = -1; - } - max_dist = threadList[K-1].dist; - } - - } - - } - for(size_t j=0; j h_ps; - h_ps.dim = dim; - h_ps.data = d_data; - - PointSet* d_ps; - - CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); - CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); - - auto start_t = std::chrono::high_resolution_clock::now(); - - test_KNN<<<1024, 64, K*THREADS*sizeof(DistPair)>>>(d_ps, d_results, rows, K); - CUDA_CHECK(cudaDeviceSynchronize()); - - auto end_t = std::chrono::high_resolution_clock::now(); - - printf("Total time: %0.2lf\n", GET_CHRONO_TIME(start_t, end_t)); - - int* h_results = new int[K*rows]; - CUDA_CHECK(cudaMemcpy(h_results, d_results, rows*K*sizeof(int), cudaMemcpyDeviceToHost)); - for(int i=0; i<100; ++i) printf("%d, ", h_results[i*100]); - printf("%d, %d, ..., %d\n", h_results[0], h_results[1], h_results[rows*K-1]); - - return 0; -} diff --git a/AnnService/inc/Core/DefinitionList.h b/AnnService/inc/Core/DefinitionList.h index 812cc3cd0..b4919681c 100644 --- a/AnnService/inc/Core/DefinitionList.h +++ b/AnnService/inc/Core/DefinitionList.h @@ -82,6 +82,12 @@ DefineErrorCode(ReadIni_DuplicatedParam, 0x3003) DefineErrorCode(Socket_FailedResolveEndPoint, 0x4000) DefineErrorCode(Socket_FailedConnectToEndPoint, 0x4001) +// 0x5000 ~ -0x5FFF Block File Status +DefineErrorCode(Key_OverFlow, 0x5000) +DefineErrorCode(Key_NotFound, 0x5001) +DefineErrorCode(Posting_OverFlow, 0x5002) +DefineErrorCode(Posting_SizeError, 0x5003) +DefineErrorCode(Block_IDError, 0x5004) #endif // DefineErrorCode @@ -125,4 +131,27 @@ DefineTruthFileType(XVEC) // row(int32_t), column(int32_t), data... DefineTruthFileType(DEFAULT) -#endif // DefineTruthFileType \ No newline at end of file +#endif // DefineTruthFileType + +#ifdef DefineNumaStrategy + +DefineNumaStrategy(LOCAL) +DefineNumaStrategy(SCATTER) + +#endif // DefineNumaStrategy + +#ifdef DefineOrderStrategy + +DefineOrderStrategy(ASC) +DefineOrderStrategy(DESC) + +#endif // DefineOrderStrategy + +#ifdef DefineStorage + +DefineStorage(STATIC) +DefineStorage(FILEIO) +DefineStorage(SPDKIO) +DefineStorage(ROCKSDBIO) + +#endif // DefineStorage diff --git a/AnnService/inc/Core/KDT/Index.h b/AnnService/inc/Core/KDT/Index.h index 8286f48bb..9d11dac43 100644 --- a/AnnService/inc/Core/KDT/Index.h +++ b/AnnService/inc/Core/KDT/Index.h @@ -40,6 +40,9 @@ namespace SPTAG class RebuildJob : public Helper::ThreadPool::Job { public: RebuildJob(COMMON::Dataset* p_data, COMMON::KDTree* p_tree, COMMON::RelativeNeighborhoodGraph* p_graph) : m_data(p_data), m_tree(p_tree), m_graph(p_graph) {} + void exec(void* p_workSpace, IAbortOperation* p_abort) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot support job.exec(workspace, abort)!\n"); + } void exec(IAbortOperation* p_abort) { m_tree->Rebuild(*m_data, p_abort); } @@ -70,7 +73,6 @@ namespace SPTAG std::shared_timed_mutex m_dataDeleteLock; COMMON::Labelset m_deletedID; - std::unique_ptr> m_workSpacePool; Helper::ThreadPool m_threadPool; int m_iNumberOfThreads; @@ -83,6 +85,7 @@ namespace SPTAG int m_iNumberOfInitialDynamicPivots; int m_iNumberOfOtherDynamicPivots; int m_iHashTableExp; + std::unique_ptr> m_workSpaceFactory; public: Index() @@ -96,6 +99,7 @@ namespace SPTAG m_pSamples.SetName("Vector"); m_fComputeDistance = std::function(COMMON::DistanceCalcSelector(m_iDistCalcMethod)); m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; + m_workSpaceFactory = std::make_unique>(); } ~Index() {} @@ -120,6 +124,9 @@ namespace SPTAG return 1.0f - xy / (sqrt(xx) * sqrt(yy)); } inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); } + inline float GetDistance(const void* target, const SizeType idx) const { + return ComputeDistance(target, m_pSamples.At(idx)); + } inline const void* GetSample(const SizeType idx) const { return (void*)m_pSamples[idx]; } inline bool ContainSample(const SizeType idx) const { return idx >= 0 && idx < m_deletedID.R() && !m_deletedID.Contains(idx); } inline bool NeedRefine() const { return m_deletedID.Count() > (size_t)(GetNumSamples() * m_fDeletePercentageForRefine); } @@ -152,9 +159,19 @@ namespace SPTAG ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized = false, bool p_shareOwnership = false); ErrorCode SearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const; + + std::shared_ptr GetIterator(const void* p_target, bool p_searchDeleted = false, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + ErrorCode SearchIndexIterativeNext(QueryResult& p_query, COMMON::WorkSpace* workSpace, int p_batch, int& resultCount, bool p_isFirst, bool p_searchDeleted) const; + ErrorCode SearchIndexIterativeEnd(std::unique_ptr workSpace) const; + bool SearchIndexIterativeFromNeareast(QueryResult& p_query, COMMON::WorkSpace* p_space, bool p_isFirst, bool p_searchDeleted = false) const; + std::unique_ptr RentWorkSpace(int batch, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + + ErrorCode SearchIndexWithFilter(QueryResult& p_query, std::function filterFunc, int maxCheck = 0, bool p_searchDeleted = false) const; ErrorCode RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const; ErrorCode SearchTree(QueryResult &p_query) const; ErrorCode AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false, bool p_normalized = false); + ErrorCode AddIndexIdx(SizeType begin, SizeType end); + ErrorCode AddIndexId(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, int& beginHead, int& endHead); ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum); ErrorCode DeleteIndex(const SizeType& p_id); @@ -162,12 +179,33 @@ namespace SPTAG std::string GetParameter(const char* p_param, const char* p_section = nullptr) const; ErrorCode UpdateIndex(); - ErrorCode RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort); + ErrorCode RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping); ErrorCode RefineIndex(std::shared_ptr& p_newIndex); + ErrorCode SetWorkSpaceFactory(std::unique_ptr> up_workSpaceFactory) + { + SPTAG::COMMON::IWorkSpaceFactory* raw_generic_ptr = up_workSpaceFactory.release(); + if (!raw_generic_ptr) return ErrorCode::Fail; + + + SPTAG::COMMON::IWorkSpaceFactory* raw_specialized_ptr = dynamic_cast*>(raw_generic_ptr); + if (!raw_specialized_ptr) + { + delete raw_generic_ptr; + return ErrorCode::Fail; + } + else + { + m_workSpaceFactory = std::unique_ptr>(raw_specialized_ptr); + return ErrorCode::Success; + } + } private: - void SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const; - void SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const; + template + void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted) const; + template + void Search(COMMON::QueryResultSet& p_query, COMMON::WorkSpace& p_space) const; }; } // namespace KDT } // namespace SPTAG diff --git a/AnnService/inc/Core/MetadataSet.h b/AnnService/inc/Core/MetadataSet.h index 0e67b9c2c..d16c037ff 100644 --- a/AnnService/inc/Core/MetadataSet.h +++ b/AnnService/inc/Core/MetadataSet.h @@ -4,6 +4,7 @@ #ifndef _SPTAG_METADATASET_H_ #define _SPTAG_METADATASET_H_ +#include "inc/Helper/DiskIO.h" #include "CommonDataStructure.h" namespace SPTAG diff --git a/AnnService/inc/Core/MultiIndexScan.h b/AnnService/inc/Core/MultiIndexScan.h new file mode 100644 index 000000000..71dbeb012 --- /dev/null +++ b/AnnService/inc/Core/MultiIndexScan.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_MULTI_INDEX_SCAN_H +#define _SPTAG_MULTI_INDEX_SCAN_H + +#include + +#include +#include +#include +#include +#include +#include + +#include "ResultIterator.h" +#include "VectorIndex.h" +#include +namespace SPTAG +{ + class MultiIndexScan + { + public: + MultiIndexScan(); + MultiIndexScan(std::vector> vecIndices, + std::vector p_targets, + unsigned int k, + float (*rankFunction)(std::vector), + bool useTimer, + int termCondVal, + int searchLimit + ); + ~MultiIndexScan(); + void Init(std::vector> vecIndices, + std::vector p_targets, + std::vector weight, + unsigned int k, + bool useTimer, + int termCondVal, + int searchLimit); + bool Next(BasicResult& result); + void Close(); + + private: + std::vector> indexIters; + std::vector> fwdLUTs; + std::unordered_set seenSet; + std::vector p_data_array; + std::vector weight; + + unsigned int k; + + + + bool useTimer; + unsigned int termCondVal; + int searchLimit; + std::chrono::time_point t_start; + + float (*func)(std::vector); + + unsigned int consecutive_drops; + + bool terminate; + using pq_item = std::pair; + class pq_item_compare + { + public: + bool operator()(const pq_item& lhs, const pq_item& rhs) + { + return lhs.first < rhs.first; + } + }; + std::priority_queue, pq_item_compare> pq; + std::stack outputStk; + float WeightedRankFunc(std::vector); + + }; +} // namespace SPTAG +#endif diff --git a/AnnService/inc/Core/ResultIterator.h b/AnnService/inc/Core/ResultIterator.h new file mode 100644 index 000000000..7c0c12142 --- /dev/null +++ b/AnnService/inc/Core/ResultIterator.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_RESULT_ITERATOR_H +#define _SPTAG_RESULT_ITERATOR_H + +#include + +#include "VectorIndex.h" +#include "SearchQuery.h" + +typedef SPTAG::VectorIndex VectorIndex; +typedef SPTAG::ByteArray ByteArray; +typedef SPTAG::QueryResult QueryResult; + +class ResultIterator +{ +public: + ResultIterator(const void* p_index, const void* p_target, bool p_searchDeleted, int p_workspaceBatch, std::function p_filterFunc, int p_maxCheck); + + ~ResultIterator(); + + void* GetWorkSpace(); + + virtual std::shared_ptr Next(int batch); + + virtual bool GetRelaxedMono(); + + virtual SPTAG::ErrorCode GetErrorCode(); + + virtual void Close(); + + const void* GetTarget(); + +protected: + const VectorIndex* m_index; + const void* m_target; + ByteArray m_byte_target; + std::shared_ptr m_queryResult; + void* m_workspace; + bool m_searchDeleted; + bool m_isFirstResult; + int m_batch = 1; +}; + +#endif \ No newline at end of file diff --git a/AnnService/inc/Core/SPANN/Compressor.h b/AnnService/inc/Core/SPANN/Compressor.h index 47d740ef2..21f53a6b8 100644 --- a/AnnService/inc/Core/SPANN/Compressor.h +++ b/AnnService/inc/Core/SPANN/Compressor.h @@ -5,14 +5,19 @@ #define _SPTAG_SPANN_COMPRESSOR_H_ #include -#include "zstd.h" -#include "zdict.h" +//#include "zstd.h" +//#include "zdict.h" +#ifdef ZSTD +#include +#include +#endif #include "inc/Core/Common.h" namespace SPTAG { namespace SPANN { + #ifdef ZSTD class Compressor { private: @@ -21,7 +26,7 @@ namespace SPTAG cdict = ZSTD_createCDict((void *)dictBuffer.data(), dictBuffer.size(), compress_level); if (cdict == NULL) { - LOG(Helper::LogLevel::LL_Error, "ZSTD_createCDict() failed! \n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD_createCDict() failed! \n"); throw std::runtime_error("ZSTD_createCDict() failed!"); } } @@ -31,7 +36,7 @@ namespace SPTAG ddict = ZSTD_createDDict((void *)dictBuffer.data(), dictBuffer.size()); if (ddict == NULL) { - LOG(Helper::LogLevel::LL_Error, "ZSTD_createDDict() failed! \n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD_createDDict() failed! \n"); throw std::runtime_error("ZSTD_createDDict() failed!"); } } @@ -45,13 +50,13 @@ namespace SPTAG ZSTD_CCtx *const cctx = ZSTD_createCCtx(); if (cctx == NULL) { - LOG(Helper::LogLevel::LL_Error, "ZSTD_createCCtx() failed! \n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD_createCCtx() failed! \n"); throw std::runtime_error("ZSTD_createCCtx() failed!"); } size_t compressed_size = ZSTD_compress_usingCDict(cctx, (void *)comp_buffer.data(), est_compress_size, src.data(), src.size(), cdict); if (ZSTD_isError(compressed_size)) { - LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); throw std::runtime_error("ZSTD compress error"); } ZSTD_freeCCtx(cctx); @@ -66,14 +71,14 @@ namespace SPTAG ZSTD_DCtx* const dctx = ZSTD_createDCtx(); if (dctx == NULL) { - LOG(Helper::LogLevel::LL_Error, "ZSTD_createDCtx() failed! \n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD_createDCtx() failed! \n"); throw std::runtime_error("ZSTD_createDCtx() failed!"); } std::size_t const decomp_size = ZSTD_decompress_usingDDict(dctx, (void*)dst, dstCapacity, src, srcSize, ddict); if (ZSTD_isError(decomp_size)) { - LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); throw std::runtime_error("ZSTD decompress failed."); } ZSTD_freeDCtx(dctx); @@ -89,7 +94,7 @@ namespace SPTAG src.data(), src.size(), compress_level); if (ZSTD_isError(compressed_size)) { - LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD compress error %s, \n", ZSTD_getErrorName(compressed_size)); throw std::runtime_error("ZSTD compress error"); } buffer.resize(compressed_size); @@ -104,7 +109,7 @@ namespace SPTAG (void *)dst, dstCapacity, src, srcSize); if (ZSTD_isError(decomp_size)) { - LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZSTD decompress error %s, \n", ZSTD_getErrorName(decomp_size)); throw std::runtime_error("ZSTD decompress failed."); } @@ -128,7 +133,7 @@ namespace SPTAG size_t dictSize = ZDICT_trainFromBuffer((void *)dictBuffer.data(), dictBufferCapacity, (void *)samplesBuffer.data(), &samplesSizes[0], nbSamples); if (ZDICT_isError(dictSize)) { - LOG(Helper::LogLevel::LL_Error, "ZDICT_trainFromBuffer() failed: %s \n", ZDICT_getErrorName(dictSize)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ZDICT_trainFromBuffer() failed: %s \n", ZDICT_getErrorName(dictSize)); throw std::runtime_error("ZDICT_trainFromBuffer() failed"); } dictBuffer.resize(dictSize); @@ -175,6 +180,51 @@ namespace SPTAG ZSTD_CDict *cdict; ZSTD_DDict *ddict; }; + #else + class Compressor + { + public: + Compressor(int level = 0, int bufferCapacity = 102400) + { + } + + virtual ~Compressor() + { + } + + std::size_t TrainDict(const std::string &samplesBuffer, const size_t *samplesSizes, unsigned nbSamples) + { + return 0; + } + + std::string GetDictBuffer() + { + return ""; + } + + void SetDictBuffer(const std::string &buffer) + { + } + + std::string Compress(const std::string &src, const bool useDict) + { + + return src; + } + + std::size_t Decompress(const char *src, size_t srcSize, char *dst, size_t dstCapacity, const bool useDict) + { + memcpy(dst, src, srcSize); + return srcSize; + } + + // return the compressed sie + size_t GetCompressedSize(const std::string &src, bool useDict) + { + return src.size(); + } + }; + #endif } // SPANN } // SPTAG diff --git a/AnnService/inc/Core/SPANN/ExtraDynamicSearcher.h b/AnnService/inc/Core/SPANN/ExtraDynamicSearcher.h new file mode 100644 index 000000000..d5b88a8ae --- /dev/null +++ b/AnnService/inc/Core/SPANN/ExtraDynamicSearcher.h @@ -0,0 +1,2614 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SPANN_EXTRADYNAMICSEARCHER_H_ +#define _SPTAG_SPANN_EXTRADYNAMICSEARCHER_H_ + +#include "inc/Helper/VectorSetReader.h" +#include "inc/Helper/AsyncFileReader.h" +#include "IExtraSearcher.h" +#include "ExtraStaticSearcher.h" +#include "inc/Core/Common/TruthSet.h" +#include "inc/Helper/KeyValueIO.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Core/Common/FineGrainedLock.h" +#include "inc/Core/Common/Checksum.h" +#include "PersistentBuffer.h" +#include "inc/Core/Common/PostingSizeRecord.h" +#include "ExtraFileController.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef SPDK +#include "ExtraSPDKController.h" +#endif + +#ifdef ROCKSDB +#include "ExtraRocksDBController.h" +// enable rocksdb io_uring +extern "C" bool RocksDbIOUringEnable() { return true; } +#endif + +namespace SPTAG::SPANN { + template + class ExtraDynamicSearcher : public IExtraSearcher + { + struct AppendPair + { + std::string BKTID; + int headID; + std::string posting; + + AppendPair(std::string p_BKTID = "", int p_headID = -1, std::string p_posting = "") : BKTID(p_BKTID), headID(p_headID), posting(p_posting) {} + inline bool operator < (const AppendPair& rhs) const + { + return std::strcmp(BKTID.c_str(), rhs.BKTID.c_str()) < 0; + } + + inline bool operator > (const AppendPair& rhs) const + { + return std::strcmp(BKTID.c_str(), rhs.BKTID.c_str()) > 0; + } + }; + + class MergeAsyncJob : public Helper::ThreadPool::Job + { + private: + VectorIndex* m_index; + ExtraDynamicSearcher* m_extraIndex; + SizeType headID; + bool disableReassign; + std::function m_callback; + public: + MergeAsyncJob(VectorIndex* headIndex, ExtraDynamicSearcher* extraIndex, SizeType headID, bool disableReassign, std::function p_callback) + : m_index(headIndex), m_extraIndex(extraIndex), headID(headID), disableReassign(disableReassign), m_callback(std::move(p_callback)) {} + + ~MergeAsyncJob() {} + inline void exec(IAbortOperation* p_abort) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot support job.exec(abort)!\n"); + } + inline void exec(void* p_workSpace, IAbortOperation* p_abort) override { + ErrorCode ret = m_extraIndex->MergePostings((ExtraWorkSpace*)p_workSpace, m_index, headID, !disableReassign); + if (ret != ErrorCode::Success) + m_extraIndex->m_asyncStatus = ret; + if (m_callback != nullptr) { + m_callback(); + } + } + }; + + class SplitAsyncJob : public Helper::ThreadPool::Job + { + private: + VectorIndex* m_index; + ExtraDynamicSearcher* m_extraIndex; + SizeType headID; + bool disableReassign; + std::function m_callback; + public: + SplitAsyncJob(VectorIndex* headIndex, ExtraDynamicSearcher* extraIndex, SizeType headID, bool disableReassign, std::function p_callback) + : m_index(headIndex), m_extraIndex(extraIndex), headID(headID), disableReassign(disableReassign), m_callback(std::move(p_callback)) {} + + ~SplitAsyncJob() {} + inline void exec(IAbortOperation* p_abort) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot support job.exec(abort)!\n"); + } + inline void exec(void* p_workSpace, IAbortOperation* p_abort) override { + ErrorCode ret = m_extraIndex->Split((ExtraWorkSpace*)p_workSpace, m_index, headID, !disableReassign); + if (ret != ErrorCode::Success) + m_extraIndex->m_asyncStatus = ret; + if (m_callback != nullptr) { + m_callback(); + } + } + }; + + class ReassignAsyncJob : public Helper::ThreadPool::Job + { + private: + VectorIndex* m_index; + ExtraDynamicSearcher* m_extraIndex; + std::shared_ptr vectorInfo; + SizeType HeadPrev; + std::function m_callback; + public: + ReassignAsyncJob(VectorIndex* headIndex, ExtraDynamicSearcher* extraIndex, + std::shared_ptr vectorInfo, SizeType HeadPrev, std::function p_callback) + : m_index(headIndex), m_extraIndex(extraIndex), vectorInfo(std::move(vectorInfo)), HeadPrev(HeadPrev), m_callback(std::move(p_callback)) {} + + ~ReassignAsyncJob() {} + + inline void exec(IAbortOperation* p_abort) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot support job.exec(abort)!\n"); + } + + void exec(void* p_workSpace, IAbortOperation* p_abort) override { + ErrorCode ret = m_extraIndex->Reassign((ExtraWorkSpace*)p_workSpace, m_index, vectorInfo, HeadPrev); + if (ret != ErrorCode::Success) + m_extraIndex->m_asyncStatus = ret; + if (m_callback != nullptr) { + m_callback(); + } + } + }; + + class SPDKThreadPool : public Helper::ThreadPool + { + public: + void initSPDK(int numberOfThreads, ExtraDynamicSearcher* extraIndex) + { + m_abort.SetAbort(false); + for (int i = 0; i < numberOfThreads; i++) + { + m_threads.emplace_back([this, extraIndex] { + Job *j; + ExtraWorkSpace workSpace; + extraIndex->InitWorkSpace(&workSpace); + while (get(j)) + { + try + { + j->exec(&workSpace, &m_abort); + } + catch (std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ThreadPool: exception in %s %s\n", typeid(*j).name(), e.what()); + } + delete j; + currentJobs--; + } + }); + } + } + }; + + private: + std::shared_ptr> m_freeWorkSpaceIds; + std::atomic m_workspaceCount = 0; + + Helper::Concurrent::ConcurrentPriorityQueue m_asyncAppendQueue; + std::mutex m_asyncAppendLock; + + std::shared_ptr db; + + COMMON::VersionLabel* m_versionMap; + Options* m_opt; + + std::mutex m_dataAddLock; + + std::mutex m_mergeLock; + + COMMON::FineGrainedRWLock m_rwLocks; + + COMMON::PostingSizeRecord m_postingSizes; + + COMMON::Checksum m_checkSum; + COMMON::Dataset m_checkSums; + + IndexStats m_stat; + + std::shared_ptr m_wal; + + std::mutex m_runningLock; + std::unordered_setm_splitList; + + Helper::Concurrent::ConcurrentMap m_mergeList; + + public: + ExtraDynamicSearcher(SPANN::Options& p_opt) { + m_opt = &p_opt; + m_metaDataSize = sizeof(int) + sizeof(uint8_t); + m_vectorInfoSize = p_opt.m_dim * sizeof(ValueType) + m_metaDataSize; + p_opt.m_postingPageLimit = max(p_opt.m_postingPageLimit, static_cast((p_opt.m_postingVectorLimit * m_vectorInfoSize + PageSize - 1) / PageSize)); + p_opt.m_searchPostingPageLimit = p_opt.m_postingPageLimit; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting index with posting page limit:%d\n", p_opt.m_postingPageLimit); + m_postingSizeLimit = p_opt.m_postingPageLimit * PageSize / m_vectorInfoSize; + m_bufferSizeLimit = p_opt.m_bufferLength * PageSize / m_vectorInfoSize; + + if(p_opt.m_storage == Storage::FILEIO) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ExtraDynamicSearcher:UseFileIO\n"); + db.reset(new FileIO(p_opt)); + } + else if (p_opt.m_storage == Storage::SPDKIO) { +#ifdef SPDK + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ExtraDynamicSearcher:UseSPDK\n"); + db.reset(new SPDKIO(p_opt)); +#else + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ExtraDynamicSearcher:SPDK unsupport! Use -DSPDK to enable SPDK when doing cmake.\n"); + return; +#endif + } + else if (p_opt.m_storage == Storage::ROCKSDBIO) { +#ifdef ROCKSDB + std::string indexDir = (p_opt.m_recovery)? p_opt.m_persistentBufferPath + FolderSep: p_opt.m_indexDirectory + FolderSep; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ExtraDynamicSearcher:UseKV\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ExtraDynamicSearcher:dbPath:%s\n", (indexDir + p_opt.m_KVFile).c_str()); + db.reset(new RocksDBIO((indexDir + p_opt.m_KVFile).c_str(), p_opt.m_useDirectIO, p_opt.m_enableWAL, p_opt.m_recovery)); +#else + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ExtraDynamicSearcher:RocksDB unsupport! Use -DROCKSDB to enable RocksDB when doing cmake.\n"); + return; +#endif + } + + + m_hardLatencyLimit = std::chrono::microseconds((int)(p_opt.m_latencyLimit) * 1000); + m_mergeThreshold = p_opt.m_mergeThreshold; + m_checkSum.Initialize(!p_opt.m_checksumCheck, 0, 0); + + int maxIOThreads = max(p_opt.m_ioThreads, (2 * max(p_opt.m_searchThreadNum, p_opt.m_iSSDNumberOfThreads) + + p_opt.m_insertThreadNum + p_opt.m_reassignThreadNum + p_opt.m_appendThreadNum)); + m_freeWorkSpaceIds.reset(new Helper::Concurrent::ConcurrentQueue()); + for (int i = 0; i < maxIOThreads; i++) { + m_freeWorkSpaceIds->push(i); + } + m_workspaceCount = maxIOThreads; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting size limit: %d, search limit: %f, merge threshold: %d\n", m_postingSizeLimit, p_opt.m_latencyLimit, m_mergeThreshold); + } + + ~ExtraDynamicSearcher() {} + + virtual bool Available() override + { + return db->Available(); + } + + //headCandidates: search data structrue for "vid" vector + //headID: the head vector that stands for vid + bool IsAssumptionBroken(VectorIndex* p_index, SizeType headID, QueryResult& headCandidates, SizeType vid) + { + p_index->SearchIndex(headCandidates); + int replicaCount = 0; + BasicResult* queryResults = headCandidates.GetResults(); + std::vector selections(static_cast(m_opt->m_replicaCount)); + for (int i = 0; i < headCandidates.GetResultNum() && replicaCount < m_opt->m_replicaCount; ++i) { + if (queryResults[i].VID == -1) { + break; + } + // RNG Check. + bool rngAccpeted = true; + for (int j = 0; j < replicaCount; ++j) { + float nnDist = p_index->ComputeDistance( + p_index->GetSample(queryResults[i].VID), + p_index->GetSample(selections[j].node)); + if (nnDist < queryResults[i].Dist) { + rngAccpeted = false; + break; + } + } + if (!rngAccpeted) + continue; + + selections[replicaCount].node = queryResults[i].VID; + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head:%d\n", queryResults[i].VID); + if (selections[replicaCount].node == headID) return false; + ++replicaCount; + } + return true; + } + + //Measure that in "headID" posting list, how many vectors break their assumption + int QuantifyAssumptionBroken(VectorIndex* p_index, SizeType headID, std::string& postingList, SizeType SplitHead, std::vector& newHeads, std::set& brokenID, int topK = 0, float ratio = 1.0) + { + int assumptionBrokenNum = 0; + int postVectorNum = postingList.size() / m_vectorInfoSize; + uint8_t* postingP = reinterpret_cast(&postingList.front()); + float minDist; + float maxDist; + float avgDist = 0; + std::vector distanceSet; + + for (int j = 0; j < postVectorNum; j++) { + uint8_t* vectorId = postingP + j * m_vectorInfoSize; + SizeType vid = *(reinterpret_cast(vectorId)); + uint8_t version = *(reinterpret_cast(vectorId + sizeof(int))); + float_t dist = p_index->ComputeDistance(reinterpret_cast(vectorId + m_metaDataSize), p_index->GetSample(headID)); + // if (dist < Epsilon) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head found: vid: %d, head: %d\n", vid, headID); + avgDist += dist; + distanceSet.push_back(dist); + if (m_versionMap->Deleted(vid) || m_versionMap->GetVersion(vid) != version) continue; + COMMON::QueryResultSet headCandidates(reinterpret_cast(vectorId + m_metaDataSize), 64); + if (brokenID.find(vid) == brokenID.end() && IsAssumptionBroken(p_index, headID, headCandidates, vid)) { + /* + float_t headDist = p_index->ComputeDistance(headCandidates.GetTarget(), p_index->GetSample(SplitHead)); + float_t newHeadDist_1 = p_index->ComputeDistance(headCandidates.GetTarget(), p_index->GetSample(newHeads[0])); + float_t newHeadDist_2 = p_index->ComputeDistance(headCandidates.GetTarget(), p_index->GetSample(newHeads[1])); + + float_t splitDist = p_index->ComputeDistance(p_index->GetSample(SplitHead), p_index->GetSample(headID)); + + float_t headToNewHeadDist_1 = p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeads[0])); + float_t headToNewHeadDist_2 = p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeads[1])); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "broken vid to head distance: %f, to split head distance: %f\n", dist, headDist); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "broken vid to new head 1 distance: %f, to new head 2 distance: %f\n", newHeadDist_1, newHeadDist_2); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head to spilit head distance: %f\n", splitDist); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head to new head 1 distance: %f, to new head 2 distance: %f\n", headToNewHeadDist_1, headToNewHeadDist_2); + */ + assumptionBrokenNum++; + brokenID.insert(vid); + } + } + + if (assumptionBrokenNum != 0) { + std::sort(distanceSet.begin(), distanceSet.end()); + minDist = distanceSet[1]; + maxDist = distanceSet.back(); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "distance: min: %f, max: %f, avg: %f, 50th: %f\n", minDist, maxDist, avgDist/postVectorNum, distanceSet[distanceSet.size() * 0.5]); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "assumption broken num: %d\n", assumptionBrokenNum); + float_t splitDist = p_index->ComputeDistance(p_index->GetSample(SplitHead), p_index->GetSample(headID)); + + float_t headToNewHeadDist_1 = p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeads[0])); + float_t headToNewHeadDist_2 = p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeads[1])); + + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head to spilt head distance: %f/%d/%.2f\n", splitDist, topK, ratio); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "head to new head 1 distance: %f, to new head 2 distance: %f\n", headToNewHeadDist_1, headToNewHeadDist_2); + } + + return assumptionBrokenNum; + } + + int QuantifySplitCaseA(VectorIndex* p_index, std::vector& newHeads, std::vector& postingLists, SizeType SplitHead, int split_order, std::set& brokenID) + { + int assumptionBrokenNum = 0; + assumptionBrokenNum += QuantifyAssumptionBroken(p_index, newHeads[0], postingLists[0], SplitHead, newHeads, brokenID); + assumptionBrokenNum += QuantifyAssumptionBroken(p_index, newHeads[1], postingLists[1], SplitHead, newHeads, brokenID); + int vectorNum = (postingLists[0].size() + postingLists[1].size()) / m_vectorInfoSize; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Split%d, Top0 nearby posting lists, caseA : %d/%d\n", split_order, assumptionBrokenNum, vectorNum); + return assumptionBrokenNum; + } + + //Measure that around "headID", how many vectors break their assumption + //"headID" is the head vector before split + void QuantifySplitCaseB(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, SizeType headID, std::vector& newHeads, SizeType SplitHead, int split_order, int assumptionBrokenNum_top0, std::set& brokenID) + { + COMMON::QueryResultSet nearbyHeads(reinterpret_cast(p_index->GetSample(headID)), 64); + std::vector postingLists; + p_index->SearchIndex(nearbyHeads); + std::string postingList; + BasicResult* queryResults = nearbyHeads.GetResults(); + int topk = 8; + int assumptionBrokenNum = assumptionBrokenNum_top0; + int assumptionBrokenNum_topK = assumptionBrokenNum_top0; + int i; + int containedHead = 0; + if (assumptionBrokenNum_top0 != 0) containedHead++; + int vectorNum = 0; + float furthestDist = 0; + for (i = 0; i < nearbyHeads.GetResultNum(); i++) { + if (queryResults[i].VID == -1) { + break; + } + furthestDist = queryResults[i].Dist; + if (i == topk) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Split%d, Top%d nearby posting lists, caseB : %d in %d/%d\n", split_order, i, assumptionBrokenNum, containedHead, vectorNum); + topk *= 2; + } + if (queryResults[i].VID == newHeads[0] || queryResults[i].VID == newHeads[1]) continue; + db->Get(queryResults[i].VID, &postingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests)); + vectorNum += postingList.size() / m_vectorInfoSize; + int tempNum = QuantifyAssumptionBroken(p_index, queryResults[i].VID, postingList, SplitHead, newHeads, brokenID, i, queryResults[i].Dist / queryResults[1].Dist); + assumptionBrokenNum += tempNum; + if (tempNum != 0) containedHead++; + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Split%d, Top%d nearby posting lists, caseB : %d in %d/%d\n", split_order, i, assumptionBrokenNum, containedHead, vectorNum); + } + + void QuantifySplit(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, SizeType headID, std::vector& postingLists, std::vector& newHeads, SizeType SplitHead, int split_order) + { + std::set brokenID; + brokenID.clear(); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Split Quantify: %d, head1:%d, head2:%d\n", split_order, newHeads[0], newHeads[1]); + int assumptionBrokenNum = QuantifySplitCaseA(p_index, newHeads, postingLists, SplitHead, split_order, brokenID); + QuantifySplitCaseB(p_exWorkSpace, p_index, headID, newHeads, SplitHead, split_order, assumptionBrokenNum, brokenID); + } + + bool CheckIsNeedReassign(VectorIndex* p_index, std::vector& newHeads, ValueType* data, SizeType splitHead, float_t headToSplitHeadDist, float_t currentHeadDist, bool isInSplitHead, SizeType currentHead) + { + + float_t splitHeadDist = p_index->ComputeDistance(data, p_index->GetSample(splitHead)); + + if (isInSplitHead) { + if (splitHeadDist >= currentHeadDist) return false; + } + else { + float_t newHeadDist_1 = p_index->ComputeDistance(data, p_index->GetSample(newHeads[0])); + float_t newHeadDist_2 = p_index->ComputeDistance(data, p_index->GetSample(newHeads[1])); + if (splitHeadDist <= newHeadDist_1 && splitHeadDist <= newHeadDist_2) return false; + if (currentHeadDist <= newHeadDist_1 && currentHeadDist <= newHeadDist_2) return false; + } + return true; + } + + inline void Serialize(char* ptr, SizeType VID, std::uint8_t version, const void* vector) { + memcpy(ptr, &VID, sizeof(VID)); + memcpy(ptr + sizeof(VID), &version, sizeof(version)); + memcpy(ptr + m_metaDataSize, vector, m_vectorInfoSize - m_metaDataSize); + } + + void CalculatePostingDistribution(VectorIndex* p_index) + { + int top = m_postingSizeLimit / 10 + 1; + int page = m_opt->m_postingPageLimit + 1; + std::vector lengthDistribution(top, 0); + std::vector sizeDistribution(page + 2, 0); + int deletedHead = 0; + for (int i = 0; i < p_index->GetNumSamples(); i++) { + if (!p_index->ContainSample(i)) deletedHead++; + lengthDistribution[m_postingSizes.GetSize(i) / 10]++; + int size = m_postingSizes.GetSize(i) * m_vectorInfoSize; + if (size < PageSize) { + if (size < 512) sizeDistribution[0]++; + else if (size < 1024) sizeDistribution[1]++; + else sizeDistribution[2]++; + } + else { + sizeDistribution[size / PageSize + 2]++; + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting Length (Vector Num):\n"); + for (int i = 0; i < top; ++i) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d ~ %d: %d, \n", i * 10, (i + 1) * 10 - 1, lengthDistribution[i]); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting Length (Data Size):\n"); + for (int i = 0; i < page + 2; ++i) + { + if (i <= 2) { + if (i == 0) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "0 ~ 512 B: %d, \n", sizeDistribution[0] - deletedHead); + else if (i == 1) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "512 B ~ 1 KB: %d, \n", sizeDistribution[1]); + else SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "1 KB ~ 4 KB: %d, \n", sizeDistribution[2]); + } + else + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d ~ %d KB: %d, \n", (i - 2) * 4, (i - 1) * 4, sizeDistribution[i]); + } + } + + void PrintErrorInPosting(std::string &posting, SizeType headID) + { + SizeType postVectorNum = posting.size() / m_vectorInfoSize; + uint8_t *vectorId = reinterpret_cast(posting.data()); + for (int j = 0; j < postVectorNum; j++, vectorId += m_vectorInfoSize) + { + SizeType VID = *((SizeType *)(vectorId)); + if (VID < 0 || VID >= m_versionMap->Count()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "PrintErrorInPosting found wrong VID:%d in headID:%d (should be less than %d)\n", VID, + headID, m_versionMap->Count()); + } + } + } + + // TODO + ErrorCode RefineIndex(std::shared_ptr p_index, + bool p_prereassign, std::vector *p_headmapping, std::vector *p_mapping) override + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin RefineIndex\n"); + + COMMON::PostingSizeRecord new_postingSizes; + COMMON::Dataset new_checkSums; + if (!p_prereassign) + { + new_postingSizes.Initialize(p_index->GetNumSamples() - p_index->GetNumDeleted(), + p_index->m_iDataBlockSize, p_index->m_iDataCapacity); + new_checkSums.Initialize(p_index->GetNumSamples() - p_index->GetNumDeleted(), 1, + p_index->m_iDataBlockSize, p_index->m_iDataCapacity); + } + std::atomic_bool doneReassign = false; + while (!doneReassign) { + auto preReassignTimeBegin = std::chrono::high_resolution_clock::now(); + std::atomic finalcode = ErrorCode::Success; + doneReassign = true; + std::vector threads; + std::atomic_int nextPostingID(0); + int currentPostingNum = p_index->GetNumSamples(); + int limit = m_postingSizeLimit * m_opt->m_preReassignRatio; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Batch PreReassign, Current PostingNum: %d, Current Limit: %d\n", currentPostingNum, limit); + auto func = [&]() + { + ErrorCode ret; + int index = 0; + ExtraWorkSpace workSpace; + InitWorkSpace(&workSpace); + while (true) + { + index = nextPostingID.fetch_add(1); + if (index < currentPostingNum) + { + if ((index & ((1 << 14) - 1)) == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Sent %.2lf%%...\n", index * 100.0 / currentPostingNum); + } + if (p_prereassign) + { + if (m_postingSizes.GetSize(index) >= limit) + { + doneReassign = false; + Split(&workSpace, p_index.get(), index, false, p_prereassign); + } + } + else + { + if (!p_index->ContainSample(index)) + continue; + + // ForceCompaction + std::string postingList; + if ((ret = db->Get(index, &postingList, MaxTimeout, &(workSpace.m_diskRequests))) != + ErrorCode::Success || + !m_checkSum.ValidateChecksum(postingList.c_str(), (int)(postingList.size()), + *m_checkSums[index])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "RefineIndex failed to get posting %d, required size:%d, read size:%d " + "checksum issue:%d\n", + index, (int)(m_postingSizes.GetSize(index) * m_vectorInfoSize), + (int)(postingList.size()), (int)(ret == ErrorCode::Success)); + PrintErrorInPosting(postingList, index); + finalcode = ret; + return; + } + SizeType postVectorNum = (SizeType)(postingList.size() / m_vectorInfoSize); + auto *postingP = reinterpret_cast(&postingList.front()); + uint8_t *vectorId = postingP; + int vectorCount = 0; + for (int j = 0; j < postVectorNum; + j++, vectorId += m_vectorInfoSize) + { + uint8_t version = *(vectorId + sizeof(SizeType)); + SizeType VID = *((SizeType *)(vectorId)); + + if (m_versionMap->Deleted(VID) || m_versionMap->GetVersion(VID) != version) + continue; + + *((SizeType *)(vectorId)) = (*p_mapping)[VID]; + if (j != vectorCount) + { + memcpy(postingP + vectorCount * m_vectorInfoSize, vectorId, m_vectorInfoSize); + } + vectorCount++; + } + + postingList.resize(vectorCount * m_vectorInfoSize); + new_postingSizes.UpdateSize(p_headmapping->at(index), vectorCount); + *new_checkSums[p_headmapping->at(index)] = + m_checkSum.CalcChecksum(postingList.c_str(), (int)(postingList.size())); + if ((ret = db->Put(p_headmapping->at(index), postingList, MaxTimeout, + &(workSpace.m_diskRequests))) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "RefineIndex Failed to write back compacted posting\n"); + finalcode = ret; + return; + } + if (m_opt->m_consistencyCheck && (ret = db->Check(p_headmapping->at(index), new_postingSizes.GetSize(p_headmapping->at(index)) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "RefineIndex: Check failed after Put %d\n", + p_headmapping->at(index)); + finalcode = ret; + return; + } + } + } + else + { + return; + } + } + }; + for (int j = 0; j < m_opt->m_iSSDNumberOfThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + auto preReassignTimeEnd = std::chrono::high_resolution_clock::now(); + double elapsedSeconds = std::chrono::duration_cast(preReassignTimeEnd - preReassignTimeBegin).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rebuild cost: %.2lf s\n", elapsedSeconds); + + if (finalcode != ErrorCode::Success) + return finalcode; + + if (p_prereassign) + { + p_index->SaveIndex(m_opt->m_indexDirectory + FolderSep + m_opt->m_headIndexFolder); + Checkpoint(m_opt->m_indexDirectory); + CalculatePostingDistribution(p_index.get()); + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving posting size\n"); + std::string p_persistenRecord = m_opt->m_indexDirectory + FolderSep + m_opt->m_ssdInfoFile; + new_postingSizes.Save(p_persistenRecord); + std::string p_checksumPath = m_opt->m_indexDirectory + FolderSep + m_opt->m_checksumFile; + new_checkSums.Save(p_checksumPath); + db->Checkpoint(m_opt->m_indexDirectory); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: ReWriting SSD Info\n"); + } + return ErrorCode::Success; + } + + ErrorCode Split(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, const SizeType headID, bool reassign = false, bool preReassign = false, bool requirelock = true) + { + auto splitBegin = std::chrono::high_resolution_clock::now(); + std::vector newHeadsID; + std::vector newPostingLists; + ErrorCode ret; + bool theSameHead = false; + double elapsedMSeconds; + { + std::unique_lock lock(m_rwLocks[headID], std::defer_lock); + if (requirelock) lock.lock(); + + if (!p_index->ContainSample(headID)) return ErrorCode::Success; + + std::string postingList; + auto splitGetBegin = std::chrono::high_resolution_clock::now(); + if ((ret=db->Get(headID, &postingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != + ErrorCode::Success)// || !m_checkSum.ValidateChecksum(postingList.c_str(), (int)(postingList.size()), *m_checkSums[headID])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split fail to get oversized postings: key=%d size=%d\n", headID, m_postingSizes.GetSize(headID)); + return ret; + } + auto splitGetEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(splitGetEnd - splitGetBegin).count(); + m_stat.m_getCost += elapsedMSeconds; + // reinterpret postingList to vectors and IDs + auto* postingP = reinterpret_cast(&postingList.front()); + SizeType postVectorNum = (SizeType)(postingList.size() / m_vectorInfoSize); + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DEBUG: db get Posting %d successfully with length %d real length:%d vectorNum:%d\n", headID, (int)(postingList.size()), m_postingSizes.GetSize(headID), postVectorNum); + COMMON::Dataset smallSample(postVectorNum, m_opt->m_dim, p_index->m_iDataBlockSize, p_index->m_iDataCapacity, (ValueType*)postingP, true, nullptr, m_metaDataSize, m_vectorInfoSize); + //COMMON::Dataset smallSample(0, m_opt->m_dim, p_index->m_iDataBlockSize, p_index->m_iDataCapacity); // smallSample[i] -> VID + //std::vector localIndicesInsert(postVectorNum); // smallSample[i] = j <-> localindices[j] = i + //std::vector localIndicesInsertVersion(postVectorNum); + std::vector localIndices(postVectorNum); + int index = 0; + uint8_t* vectorId = postingP; + for (int j = 0; j < postVectorNum; j++, vectorId += m_vectorInfoSize) + { + //LOG(Helper::LogLevel::LL_Info, "vector index/total:id: %d/%d:%d\n", j, m_postingSizes[headID].load(), *(reinterpret_cast(vectorId))); + uint8_t version = *(vectorId + sizeof(int)); + int VID = *((int*)(vectorId)); + //if (VID >= m_versionMap->Count()) SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "DEBUG: vector ID:%d total size:%d\n", VID, m_versionMap->Count()); + if (m_versionMap->Deleted(VID) || m_versionMap->GetVersion(VID) != version) continue; + + //localIndicesInsert[index] = VID; + //localIndicesInsertVersion[index] = version; + //smallSample.AddBatch(1, (ValueType*)(vectorId + m_metaDataSize)); + localIndices[index] = j; + index++; + } + // double gcEndTime = sw.getElapsedMs(); + // m_splitGcCost += gcEndTime; + + if (!preReassign && index < m_postingSizeLimit) + { + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DEBUG: in place or not prereassign & index < m_postingSizeLimit. GC begin...\n"); + char* ptr = (char*)(postingList.c_str()); + for (int j = 0; j < index; j++, ptr += m_vectorInfoSize) + { + if (j == localIndices[j]) continue; + memcpy(ptr, postingList.c_str() + localIndices[j] * m_vectorInfoSize, m_vectorInfoSize); + //Serialize(ptr, localIndicesInsert[j], localIndicesInsertVersion[j], smallSample[j]); + } + postingList.resize(index * m_vectorInfoSize); + m_postingSizes.UpdateSize(headID, index); + *m_checkSums[headID] = m_checkSum.CalcChecksum(postingList.c_str(), (int)(postingList.size())); + if ((ret=db->Put(headID, postingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split Fail to write back postings\n"); + return ret; + } + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split: Check failed after Put %d\n", headID); + return ret; + } + m_stat.m_garbageNum++; + auto GCEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(GCEnd - splitBegin).count(); + m_stat.m_garbageCost += elapsedMSeconds; + { + std::lock_guard tmplock(m_runningLock); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"erase: %d\n", headID); + m_splitList.erase(headID); + } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "GC triggered: %d, new length: %d\n", headID, index); + return ErrorCode::Success; + } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Resize\n"); + localIndices.resize(index); + + auto clusterBegin = std::chrono::high_resolution_clock::now(); + // k = 2, maybe we can change the split number, now it is fixed + SPTAG::COMMON::KmeansArgs args(2, smallSample.C(), (SizeType)localIndices.size(), 1, p_index->GetDistCalcMethod()); + std::shuffle(localIndices.begin(), localIndices.end(), std::mt19937(std::random_device()())); + + int numClusters = SPTAG::COMMON::KmeansClustering(smallSample, localIndices, 0, (SizeType)localIndices.size(), args, 1000, 100.0F, false, nullptr); + + auto clusterEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(clusterEnd - clusterBegin).count(); + m_stat.m_clusteringCost += elapsedMSeconds; + // int numClusters = ClusteringSPFresh(smallSample, localIndices, 0, localIndices.size(), args, 10, false, m_opt->m_virtualHead); + if (numClusters <= 1) + { + int cut = 1; + if (m_opt->m_oneClusterCutMax) cut = m_postingSizeLimit; + std::string newpostingList(cut * m_vectorInfoSize, '\0'); + char* ptr = (char*)(newpostingList.c_str()); + float totaldist = 0.0f; + for (int j = 0; j < cut; j++, ptr += m_vectorInfoSize) + { + totaldist += p_index->ComputeDistance(ptr + sizeof(int) + 1, args.centers); + memcpy(ptr, postingList.c_str() + localIndices[j] * m_vectorInfoSize, m_vectorInfoSize); + //Serialize(ptr, localIndicesInsert[j], localIndicesInsertVersion[j], smallSample[j]); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Cluserting Failed (The same vector), Cluster total dist:%f Only Keep %d vectors.\n", totaldist, cut); + + m_postingSizes.UpdateSize(headID, cut); + *m_checkSums[headID] = m_checkSum.CalcChecksum(newpostingList.c_str(), (int)(newpostingList.size())); + if ((ret=db->Put(headID, newpostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split fail to override postings cut to limit\n"); + return ret; + } + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split: Consolidate Check failed after Put %d\n", headID); + return ret; + } + { + std::lock_guard tmplock(m_runningLock); + m_splitList.erase(headID); + } + return ErrorCode::Success; + } + + long long newHeadVID = -1; + int first = 0; + newPostingLists.resize(2); + for (int k = 0; k < 2; k++) { + if (args.counts[k] == 0) continue; + + newPostingLists[k].resize(args.counts[k] * m_vectorInfoSize); + char* ptr = (char*)(newPostingLists[k].c_str()); + for (int j = 0; j < args.counts[k]; j++, ptr += m_vectorInfoSize) + { + memcpy(ptr, postingList.c_str() + localIndices[first + j] * m_vectorInfoSize, m_vectorInfoSize); + //Serialize(ptr, localIndicesInsert[localIndices[first + j]], localIndicesInsertVersion[localIndices[first + j]], smallSample[localIndices[first + j]]); + } + if (!theSameHead && p_index->ComputeDistance(args.centers + k * args._D, p_index->GetSample(headID)) < Epsilon) { + newHeadsID.push_back(headID); + newHeadVID = headID; + theSameHead = true; + m_postingSizes.UpdateSize(newHeadVID, args.counts[k]); + *m_checkSums[newHeadVID] = + m_checkSum.CalcChecksum(newPostingLists[k].c_str(), (int)(newPostingLists[k].size())); + auto splitPutBegin = std::chrono::high_resolution_clock::now(); + if (!preReassign && (ret=db->Put(newHeadVID, newPostingLists[k], MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to override postings\n"); + return ret; + } + if (m_opt->m_consistencyCheck && (ret = db->Check(newHeadVID, m_postingSizes.GetSize(newHeadVID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split: Cluster Write Check failed after Put %d\n", newHeadVID); + return ret; + } + auto splitPutEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(splitPutEnd - splitPutBegin).count(); + m_stat.m_putCost += elapsedMSeconds; + m_stat.m_theSameHeadNum++; + } + else { + int begin, end = 0; + p_index->AddIndexId(args.centers + k * args._D, 1, m_opt->m_dim, begin, end); + { + std::lock_guard tmplock(m_runningLock); + m_vectorTranslateMap->AddBatch(1); + } + if (m_opt->m_excludehead) + { + SizeType VID = *((SizeType*)(postingP + args.clusterIdx[k] * m_vectorInfoSize)); + uint8_t version = *((uint8_t*)(postingP + args.clusterIdx[k] * m_vectorInfoSize + sizeof(SizeType))); + *(m_vectorTranslateMap->At(begin)) = VID; + m_versionMap->IncVersion(VID, &version); + } + else + { + *(m_vectorTranslateMap->At(begin)) = MaxSize; + } + newHeadVID = begin; + newHeadsID.push_back(begin); + auto splitPutBegin = std::chrono::high_resolution_clock::now(); + if (!preReassign && (ret=db->Put(newHeadVID, newPostingLists[k], MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to add new postings\n"); + return ret; + } + auto splitPutEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(splitPutEnd - splitPutBegin).count(); + m_stat.m_putCost += elapsedMSeconds; + + std::lock_guard tmplock(m_dataAddLock); + if (m_postingSizes.AddBatch(1) == ErrorCode::MemoryOverFlow || m_checkSums.AddBatch(1) == ErrorCode::MemoryOverFlow) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "MemoryOverFlow: NnewHeadVID: %d, Map Size:%d\n", + newHeadVID, m_postingSizes.BufferSize()); + return ErrorCode::MemoryOverFlow; + } + m_postingSizes.UpdateSize(newHeadVID, args.counts[k]); + *m_checkSums[newHeadVID] = + m_checkSum.CalcChecksum(newPostingLists[k].c_str(), (int)(newPostingLists[k].size())); + if (m_opt->m_consistencyCheck && (ret = db->Check(newHeadVID, m_postingSizes.GetSize(newHeadVID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split: Cluster Write Check failed after Put %d\n", newHeadVID); + return ret; + } + + auto updateHeadBegin = std::chrono::high_resolution_clock::now(); + p_index->AddIndexIdx(begin, end); + auto updateHeadEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(updateHeadEnd - updateHeadBegin).count(); + m_stat.m_updateHeadCost += elapsedMSeconds; + } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Head id: %d split into : %d, length: %d\n", headID, newHeadVID, args.counts[k]); + first += args.counts[k]; + } + if (!theSameHead) { + p_index->DeleteIndex(headID); + m_postingSizes.UpdateSize(headID, 0); + *m_checkSums[headID] = 0; + if ((ret=db->Delete(headID)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to delete old posting in Split\n"); + return ret; + } + } + } + { + std::lock_guard tmplock(m_runningLock); + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"erase: %d\n", headID); + m_splitList.erase(headID); + } + m_stat.m_splitNum++; + if (reassign) { + auto reassignScanBegin = std::chrono::high_resolution_clock::now(); + + CollectReAssign(p_exWorkSpace, p_index, headID, newPostingLists, newHeadsID, theSameHead); + + auto reassignScanEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(reassignScanEnd - reassignScanBegin).count(); + + m_stat.m_reassignScanCost += elapsedMSeconds; + } + auto splitEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(splitEnd - splitBegin).count(); + m_stat.m_splitCost += elapsedMSeconds; + return ErrorCode::Success; + } + + ErrorCode MergePostings(ExtraWorkSpace *p_exWorkSpace, VectorIndex* p_index, SizeType headID, bool reassign = false) + { + { + if (!m_mergeLock.try_lock()) { + auto* curJob = new MergeAsyncJob(p_index, this, headID, reassign, nullptr); + m_splitThreadPool->add(curJob); + return ErrorCode::Success; + } + std::unique_lock lock(m_rwLocks[headID]); + + if (!p_index->ContainSample(headID)) { + m_mergeLock.unlock(); + return ErrorCode::Success; + } + + std::string mergedPostingList; + std::set vectorIdSet; + + std::string currentPostingList; + ErrorCode ret; + if ((ret = db->Get(headID, ¤tPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != + ErrorCode::Success || + !m_checkSum.ValidateChecksum(currentPostingList.c_str(), (int)(currentPostingList.size()), *m_checkSums[headID])) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "Fail to get original merge postings: %d, required size:%d, get size:%d, checksum issue:%d\n", + headID, (int)(m_postingSizes.GetSize(headID) * m_vectorInfoSize), + (int)(currentPostingList.size()), (int)(ret == ErrorCode::Success)); + PrintErrorInPosting(currentPostingList, headID); + m_mergeLock.unlock(); + return ret; + } + + auto* postingP = reinterpret_cast(¤tPostingList.front()); + size_t postVectorNum = currentPostingList.size() / m_vectorInfoSize; + int currentLength = 0; + uint8_t* vectorId = postingP; + for (int j = 0; j < postVectorNum; j++, vectorId += m_vectorInfoSize) + { + int VID = *((int*)(vectorId)); + uint8_t version = *(vectorId + sizeof(int)); + if (m_versionMap->Deleted(VID) || m_versionMap->GetVersion(VID) != version) continue; + vectorIdSet.insert(VID); + mergedPostingList += currentPostingList.substr(j * m_vectorInfoSize, m_vectorInfoSize); + currentLength++; + } + int totalLength = currentLength; + + if (currentLength > m_mergeThreshold) + { + if ((ret=db->Put(headID, mergedPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge Fail to write back postings\n"); + m_mergeLock.unlock(); + return ret; + } + + m_postingSizes.UpdateSize(headID, currentLength); + *m_checkSums[headID] = + m_checkSum.CalcChecksum(mergedPostingList.c_str(), (int)(mergedPostingList.size())); + + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge: Check failed after Put %d\n", headID); + m_mergeLock.unlock(); + return ret; + } + m_mergeList.unsafe_erase(headID); + m_mergeLock.unlock(); + return ErrorCode::Success; + } + + QueryResult queryResults(p_index->GetSample(headID), m_opt->m_internalResultNum, false); + p_index->SearchIndex(queryResults); + + std::string nextPostingList; + for (int i = 1; i < queryResults.GetResultNum(); ++i) + { + BasicResult* queryResult = queryResults.GetResult(i); + int nextLength = m_postingSizes.GetSize(queryResult->VID); + if (currentLength + nextLength < m_postingSizeLimit && m_mergeList.find(queryResult->VID) == m_mergeList.end()) + { + { + std::unique_lock anotherLock(m_rwLocks[queryResult->VID], std::defer_lock); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Locked: %d, to be lock: %d\n", headID, queryResult->VID); + if (m_rwLocks.hash_func(queryResult->VID) != m_rwLocks.hash_func(headID)) anotherLock.lock(); + if (!p_index->ContainSample(queryResult->VID)) continue; + if ((ret=db->Get(queryResult->VID, &nextPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success || + !m_checkSum.ValidateChecksum(nextPostingList.c_str(), (int)(nextPostingList.size()), *m_checkSums[queryResult->VID])) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Fail to get to be merged postings: %d, required size:%d get size:%d, " + "checksum issue:%d\n", + queryResult->VID, + (int)(m_postingSizes.GetSize(queryResult->VID) * m_vectorInfoSize), + (int)(nextPostingList.size()), (int)(ret == ErrorCode::Success)); + PrintErrorInPosting(nextPostingList, queryResult->VID); + m_mergeLock.unlock(); + return ret; + } + postingP = reinterpret_cast(nextPostingList.data()); + postVectorNum = nextPostingList.size() / m_vectorInfoSize; + if (currentLength + postVectorNum > m_postingSizeLimit) + { + continue; + } + + nextLength = 0; + vectorId = postingP; + for (int j = 0; j < postVectorNum; j++, vectorId += m_vectorInfoSize) + { + int VID = *((int*)(vectorId)); + uint8_t version = *(vectorId + sizeof(int)); + if (m_versionMap->Deleted(VID) || m_versionMap->GetVersion(VID) != version) continue; + if (vectorIdSet.find(VID) == vectorIdSet.end()) { + mergedPostingList += nextPostingList.substr(j * m_vectorInfoSize, m_vectorInfoSize); + totalLength++; + } + nextLength++; + } + if (currentLength > nextLength) + { + p_index->DeleteIndex(queryResult->VID); + if ((ret=db->Put(headID, mergedPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MergePostings fail to override postings after merge\n"); + m_mergeLock.unlock(); + return ret; + } + m_postingSizes.UpdateSize(queryResult->VID, 0); + *m_checkSums[queryResult->VID] = 0; + m_postingSizes.UpdateSize(headID, totalLength); + *m_checkSums[headID] = + m_checkSum.CalcChecksum(mergedPostingList.c_str(), (int)(mergedPostingList.size())); + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MergePostings fail to check old posting %d in Merge\n", headID); + m_mergeLock.unlock(); + return ret; + } + if ((ret=db->Delete(queryResult->VID)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to delete old posting in Merge\n"); + m_mergeLock.unlock(); + return ret; + } + } else + { + p_index->DeleteIndex(headID); + if ((ret=db->Put(queryResult->VID, mergedPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MergePostings fail to override postings after merge\n"); + m_mergeLock.unlock(); + return ret; + } + m_postingSizes.UpdateSize(queryResult->VID, totalLength); + *m_checkSums[queryResult->VID] = + m_checkSum.CalcChecksum(mergedPostingList.c_str(), (int)(mergedPostingList.size())); + m_postingSizes.UpdateSize(headID, 0); + *m_checkSums[headID] = 0; + if (m_opt->m_consistencyCheck && (ret = db->Check(queryResult->VID, m_postingSizes.GetSize(queryResult->VID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MergePostings fail to check nearby posting %d in Merge\n", queryResult->VID); + m_mergeLock.unlock(); + return ret; + } + if ((ret = db->Delete(headID)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to delete old posting in Merge\n"); + m_mergeLock.unlock(); + return ret; + } + } + if (m_rwLocks.hash_func(queryResult->VID) != m_rwLocks.hash_func(headID)) anotherLock.unlock(); + } + + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Release: %d, Release: %d\n", headID, queryResult->VID); + lock.unlock(); + m_mergeLock.unlock(); + + if (reassign) + { + SizeType deletedHead = -1; + /* ReAssign */ + if (currentLength > nextLength) + { + /* ReAssign queryResult->VID*/ + postingP = reinterpret_cast(&nextPostingList.front()); + for (int j = 0; j < nextLength; j++) { + uint8_t* vectorId = postingP + j * m_vectorInfoSize; + // SizeType vid = *(reinterpret_cast(vectorId)); + ValueType* vector = reinterpret_cast(vectorId + m_metaDataSize); + float origin_dist = p_index->ComputeDistance(p_index->GetSample(queryResult->VID), vector); + float current_dist = p_index->ComputeDistance(p_index->GetSample(headID), vector); + if (current_dist > origin_dist) + ReassignAsync(p_index, std::make_shared((char*)vectorId, m_vectorInfoSize), headID); + } + deletedHead = queryResult->VID; + } else + { + /* ReAssign headID*/ + postingP = reinterpret_cast(¤tPostingList.front()); + for (int j = 0; j < currentLength; j++) { + uint8_t* vectorId = postingP + j * m_vectorInfoSize; + // SizeType vid = *(reinterpret_cast(vectorId)); + ValueType* vector = reinterpret_cast(vectorId + m_metaDataSize); + float origin_dist = p_index->ComputeDistance(p_index->GetSample(headID), vector); + float current_dist = p_index->ComputeDistance(p_index->GetSample(queryResult->VID), vector); + if (current_dist > origin_dist) + ReassignAsync(p_index, std::make_shared((char*)vectorId, m_vectorInfoSize), queryResult->VID); + } + deletedHead = headID; + } + + if (m_opt->m_excludehead) + { + SizeType vid = (SizeType)(*(m_vectorTranslateMap->At(deletedHead))); + if (!m_versionMap->Deleted(vid)) + { + std::shared_ptr vectorinfo = + std::make_shared(m_vectorInfoSize, ' '); + Serialize(&(vectorinfo->front()), vid, m_versionMap->GetVersion(vid), + p_index->GetSample(deletedHead)); + ReassignAsync(p_index, vectorinfo, -1); + } + } + } + + m_mergeList.unsafe_erase(headID); + m_stat.m_mergeNum++; + + return ErrorCode::Success; + } + } + mergedPostingList.resize(currentLength * m_vectorInfoSize); + if ((ret=db->Put(headID, mergedPostingList, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge Fail to write back postings\n"); + return ret; + } + + m_postingSizes.UpdateSize(headID, currentLength); + *m_checkSums[headID] = + m_checkSum.CalcChecksum(mergedPostingList.c_str(), (int)(mergedPostingList.size())); + + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge: Check failed after put original posting %d\n", headID); + return ret; + } + m_mergeList.unsafe_erase(headID); + m_mergeLock.unlock(); + } + return ErrorCode::Success; + } + + inline void SplitAsync(VectorIndex* p_index, SizeType headID, std::function p_callback = nullptr) + { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Into SplitAsync, current headID: %d, size: %d\n", headID, m_postingSizes.GetSize(headID)); + // tbb::concurrent_hash_map::const_accessor headIDAccessor; + // if (m_splitList.find(headIDAccessor, headID)) { + // return; + // } + // tbb::concurrent_hash_map::value_type workPair(headID, headID); + // m_splitList.insert(workPair); + { + std::lock_guard tmplock(m_runningLock); + + if (m_splitList.find(headID) != m_splitList.end()) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Already in queue\n"); + return; + } + m_splitList.insert(headID); + } + + auto* curJob = new SplitAsyncJob(p_index, this, headID, m_opt->m_disableReassign, p_callback); + m_splitThreadPool->add(curJob); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Add to thread pool\n"); + } + + inline void MergeAsync(VectorIndex* p_index, SizeType headID, std::function p_callback = nullptr) + { + if (m_mergeList.find(headID) != m_mergeList.end()) { + return; + } + Helper::Concurrent::ConcurrentMap::value_type workPair(headID, headID); + m_mergeList.insert(workPair); + + auto* curJob = new MergeAsyncJob(p_index, this, headID, m_opt->m_disableReassign, p_callback); + m_splitThreadPool->add(curJob); + } + + inline void ReassignAsync(VectorIndex* p_index, std::shared_ptr vectorInfo, SizeType HeadPrev, std::function p_callback = nullptr) + { + auto* curJob = new ReassignAsyncJob(p_index, this, std::move(vectorInfo), HeadPrev, p_callback); + m_splitThreadPool->add(curJob); + } + + ErrorCode CollectReAssign(ExtraWorkSpace *p_exWorkSpace, VectorIndex *p_index, SizeType headID, + std::vector &postingLists, std::vector &newHeadsID, + bool theSameHead) + { + auto headVector = reinterpret_cast(p_index->GetSample(headID)); + if (m_opt->m_excludehead && !theSameHead) + { + SizeType vid = (SizeType)(*(m_vectorTranslateMap->At(headID))); + if (!m_versionMap->Deleted(vid)) + { + std::shared_ptr vectorinfo = std::make_shared(m_vectorInfoSize, ' '); + Serialize(&(vectorinfo->front()), vid, m_versionMap->GetVersion(vid), headVector); + ReassignAsync(p_index, vectorinfo, -1); + } + } + std::vector newHeadsDist; + std::set reAssignVectorsTopK; + newHeadsDist.push_back(p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeadsID[0]))); + newHeadsDist.push_back(p_index->ComputeDistance(p_index->GetSample(headID), p_index->GetSample(newHeadsID[1]))); + for (int i = 0; i < postingLists.size(); i++) { + auto& postingList = postingLists[i]; + size_t postVectorNum = postingList.size() / m_vectorInfoSize; + auto* postingP = reinterpret_cast(&postingList.front()); + for (int j = 0; j < postVectorNum; j++) { + uint8_t* vectorId = postingP + j * m_vectorInfoSize; + SizeType vid = *(reinterpret_cast(vectorId)); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "VID: %d, Head: %d\n", vid, newHeadsID[i]); + uint8_t version = *(reinterpret_cast(vectorId + sizeof(int))); + ValueType* vector = reinterpret_cast(vectorId + m_metaDataSize); + if (reAssignVectorsTopK.find(vid) == reAssignVectorsTopK.end() && !m_versionMap->Deleted(vid) && m_versionMap->GetVersion(vid) == version) { + m_stat.m_reAssignScanNum++; + float dist = p_index->ComputeDistance(p_index->GetSample(newHeadsID[i]), vector); + if (CheckIsNeedReassign(p_index, newHeadsID, vector, headID, newHeadsDist[i], dist, true, newHeadsID[i])) { + ReassignAsync(p_index, std::make_shared((char*)vectorId, m_vectorInfoSize), newHeadsID[i]); + reAssignVectorsTopK.insert(vid); + } + } + } + } + if (m_opt->m_reassignK > 0) { + std::vector HeadPrevTopK; + newHeadsDist.clear(); + newHeadsDist.resize(0); + postingLists.clear(); + postingLists.resize(0); + COMMON::QueryResultSet nearbyHeads(headVector, m_opt->m_reassignK); + p_index->SearchIndex(nearbyHeads); + BasicResult* queryResults = nearbyHeads.GetResults(); + for (int i = 0; i < nearbyHeads.GetResultNum(); i++) { + auto vid = queryResults[i].VID; + if (vid == -1) break; + + if (find(newHeadsID.begin(), newHeadsID.end(), vid) == newHeadsID.end()) { + HeadPrevTopK.push_back(vid); + newHeadsID.push_back(vid); + newHeadsDist.push_back(queryResults[i].Dist); + } + } + auto reassignScanIOBegin = std::chrono::high_resolution_clock::now(); + ErrorCode ret; + if ((ret=db->MultiGet(HeadPrevTopK, &postingLists, m_hardLatencyLimit, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success || !ValidatePostings(HeadPrevTopK, postingLists)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ReAssign can't get all the near postings\n"); + return ret; + } + + auto reassignScanIOEnd = std::chrono::high_resolution_clock::now(); + auto elapsedMSeconds = std::chrono::duration_cast(reassignScanIOEnd - reassignScanIOBegin).count(); + m_stat.m_reassignScanIOCost += elapsedMSeconds; + + for (int i = 0; i < postingLists.size(); i++) { + auto& postingList = postingLists[i]; + size_t postVectorNum = postingList.size() / m_vectorInfoSize; + auto* postingP = reinterpret_cast(postingList.data()); + for (int j = 0; j < postVectorNum; j++) { + uint8_t* vectorId = postingP + j * m_vectorInfoSize; + SizeType vid = *(reinterpret_cast(vectorId)); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d: VID: %d, Head: %d, size:%d/%d\n", i, vid, HeadPrevTopK[i], postingLists.size(), HeadPrevTopK.size()); + uint8_t version = *(reinterpret_cast(vectorId + sizeof(int))); + ValueType* vector = reinterpret_cast(vectorId + m_metaDataSize); + if (reAssignVectorsTopK.find(vid) == reAssignVectorsTopK.end() && !m_versionMap->Deleted(vid) && m_versionMap->GetVersion(vid) == version) { + m_stat.m_reAssignScanNum++; + float dist = p_index->ComputeDistance(p_index->GetSample(HeadPrevTopK[i]), vector); + if (CheckIsNeedReassign(p_index, newHeadsID, vector, headID, newHeadsDist[i], dist, false, HeadPrevTopK[i])) { + ReassignAsync(p_index, std::make_shared((char*)vectorId, m_vectorInfoSize), HeadPrevTopK[i]); + reAssignVectorsTopK.insert(vid); + } + } + } + } + } + return ErrorCode::Success; + } + + bool RNGSelection(std::vector& selections, ValueType* queryVector, VectorIndex* p_index, SizeType p_fullID, int& replicaCount, int checkHeadID = -1) + { + QueryResult queryResults(queryVector, m_opt->m_internalResultNum, false); + p_index->SearchIndex(queryResults); + + replicaCount = 0; + for (int i = 0; i < queryResults.GetResultNum() && replicaCount < m_opt->m_replicaCount; ++i) + { + BasicResult* queryResult = queryResults.GetResult(i); + if (queryResult->VID == -1) { + break; + } + // RNG Check. + bool rngAccpeted = true; + for (int j = 0; j < replicaCount; ++j) + { + float nnDist = p_index->ComputeDistance(p_index->GetSample(queryResult->VID), + p_index->GetSample(selections[j].node)); + if (m_opt->m_rngFactor * nnDist <= queryResult->Dist) + { + rngAccpeted = false; + break; + } + } + if (!rngAccpeted) continue; + selections[replicaCount].node = queryResult->VID; + selections[replicaCount].tonode = p_fullID; + selections[replicaCount].distance = queryResult->Dist; + if (selections[replicaCount].node == checkHeadID) { + return false; + } + ++replicaCount; + } + return true; + } + + void InitWorkSpace(ExtraWorkSpace* p_exWorkSpace, bool clear = false) override + { + if (clear) { + p_exWorkSpace->Clear(m_opt->m_searchInternalResultNum, (max(m_opt->m_postingPageLimit, m_opt->m_searchPostingPageLimit) + m_opt->m_bufferLength) << PageSizeEx, true, m_opt->m_enableDataCompression); + } + else { + p_exWorkSpace->Initialize(m_opt->m_maxCheck, m_opt->m_hashExp, m_opt->m_searchInternalResultNum, (max(m_opt->m_postingPageLimit, m_opt->m_searchPostingPageLimit) + m_opt->m_bufferLength) << PageSizeEx, true, m_opt->m_enableDataCompression); + int wid = 0; + if (m_freeWorkSpaceIds == nullptr || !m_freeWorkSpaceIds->try_pop(wid)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FreeWorkSpaceIds is not initalized or the workspace number is not enough! Please increase iothread number.\n"); + p_exWorkSpace->m_diskRequests[0].m_status = -1; + return; + } + p_exWorkSpace->m_diskRequests[0].m_status = wid; + p_exWorkSpace->m_callback = [m_freeWorkSpaceIds = m_freeWorkSpaceIds, wid] () { + if (m_freeWorkSpaceIds) m_freeWorkSpaceIds->push(wid); + }; + } + } + + ErrorCode AsyncAppend(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, SizeType headID, int appendNum, std::string& appendPosting, int reassignThreshold = 0) + { + if (m_asyncAppendQueue.size() >= m_opt->m_asyncAppendQueueSize) { + std::lock_guard lock(m_asyncAppendLock); + if (m_asyncAppendQueue.size() < m_opt->m_asyncAppendQueueSize) { + m_asyncAppendQueue.push(AppendPair(p_index->GetPriorityID(headID), headID, appendPosting)); + return ErrorCode::Success; + } + + AppendPair workPair; + ErrorCode ret; + while (m_asyncAppendQueue.try_pop(workPair)) { + if ((ret = Append(p_exWorkSpace, p_index, workPair.headID, 1, workPair.posting, reassignThreshold)) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncAppend: Append failed in async queue processing, headID: %d\n", workPair.headID); + return ret; + } + } + } else { + m_asyncAppendQueue.push(AppendPair(p_index->GetPriorityID(headID), headID, appendPosting)); + } + return ErrorCode::Success; + } + + ErrorCode Append(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, SizeType headID, int appendNum, std::string& appendPosting, int reassignThreshold = 0) + { + auto appendBegin = std::chrono::high_resolution_clock::now(); + if (appendPosting.empty()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error! empty append posting!\n"); + } + + if (appendNum == 0) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Error!, headID :%d, appendNum:%d\n", headID, appendNum); + } + + checkDeleted: + if (!p_index->ContainSample(headID)) { + for (int i = 0; i < appendNum; i++) + { + uint32_t idx = i * m_vectorInfoSize; + SizeType VID = *(int*)(&appendPosting[idx]); + uint8_t version = *(uint8_t*)(&appendPosting[idx + sizeof(int)]); + auto vectorInfo = std::make_shared(appendPosting.c_str() + idx, m_vectorInfoSize); + if (m_versionMap->GetVersion(VID) == version) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Head Miss To ReAssign: VID: %d, current version: %d\n", *(int*)(&appendPosting[idx]), version); + m_stat.m_headMiss++; + ReassignAsync(p_index, vectorInfo, headID); + } + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Head Miss Do Not To ReAssign: VID: %d, version: %d, current version: %d\n", *(int*)(&appendPosting[idx]), m_versionMap->GetVersion(*(int*)(&appendPosting[idx])), version); + } + return ErrorCode::Success; + } + double appendIOSeconds = 0; + { + //std::shared_lock lock(m_rwLocks[headID]); //ROCKSDB + std::unique_lock lock(m_rwLocks[headID]); //SPDK + if (!p_index->ContainSample(headID)) { + lock.unlock(); + goto checkDeleted; + } + if (m_postingSizes.GetSize(headID) + appendNum > (m_postingSizeLimit + m_bufferSizeLimit)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "After appending, the number of vectors exceeds the postingsize + buffersize (%d + %d)! Do split now...\n", m_postingSizeLimit, m_bufferSizeLimit); + Split(p_exWorkSpace, p_index, headID, !m_opt->m_disableReassign, false, false); + lock.unlock(); + goto checkDeleted; + } + + ErrorCode ret; + auto appendIOBegin = std::chrono::high_resolution_clock::now(); + *m_checkSums[headID] = + m_checkSum.AppendChecksum(*m_checkSums[headID], appendPosting.c_str(), (int)(appendPosting.size())); + if ((ret=db->Merge(headID, appendPosting, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge failed! Posting Size:%d, limit: %d\n", m_postingSizes.GetSize(headID), m_postingSizeLimit); + GetDBStats(); + return ret; + } + auto appendIOEnd = std::chrono::high_resolution_clock::now(); + appendIOSeconds = std::chrono::duration_cast(appendIOEnd - appendIOBegin).count(); + m_postingSizes.IncSize(headID, appendNum); + if (m_opt->m_consistencyCheck && (ret = db->Check(headID, m_postingSizes.GetSize(headID) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Append: Check failed after Merge %d, append %d vectors with size %d\n", headID, appendNum, (int)(appendPosting.size())); + return ret; + } + } + if (m_postingSizes.GetSize(headID) > (m_postingSizeLimit + reassignThreshold)) { + // SizeType VID = *(int*)(&appendPosting[0]); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Split Triggered by inserting VID: %d, reAssign: %d\n", VID, reassignThreshold); + // GetDBStats(); + // if (m_postingSizes.GetSize(headID) > 120) { + // GetDBStats(); + // } + if (!reassignThreshold) SplitAsync(p_index, headID); + else Split(p_exWorkSpace, p_index, headID, !m_opt->m_disableReassign); + // SplitAsync(p_index, headID); + } + auto appendEnd = std::chrono::high_resolution_clock::now(); + double elapsedMSeconds = std::chrono::duration_cast(appendEnd - appendBegin).count(); + if (!reassignThreshold) { + m_stat.m_appendTaskNum++; + m_stat.m_appendIOCost += appendIOSeconds; + m_stat.m_appendCost += elapsedMSeconds; + } + // } else { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReAssign Append To: %d\n", headID); + // } + return ErrorCode::Success; + } + + ErrorCode Reassign(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index, std::shared_ptr vectorInfo, SizeType HeadPrev) + { + SizeType VID = *((SizeType*)vectorInfo->c_str()); + uint8_t version = *((uint8_t*)(vectorInfo->c_str() + sizeof(VID))); + // return; + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReassignID: %d, version: %d, current version: %d, HeadPrev: %d\n", VID, version, m_versionMap->GetVersion(VID), HeadPrev); + if (m_versionMap->Deleted(VID) || m_versionMap->GetVersion(VID) != version) { + return ErrorCode::Success; + } + auto reassignBegin = std::chrono::high_resolution_clock::now(); + + m_stat.m_reAssignNum++; + + auto selectBegin = std::chrono::high_resolution_clock::now(); + std::vector selections(static_cast(m_opt->m_replicaCount)); + int replicaCount; + bool isNeedReassign = RNGSelection(selections, (ValueType*)(vectorInfo->c_str() + m_metaDataSize), p_index, VID, replicaCount, HeadPrev); + auto selectEnd = std::chrono::high_resolution_clock::now(); + auto elapsedMSeconds = std::chrono::duration_cast(selectEnd - selectBegin).count(); + m_stat.m_selectCost += elapsedMSeconds; + + auto reassignAppendBegin = std::chrono::high_resolution_clock::now(); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Need ReAssign\n"); + if (isNeedReassign && m_versionMap->GetVersion(VID) == version) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Update Version: VID: %d, version: %d, current version: %d\n", VID, version, m_versionMap.GetVersion(VID)); + m_versionMap->IncVersion(VID, &version); + (*vectorInfo)[sizeof(VID)] = version; + + //LOG(Helper::LogLevel::LL_Info, "Reassign: oldVID:%d, replicaCount:%d, candidateNum:%d, dist0:%f\n", oldVID, replicaCount, i, selections[0].distance); + for (int i = 0; i < replicaCount && m_versionMap->GetVersion(VID) == version; i++) { + //LOG(Helper::LogLevel::LL_Info, "Reassign: headID :%d, oldVID:%d, newVID:%d, posting length: %d, dist: %f, string size: %d\n", headID, oldVID, VID, m_postingSizes[headID].load(), selections[i].distance, newPart.size()); + ErrorCode tmp = Append(p_exWorkSpace, p_index, selections[i].node, 1, *vectorInfo, 3); + if (ErrorCode::Success != tmp) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Head Miss: VID: %d, current version: %d, another re-assign\n", VID, version); + return tmp; + } + } + } + auto reassignAppendEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(reassignAppendEnd - reassignAppendBegin).count(); + m_stat.m_reAssignAppendCost += elapsedMSeconds; + + auto reassignEnd = std::chrono::high_resolution_clock::now(); + elapsedMSeconds = std::chrono::duration_cast(reassignEnd - reassignBegin).count(); + m_stat.m_reAssignCost += elapsedMSeconds; + return ErrorCode::Success; + } + + bool LoadIndex(Options& p_opt, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& p_vectorTranslateMap, std::shared_ptr m_index) override { + m_versionMap = &p_versionMap; + m_opt = &p_opt; + m_vectorTranslateMap = &p_vectorTranslateMap; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DataBlockSize: %d, Capacity: %d\n", m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + std::string versionmapPath = m_opt->m_indexDirectory + FolderSep + m_opt->m_deleteIDFile; + std::string postingSizePath = m_opt->m_indexDirectory + FolderSep + m_opt->m_ssdInfoFile; + std::string checksumPath = m_opt->m_indexDirectory + FolderSep + m_opt->m_checksumFile; + if (m_opt->m_recovery) { + versionmapPath = m_opt->m_persistentBufferPath + FolderSep + m_opt->m_deleteIDFile; + postingSizePath = m_opt->m_persistentBufferPath + FolderSep + m_opt->m_ssdInfoFile; + checksumPath = m_opt->m_persistentBufferPath + FolderSep + m_opt->m_checksumFile; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: Loading version map\n"); + m_versionMap->Load(versionmapPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: Loading posting size\n"); + m_postingSizes.Load(postingSizePath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: Loading posting checksum\n"); + m_checkSums.Load(checksumPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: Current vector num: %d.\n", m_versionMap->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery:Current posting num: %d.\n", m_postingSizes.GetPostingNum()); + } + else if (m_opt->m_storage == Storage::ROCKSDBIO) { + m_versionMap->Load(versionmapPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + m_postingSizes.Load(postingSizePath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + m_checkSums.Load(checksumPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Current vector num: %d.\n", m_versionMap->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Current posting num: %d.\n", m_postingSizes.GetPostingNum()); + } else if (m_opt->m_storage == Storage::SPDKIO || m_opt->m_storage == Storage::FILEIO) { + if (fileexists((m_opt->m_indexDirectory + FolderSep + m_opt->m_ssdIndex).c_str())) { + m_versionMap->Initialize(m_opt->m_vectorSize, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Copying data from static to SPDK\n"); + std::shared_ptr storeExtraSearcher; + storeExtraSearcher.reset(new ExtraStaticSearcher()); + if (!storeExtraSearcher->LoadIndex(*m_opt, *m_versionMap, p_vectorTranslateMap, m_index)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Load Static Index Initialize Error\n"); + return false; + } + int totalPostingNum = m_index->GetNumSamples(); + + m_postingSizes.Initialize((SizeType)(totalPostingNum), m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + m_checkSums.Initialize((SizeType)(totalPostingNum), 1, m_opt->m_datasetRowsInBlock, + m_opt->m_datasetCapacity); + + std::vector threads; + std::atomic_size_t vectorsSent(0); + ErrorCode ret = ErrorCode::Success; + auto func = [&]() { + ExtraWorkSpace workSpace; + InitWorkSpace(&workSpace); + size_t index = 0; + while (true) + { + index = vectorsSent.fetch_add(1); + if (index < totalPostingNum) + { + + if ((index & ((1 << 14) - 1)) == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Copy to SPDK: Sent %.2lf%%...\n", + index * 100.0 / totalPostingNum); + } + std::string tempPosting; + if (storeExtraSearcher->GetWritePosting(&workSpace, index, tempPosting) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Static Index Read Posting fail\n"); + ret = ErrorCode::Fail; + return; + } + int vectorNum = (int)(tempPosting.size() / (m_vectorInfoSize - sizeof(uint8_t))); + + if (vectorNum > m_postingSizeLimit) + vectorNum = m_postingSizeLimit; + auto *postingP = reinterpret_cast(&tempPosting.front()); + std::string newPosting(m_vectorInfoSize * vectorNum, '\0'); + char *ptr = (char *)(newPosting.c_str()); + for (int j = 0; j < vectorNum; ++j, ptr += m_vectorInfoSize) + { + char *vectorInfo = postingP + j * (m_vectorInfoSize - sizeof(uint8_t)); + int VID = *(reinterpret_cast(vectorInfo)); + uint8_t version = m_versionMap->GetVersion(VID); + memcpy(ptr, &VID, sizeof(int)); + memcpy(ptr + sizeof(int), &version, sizeof(uint8_t)); + memcpy(ptr + sizeof(int) + sizeof(uint8_t), vectorInfo + sizeof(int), + m_vectorInfoSize - sizeof(uint8_t) - sizeof(int)); + } + if (GetWritePosting(&workSpace, index, newPosting, true) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Index Write Posting fail\n"); + ret = ErrorCode::Fail; + return; + } + } + else + { + return; + } + } + }; + for (int j = 0; j < m_opt->m_iSSDNumberOfThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + if (ret != ErrorCode::Success) + return false; + } else { + m_versionMap->Load(versionmapPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + m_postingSizes.Load(postingSizePath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + m_checkSums.Load(checksumPath, m_opt->m_datasetRowsInBlock, m_opt->m_datasetCapacity); + } + } + if (m_opt->m_update) { + if (m_splitThreadPool == nullptr) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: initialize thread pools, append: %d, reassign %d\n", m_opt->m_appendThreadNum, m_opt->m_reassignThreadNum); + + m_splitThreadPool = std::make_shared(); + m_splitThreadPool->initSPDK(m_opt->m_appendThreadNum, this); + //m_reassignThreadPool = std::make_shared(); + //m_reassignThreadPool->initSPDK(m_opt->m_reassignThreadNum, this); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: finish initialization\n"); + } + + if (m_opt->m_enableWAL && !m_opt->m_persistentBufferPath.empty()) { + std::string p_persistenWAL = m_opt->m_persistentBufferPath + FolderSep + "WAL"; + std::shared_ptr pdb; +#ifdef ROCKSDB + pdb.reset(new RocksDBIO(p_persistenWAL.c_str(), false, false)); + m_wal.reset(new PersistentBuffer(pdb)); +#else + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPFresh: Wal only support RocksDB! Please use -DROCKSDB when doing cmake.\n"); + return false; +#endif + } + } + + /** recover the previous WAL **/ + if (m_opt->m_recovery && m_opt->m_enableWAL && m_wal) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: WAL\n"); + std::string assignment; + int countAssignment = 0; + if (!m_wal->StartToScan(assignment)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: No log\n"); + return true; + } + ExtraWorkSpace workSpace; + InitWorkSpace(&workSpace); + do { + countAssignment++; + if (countAssignment % 10000 == 0) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Process %d logs\n", countAssignment); + char* ptr = (char*)(assignment.c_str()); + SizeType VID = *(reinterpret_cast(ptr)); + if (assignment.size() == m_vectorInfoSize) { + if (VID >= m_versionMap->GetVectorNum()) { + if (m_versionMap->AddBatch(VID - m_versionMap->GetVectorNum() + 1) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MemoryOverFlow: VID: %d, Map Size:%d\n", VID, m_versionMap->BufferSize()); + return false; + } + } + std::shared_ptr vectorSet; + vectorSet.reset(new BasicVectorSet(ByteArray((std::uint8_t*)ptr + sizeof(SizeType) + sizeof(uint8_t), sizeof(ValueType) * 1 * m_opt->m_dim, false), + GetEnumValueType(), m_opt->m_dim, 1)); + AddIndex(&workSpace, vectorSet, m_index, VID); + } else { + m_versionMap->Delete(VID); + } + } while (m_wal->NextToScan(assignment)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: No more to repeat, wait for rebalance\n"); + while(!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + } + return true; + } + bool ValidatePostings( + std::vector &pids, std::vector> &postings) + { + if (!m_opt->m_checksumInRead) return true; + + ErrorCode ret; + for (int i = 0; i < pids.size(); i++) + { + if (!m_checkSum.ValidateChecksum((const char *)(postings[i].GetBuffer()), + postings[i].GetAvailableSize(), *m_checkSums[pids[i]])) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "ValidatePostings fail: posting id:%d, required size:%d, buffer size:%d, checksum:%d\n", + pids[i], (int)(m_postingSizes.GetSize(pids[i]) * m_vectorInfoSize), (int)(postings[i].GetAvailableSize()), (int)(*m_checkSums[pids[i]])); + //return false; + } + } + return true; + } + + bool ValidatePostings(std::vector &pids, std::vector &postings) + { + if (!m_opt->m_checksumInRead) + return true; + + ErrorCode ret; + for (int i = 0; i < pids.size(); i++) + { + if (!m_checkSum.ValidateChecksum(postings[i].c_str(), + postings[i].size(), *m_checkSums[pids[i]])) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "ValidatePostings fail: posting id:%d, required size:%d, buffer size:%d, checksum:%d\n", + pids[i], (int)(m_postingSizes.GetSize(pids[i]) * m_vectorInfoSize), + (int)(postings[i].size()), (int)(*m_checkSums[pids[i]])); + // return false; + } + } + return true; + } + + virtual ErrorCode SearchIndex(ExtraWorkSpace* p_exWorkSpace, + QueryResult& p_queryResults, + std::shared_ptr p_index, + SearchStats* p_stats, std::set* truth, std::map>* found) override + { + if (p_stats) p_stats->m_exSetUpLatency = 0; + + COMMON::QueryResultSet& queryResults = *((COMMON::QueryResultSet*) & p_queryResults); + if (queryResults.GetResult(0)->VID != -1) + { + int head = 0; + for (int i = 0; i < queryResults.GetResultNum(); ++i) + { + SPTAG::BasicResult* ri = queryResults.GetResult(i); + if (ri->VID != -1 && !m_versionMap->Deleted(ri->VID) && !p_exWorkSpace->m_deduper.CheckAndSet(ri->VID)) + { + if (head != i) + { + SPTAG::BasicResult* rhead = queryResults.GetResult(head); + *rhead = *ri; + ri->VID = -1; + ri->Dist = MaxDist; + } + + ++head; + } + else + { + ri->VID = -1; + ri->Dist = MaxDist; + } + } + } + + int diskRead = 0; + int diskIO = 0; + int listElements = 0; + + double compLatency = 0; + double readLatency = 0; + std::chrono::microseconds remainLimit; + if (p_stats) remainLimit = m_hardLatencyLimit - std::chrono::microseconds((int)p_stats->m_totalLatency); + else remainLimit = m_hardLatencyLimit; + + auto readStart = std::chrono::high_resolution_clock::now(); + if (db->MultiGet(p_exWorkSpace->m_postingIDs, p_exWorkSpace->m_pageBuffers, remainLimit, &(p_exWorkSpace->m_diskRequests)) != ErrorCode::Success || + !ValidatePostings(p_exWorkSpace->m_postingIDs, p_exWorkSpace->m_pageBuffers)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[SearchIndex] read postings fail!\n"); + return ErrorCode::DiskIOFail; + } + auto readEnd = std::chrono::high_resolution_clock::now(); + readLatency += ((double)std::chrono::duration_cast(readEnd - readStart).count()); + + const auto postingListCount = static_cast(p_exWorkSpace->m_postingIDs.size()); + for (uint32_t pi = 0; pi < postingListCount; ++pi) { + auto curPostingID = p_exWorkSpace->m_postingIDs[pi]; + auto& buffer = (p_exWorkSpace->m_pageBuffers[pi]); + char* p_postingListFullData = (char*)(buffer.GetBuffer()); + int vectorNum = (int)(buffer.GetAvailableSize() / m_vectorInfoSize); + + diskIO += ((buffer.GetAvailableSize() + PageSize - 1) >> PageSizeEx); + diskRead += (int)(buffer.GetAvailableSize()); + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DEBUG: postingList %d size:%d m_vectorInfoSize:%d vectorNum:%d\n", pi, (int)(postingList.size()), m_vectorInfoSize, vectorNum); + int realNum = vectorNum; + listElements += vectorNum; + auto compStart = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < vectorNum; i++) { + char* vectorInfo = p_postingListFullData + i * m_vectorInfoSize; + int vectorID = *(reinterpret_cast(vectorInfo)); + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DEBUG: vectorID:%d\n", vectorID); + if (m_versionMap->Deleted(vectorID)) { + realNum--; + listElements--; + continue; + } + if(p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) { + listElements--; + continue; + } + auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), vectorInfo + m_metaDataSize); + queryResults.AddPoint(vectorID, distance2leaf); + } + auto compEnd = std::chrono::high_resolution_clock::now(); + if (realNum <= m_mergeThreshold) MergeAsync(p_index.get(), curPostingID); // TODO: Control merge + + compLatency += ((double)std::chrono::duration_cast(compEnd - compStart).count()); + + if (truth) { + for (int i = 0; i < vectorNum; ++i) { + char* vectorInfo = p_postingListFullData + i * m_vectorInfoSize; + int vectorID = *(reinterpret_cast(vectorInfo)); + if (truth->count(vectorID) != 0) + (*found)[curPostingID].insert(vectorID); + } + } + } + + if (p_stats) + { + p_stats->m_compLatency = compLatency / 1000; + p_stats->m_diskReadLatency = readLatency / 1000; + p_stats->m_totalListElementsCount = listElements; + p_stats->m_diskIOCount = diskIO; + p_stats->m_diskAccessCount = diskRead / 1024; + } + queryResults.SetScanned(listElements); + return ErrorCode::Success; + } + + virtual ErrorCode SearchIndexWithoutParsing(ExtraWorkSpace* p_exWorkSpace) + { + int retry = 0; + ErrorCode ret = ErrorCode::Undefined; + while (retry < 2 && ret != ErrorCode::Success) + { + ret = db->MultiGet(p_exWorkSpace->m_postingIDs, p_exWorkSpace->m_pageBuffers, m_hardLatencyLimit, + &(p_exWorkSpace->m_diskRequests)); + retry++; + } + if (ret == ErrorCode::Success && + !ValidatePostings(p_exWorkSpace->m_postingIDs, p_exWorkSpace->m_pageBuffers)) + { + return ErrorCode::DiskIOFail; + } + return ret; + } + + virtual ErrorCode SearchNextInPosting(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_queryResults, + std::shared_ptr& p_index, const VectorIndex* p_spann) + { + COMMON::QueryResultSet& headResults = *((COMMON::QueryResultSet*) & p_headResults); + COMMON::QueryResultSet& queryResults = *((COMMON::QueryResultSet*) & p_queryResults); + bool foundResult = false; + BasicResult* head = headResults.GetResult(p_exWorkSpace->m_ri); + while (!foundResult && p_exWorkSpace->m_pi < p_exWorkSpace->m_postingIDs.size()) { + if (head && head->VID != -1 && p_exWorkSpace->m_ri <= p_exWorkSpace->m_pi) { + if (!m_versionMap->Deleted(head->VID) && !p_exWorkSpace->m_deduper.CheckAndSet(head->VID) && + (p_exWorkSpace->m_filterFunc == nullptr || p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(head->VID)))) { + queryResults.AddPoint(head->VID, head->Dist); + foundResult = true; + } + head = headResults.GetResult(++p_exWorkSpace->m_ri); + continue; + } + auto& buffer = (p_exWorkSpace->m_pageBuffers[p_exWorkSpace->m_pi]); + char* p_postingListFullData = (char*)(buffer.GetBuffer()); + int vectorNum = (int)(buffer.GetAvailableSize() / m_vectorInfoSize); + while (p_exWorkSpace->m_offset < vectorNum) { + char* vectorInfo = p_postingListFullData + p_exWorkSpace->m_offset * m_vectorInfoSize; + p_exWorkSpace->m_offset++; + + int vectorID = *(reinterpret_cast(vectorInfo)); + if (vectorID >= m_versionMap->Count()) return ErrorCode::Key_OverFlow; + if (m_versionMap->Deleted(vectorID)) continue; + if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) continue; + if (p_exWorkSpace->m_filterFunc != nullptr && !p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(vectorID))) continue; + + auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), vectorInfo + m_metaDataSize); + queryResults.AddPoint(vectorID, distance2leaf); + foundResult = true; + break; + } + if (p_exWorkSpace->m_offset == vectorNum) { + p_exWorkSpace->m_pi++; + p_exWorkSpace->m_offset = 0; + } + } + while (!foundResult && head && head->VID != -1) { + if (!m_versionMap->Deleted(head->VID) && !p_exWorkSpace->m_deduper.CheckAndSet(head->VID) && + (p_exWorkSpace->m_filterFunc == nullptr || p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(head->VID)))) { + queryResults.AddPoint(head->VID, head->Dist); + foundResult = true; + } + head = headResults.GetResult(++p_exWorkSpace->m_ri); + } + if (foundResult) p_queryResults.SetScanned(p_queryResults.GetScanned() + 1); + return (foundResult) ? ErrorCode::Success : ErrorCode::VectorNotFound; + } + + virtual ErrorCode SearchIterativeNext(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_query, + std::shared_ptr p_index, const VectorIndex* p_spann) + { + if (p_exWorkSpace->m_loadPosting) { + ErrorCode ret = SearchIndexWithoutParsing(p_exWorkSpace); + if (ret != ErrorCode::Success) return ret; + p_exWorkSpace->m_ri = 0; + p_exWorkSpace->m_pi = 0; + p_exWorkSpace->m_offset = 0; + p_exWorkSpace->m_loadPosting = false; + } + + return SearchNextInPosting(p_exWorkSpace, p_headResults, p_query, p_index, p_spann); + } + + std::string GetPostingListFullData( + int postingListId, + size_t p_postingListSize, + Selection& p_selections, + std::shared_ptr p_fullVectors, + bool p_enableDeltaEncoding = false, + bool p_enablePostingListRearrange = false, + const ValueType* headVector = nullptr) + { + std::string postingListFullData(""); + std::string vectors(""); + std::string vectorIDs(""); + size_t selectIdx = p_selections.lower_bound(postingListId); + // iterate over all the vectors in the posting list + for (int i = 0; i < p_postingListSize; ++i) + { + if (p_selections[selectIdx].node != postingListId) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH! node:%d offset:%zu\n", postingListId, selectIdx); + throw std::runtime_error("Selection ID mismatch"); + } + std::string vectorID(""); + std::string vector(""); + + int vid = p_selections[selectIdx++].tonode; + vectorID.append(reinterpret_cast(&vid), sizeof(int)); + + ValueType* p_vector = reinterpret_cast(p_fullVectors->GetVector(vid)); + if (p_enableDeltaEncoding) + { + DimensionType n = p_fullVectors->Dimension(); + std::vector p_vector_delta(n); + for (auto j = 0; j < n; j++) + { + p_vector_delta[j] = p_vector[j] - headVector[j]; + } + vector.append(reinterpret_cast(&p_vector_delta[0]), p_fullVectors->PerVectorDataSize()); + } + else + { + vector.append(reinterpret_cast(p_vector), p_fullVectors->PerVectorDataSize()); + } + + if (p_enablePostingListRearrange) + { + vectorIDs += vectorID; + vectors += vector; + } + else + { + postingListFullData += (vectorID + vector); + } + } + if (p_enablePostingListRearrange) + { + return vectors + vectorIDs; + } + return postingListFullData; + } + + bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_headIndex, Options& p_opt, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& p_vectorTranslateMap, SizeType upperBound = -1) override { + m_versionMap = &p_versionMap; + m_vectorTranslateMap = &p_vectorTranslateMap; + m_opt = &p_opt; + + int numThreads = m_opt->m_iSSDNumberOfThreads; + int candidateNum = m_opt->m_internalResultNum; + std::unordered_map headVectorIDS; + if (m_opt->m_headIDFile.empty()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Not found VectorIDTranslate!\n"); + return false; + } + + for (int i = 0; i < p_vectorTranslateMap.R(); i++) + { + headVectorIDS[static_cast(*(p_vectorTranslateMap[i]))] = i; + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loaded %u Vector IDs\n", static_cast(headVectorIDS.size())); + + SizeType fullCount = 0; + { + auto fullVectors = p_reader->GetVectorSet(); + fullCount = fullVectors->Count(); + m_vectorInfoSize = fullVectors->PerVectorDataSize() + m_metaDataSize; + } + if (upperBound > 0) fullCount = upperBound; + + // m_metaDataSize = sizeof(int) + sizeof(uint8_t) + sizeof(float); + m_metaDataSize = sizeof(int) + sizeof(uint8_t); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build SSD Index.\n"); + + Selection selections(static_cast(fullCount) * m_opt->m_replicaCount, m_opt->m_tmpdir); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Full vector count:%d Edge bytes:%llu selection size:%zu, capacity size:%zu\n", fullCount, sizeof(Edge), selections.m_selections.size(), selections.m_selections.capacity()); + std::vector replicaCount(fullCount); + std::vector postingListSize(p_headIndex->GetNumSamples()); + for (auto& pls : postingListSize) pls = 0; + std::unordered_set emptySet; + SizeType batchSize = (fullCount + m_opt->m_batches - 1) / m_opt->m_batches; + + auto t1 = std::chrono::high_resolution_clock::now(); + if (p_opt.m_batches > 1) + { + if (selections.SaveBatch() != ErrorCode::Success) + { + return false; + } + } + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Preparation done, start candidate searching.\n"); + SizeType sampleSize = m_opt->m_samples; + std::vector samples(sampleSize, 0); + for (int i = 0; i < m_opt->m_batches; i++) { + SizeType start = i * batchSize; + SizeType end = min(start + batchSize, fullCount); + auto fullVectors = p_reader->GetVectorSet(start, end); + if (m_opt->m_distCalcMethod == DistCalcMethod::Cosine && !p_reader->IsNormalized()) fullVectors->Normalize(m_opt->m_iSSDNumberOfThreads); + + if (p_opt.m_batches > 1) { + if (selections.LoadBatch(static_cast(start) * p_opt.m_replicaCount, static_cast(end) * p_opt.m_replicaCount) != ErrorCode::Success) + { + return false; + } + emptySet.clear(); + for (auto& pair : headVectorIDS) { + if (pair.first >= start && pair.first < end) emptySet.insert(pair.first - start); + } + } + else { + for (auto& pair : headVectorIDS) { + emptySet.insert(pair.first); + } + } + + int sampleNum = 0; + for (int j = start; j < end && sampleNum < sampleSize; j++) + { + if (headVectorIDS.count(j) == 0) samples[sampleNum++] = j - start; + } + + float acc = 0; + for (int j = 0; j < sampleNum; j++) + { + COMMON::Utils::atomic_float_add(&acc, COMMON::TruthSet::CalculateRecall(p_headIndex.get(), fullVectors->GetVector(samples[j]), candidateNum)); + } + acc = acc / sampleNum; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Batch %d vector(%d,%d) loaded with %d vectors (%zu) HeadIndex acc @%d:%f.\n", i, start, end, fullVectors->Count(), selections.m_selections.size(), candidateNum, acc); + + p_headIndex->ApproximateRNG(fullVectors, emptySet, candidateNum, selections.m_selections.data(), m_opt->m_replicaCount, numThreads, m_opt->m_gpuSSDNumTrees, m_opt->m_gpuSSDLeafSize, m_opt->m_rngFactor, m_opt->m_numGPUs); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Batch %d finished!\n", i); + + for (SizeType j = start; j < end; j++) { + replicaCount[j] = 0; + size_t vecOffset = j * (size_t)m_opt->m_replicaCount; + if (headVectorIDS.count(j) == 0) { + for (int resNum = 0; resNum < m_opt->m_replicaCount && selections[vecOffset + resNum].node != INT_MAX; resNum++) { + ++postingListSize[selections[vecOffset + resNum].node]; + selections[vecOffset + resNum].tonode = j; + ++replicaCount[j]; + } + } else if (!p_opt.m_excludehead) { + selections[vecOffset].node = headVectorIDS[j]; + selections[vecOffset].tonode = j; + ++postingListSize[selections[vecOffset].node]; + ++replicaCount[j]; + } + } + + if (p_opt.m_batches > 1) + { + if (selections.SaveBatch() != ErrorCode::Success) + { + return false; + } + } + } + } + auto t2 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Searching replicas ended. Search Time: %.2lf mins\n", ((double)std::chrono::duration_cast(t2 - t1).count()) / 60.0); + + if (p_opt.m_batches > 1) + { + if (selections.LoadBatch(0, static_cast(fullCount) * p_opt.m_replicaCount) != ErrorCode::Success) + { + return false; + } + } + + // Sort results either in CPU or GPU + VectorIndex::SortSelections(&selections.m_selections); + + auto t3 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Time to sort selections:%.2lf sec.\n", ((double)std::chrono::duration_cast(t3 - t2).count()) + ((double)std::chrono::duration_cast(t3 - t2).count()) / 1000); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting size limit: %d\n", m_postingSizeLimit); + { + std::vector replicaCountDist(m_opt->m_replicaCount + 1, 0); + for (int i = 0; i < replicaCount.size(); ++i) + { + if (headVectorIDS.count(i) > 0) continue; + ++replicaCountDist[replicaCount[i]]; + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Before Posting Cut:\n"); + for (int i = 0; i < replicaCountDist.size(); ++i) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); + } + } + + Helper::Concurrent::ConcurrentSet zeroReplicaSet; + std::atomic_int64_t originalSize(0), relaxSize(0); + { + std::vector mythreads; + mythreads.reserve(m_opt->m_iSSDNumberOfThreads); + std::atomic_size_t sent(0); + int relaxLimit = m_postingSizeLimit + m_bufferSizeLimit; + for (int tid = 0; tid < m_opt->m_iSSDNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < postingListSize.size()) + { + if (postingListSize[i] <= m_postingSizeLimit) + originalSize += postingListSize[i]; + else + originalSize += m_postingSizeLimit; + + if (postingListSize[i] <= relaxLimit) + { + relaxSize += postingListSize[i]; + continue; + } + relaxSize += relaxLimit; + + std::size_t selectIdx = + std::lower_bound(selections.m_selections.begin(), selections.m_selections.end(), i, + Selection::g_edgeComparer) - + selections.m_selections.begin(); + + for (size_t dropID = relaxLimit; + dropID < postingListSize[i]; ++dropID) + { + int tonode = selections.m_selections[selectIdx + dropID].tonode; + --replicaCount[tonode]; + if (replicaCount[tonode] == 0) + { + zeroReplicaSet.insert(tonode); + } + } + postingListSize[i] = relaxLimit; + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + } + { + std::vector replicaCountDist(m_opt->m_replicaCount + 1, 0); + for (int i = 0; i < replicaCount.size(); ++i) + { + if (headVectorIDS.count(i) > 0) continue; + ++replicaCountDist[replicaCount[i]]; + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Posting Cut:\n"); + for (int i = 0; i < replicaCountDist.size(); ++i) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting cut original:%lld relax:%lld\n", originalSize.load(), + relaxSize.load()); + + // if (m_opt->m_outputEmptyReplicaID) + // { + // std::vector replicaCountDist(m_opt->m_replicaCount + 1, 0); + // auto ptr = SPTAG::f_createIO(); + // if (ptr == nullptr || !ptr->Initialize("EmptyReplicaID.bin", std::ios::binary | std::ios::out)) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create EmptyReplicaID.bin!\n"); + // return false; + // } + // for (int i = 0; i < replicaCount.size(); ++i) + // { + // if (headVectorIDS.count(i) > 0) continue; + + // ++replicaCountDist[replicaCount[i]]; + + // if (replicaCount[i] < 2) + // { + // long long vid = i; + // if (ptr->WriteBinary(sizeof(vid), reinterpret_cast(&vid)) != sizeof(vid)) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failt to write EmptyReplicaID.bin!"); + // return false; + // } + // } + // } + + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Posting Cut:\n"); + // for (int i = 0; i < replicaCountDist.size(); ++i) + // { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); + // } + // } + + + auto t4 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Time to perform posting cut:%.2lf sec.\n", ((double)std::chrono::duration_cast(t4 - t3).count()) + ((double)std::chrono::duration_cast(t4 - t3).count()) / 1000); + + auto fullVectors = p_reader->GetVectorSet(); + if (m_opt->m_distCalcMethod == DistCalcMethod::Cosine && !p_reader->IsNormalized() && !p_headIndex->m_pQuantizer) fullVectors->Normalize(m_opt->m_iSSDNumberOfThreads); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: initialize versionMap\n"); + m_versionMap->Initialize(fullCount, p_headIndex->m_iDataBlockSize, p_headIndex->m_iDataCapacity); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: Writing values to DB\n"); + + m_postingSizes.Initialize((SizeType)(postingListSize.size()), p_headIndex->m_iDataBlockSize, + p_headIndex->m_iDataCapacity); + for (int i = 0; i < postingListSize.size(); i++) + { + m_postingSizes.UpdateSize(i, postingListSize[i].load()); + } + + m_checkSums.Initialize((SizeType)(postingListSize.size()), 1, p_headIndex->m_iDataBlockSize, + p_headIndex->m_iDataCapacity); + + if (ErrorCode::Success != WriteDownAllPostingToDB(selections, fullVectors)) return false; + + if (m_opt->m_update && !m_opt->m_allowZeroReplica && zeroReplicaSet.size() > 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: initialize thread pools, append: %d, reassign %d\n", m_opt->m_appendThreadNum, m_opt->m_reassignThreadNum); + m_splitThreadPool = std::make_shared(); + m_splitThreadPool->initSPDK(m_opt->m_appendThreadNum, this); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: finish initialization, zeroReplicaCount:%d\n", (int)(zeroReplicaSet.size())); + + ExtraWorkSpace workSpace; + InitWorkSpace(&workSpace); + for (SizeType it : zeroReplicaSet) + { + std::shared_ptr vectorSet(new BasicVectorSet(ByteArray((std::uint8_t*)fullVectors->GetVector(it), sizeof(ValueType) * m_opt->m_dim, false), + GetEnumValueType(), m_opt->m_dim, 1)); + if (AddIndex(&workSpace, vectorSet, p_headIndex, it) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to add index for zero replica ID: %d\n", it); + return false; + } + } + while (!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + if (p_headIndex->SaveIndex(m_opt->m_indexDirectory + FolderSep + m_opt->m_headIndexFolder) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to save head index!\n"); + return false; + } + + if (m_vectorTranslateMap->Save(m_opt->m_indexDirectory + FolderSep + m_opt->m_headIDFile) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to save vector ID translate map!\n"); + return false; + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: Writing SSD Info and checkSum\n"); + m_postingSizes.Save(m_opt->m_indexDirectory + FolderSep + m_opt->m_ssdInfoFile); + m_checkSums.Save(m_opt->m_indexDirectory + FolderSep + m_opt->m_checksumFile); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: save versionMap\n"); + m_versionMap->Save(m_opt->m_indexDirectory + FolderSep + m_opt->m_deleteIDFile); + + auto t5 = std::chrono::high_resolution_clock::now(); + double elapsedSeconds = std::chrono::duration_cast(t5 - t1).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedSeconds / 60.0, elapsedSeconds / 3600.0); + return true; + } + + ErrorCode WriteDownAllPostingToDB(Selection& p_postingSelections, std::shared_ptr p_fullVectors) { + + std::vector threads; + std::atomic_size_t vectorsSent(0); + ErrorCode ret = ErrorCode::Success; + auto func = [&]() + { + ExtraWorkSpace workSpace; + InitWorkSpace(&workSpace); + size_t index = 0; + while (true) + { + index = vectorsSent.fetch_add(1); + if (index < m_postingSizes.GetPostingNum()) { + std::string postinglist(m_vectorInfoSize * m_postingSizes.GetSize(index), '\0'); + char* ptr = (char*)postinglist.c_str(); + std::size_t selectIdx = p_postingSelections.lower_bound((int)index); + for (int j = 0; j < m_postingSizes.GetSize(index); ++j) + { + if (p_postingSelections[selectIdx].node != index) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH\n"); + ret = ErrorCode::Fail; + return; + } + SizeType fullID = p_postingSelections[selectIdx++].tonode; + // if (id == 0) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ID: %d\n", fullID); + uint8_t version = m_versionMap->GetVersion(fullID); + // First Vector ID, then version, then Vector + Serialize(ptr, fullID, version, p_fullVectors->GetVector(fullID)); + ptr += m_vectorInfoSize; + } + *m_checkSums[index] = m_checkSum.CalcChecksum(postinglist.c_str(), (int)(postinglist.size())); + ErrorCode tmp; + if ((tmp = db->Put(index, postinglist, MaxTimeout, &(workSpace.m_diskRequests))) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[WriteDB] Put fail!\n"); + ret = tmp; + return; + } + if (m_opt->m_consistencyCheck && (tmp = db->Check(index, m_postingSizes.GetSize(index) * m_vectorInfoSize, nullptr)) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "WriteDB: Check failed after Put %d\n", index); + ret = tmp; + return; + } + } + else + { + return; + } + } + }; + + for (int j = 0; j < m_opt->m_iSSDNumberOfThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + return ret; + } + + ErrorCode AddIndex(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr& p_vectorSet, + std::shared_ptr p_index, SizeType begin) override { + + for (int v = 0; v < p_vectorSet->Count(); v++) { + SizeType VID = begin + v; + std::vector selections(static_cast(m_opt->m_replicaCount)); + int replicaCount; + RNGSelection(selections, (ValueType*)(p_vectorSet->GetVector(v)), p_index.get(), VID, replicaCount); + + uint8_t version = m_versionMap->GetVersion(VID); + std::string appendPosting(m_vectorInfoSize, '\0'); + Serialize((char*)(appendPosting.c_str()), VID, version, p_vectorSet->GetVector(v)); + if (m_opt->m_enableWAL && m_wal) { + m_wal->PutAssignment(appendPosting); + } + for (int i = 0; i < replicaCount; i++) + { + // AppendAsync(selections[i].node, 1, appendPosting_ptr); + ErrorCode ret; + if (m_opt->m_asyncAppendQueueSize > 0) { + if ((ret = AsyncAppend(p_exWorkSpace, p_index.get(), selections[i].node, 1, appendPosting)) != ErrorCode::Success) + return ret; + } else { + if ((ret = Append(p_exWorkSpace, p_index.get(), selections[i].node, 1, appendPosting)) != + ErrorCode::Success) + return ret; + } + } + } + return ErrorCode::Success; + } + + ErrorCode DeleteIndex(SizeType p_id) override { + if (m_opt->m_enableWAL && m_wal) { + std::string assignment(sizeof(SizeType), '\0'); + memcpy((char*)assignment.c_str(), &p_id, sizeof(SizeType)); + m_wal->PutAssignment(assignment); + } + if (m_versionMap->Delete(p_id)) return ErrorCode::Success; + return ErrorCode::VectorNotFound; + } + + SizeType SearchVector(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr& p_vectorSet, + std::shared_ptr p_index, int testNum = 64, SizeType VID = -1) override { + + QueryResult queryResults(p_vectorSet->GetVector(0), testNum, false); + p_index->SearchIndex(queryResults); + + std::set checked; + std::string postingList; + for (int i = 0; i < queryResults.GetResultNum(); ++i) + { + if (db->Get(queryResults.GetResult(i)->VID, &postingList, MaxTimeout, + &(p_exWorkSpace->m_diskRequests)) != ErrorCode::Success || + !m_checkSum.ValidateChecksum(postingList.c_str(), (int)(postingList.size()), *m_checkSums[queryResults.GetResult(i)->VID])) + { + continue; + } + int vectorNum = (int)(postingList.size() / m_vectorInfoSize); + + for (int j = 0; j < vectorNum; j++) { + char* vectorInfo = (char* )postingList.data() + j * m_vectorInfoSize; + int vectorID = *(reinterpret_cast(vectorInfo)); + if(checked.find(vectorID) != checked.end() || m_versionMap->Deleted(vectorID)) { + continue; + } + checked.insert(vectorID); + if (VID != -1 && VID == vectorID) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Find %d in %dth posting\n", VID, i); + auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), vectorInfo + m_metaDataSize); + if (distance2leaf < 1e-6) return vectorID; + } + } + return -1; + } + + void ForceGC(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index) override { + for (int i = 0; i < p_index->GetNumSamples(); i++) { + if (!p_index->ContainSample(i)) continue; + Split(p_exWorkSpace, p_index, i, false); + } + } + + bool AllFinished() { return m_splitThreadPool->allClear(); } // && m_reassignThreadPool->allClear(); } + void ForceCompaction() override { db->ForceCompaction(); } + void GetDBStats() override { + db->GetStat(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "remain splitJobs: %d, reassignJobs: %d, running split: %d, running reassign: %d\n", m_splitThreadPool->jobsize(), 0, m_splitThreadPool->runningJobs(), 0); + } + + int64_t GetNumBlocks() override + { + return db->GetNumBlocks(); + } + + void GetIndexStats(int finishedInsert, bool cost, bool reset) override { m_stat.PrintStat(finishedInsert, cost, reset); } + + bool CheckValidPosting(SizeType postingID) override { + return (postingID < m_postingSizes.GetPostingNum()) && (m_postingSizes.GetSize(postingID) > 0); + } + + virtual ErrorCode CheckPosting(SizeType postingID, std::vector *visited = nullptr, + ExtraWorkSpace *p_exWorkSpace = nullptr) override + { + if (postingID < 0 || postingID >= m_postingSizes.GetPostingNum()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting]: Error postingID %d (should be 0 ~ %d)\n", + postingID, m_postingSizes.GetPostingNum()); + return ErrorCode::Key_OverFlow; + } + if (m_postingSizes.GetSize(postingID) < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting]: postingID %d has wrong size:%d\n", postingID, + m_postingSizes.GetSize(postingID)); + return ErrorCode::Posting_SizeError; + } + ErrorCode ret = db->Check(postingID, m_postingSizes.GetSize(postingID) * m_vectorInfoSize, visited); + if (ret != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting]: postingID %d has wrong meta data\n", + postingID); + return ret; + } + + + if (m_opt->m_checksumInRead && p_exWorkSpace != nullptr) + { + std::string posting; + if ((ret = db->Get(postingID, &posting, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != + ErrorCode::Success || + !m_checkSum.ValidateChecksum(posting.c_str(), (int)(posting.size()), *m_checkSums[postingID])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting] Get checksum fail %d!\n", postingID); + return ret; + } + } + return ErrorCode::Success; + } + + ErrorCode GetWritePosting(ExtraWorkSpace* p_exWorkSpace, SizeType pid, std::string& posting, bool write = false) override { + ErrorCode ret; + if (write) { + if ((ret = db->Put(pid, posting, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[GetWritePosting] Put fail!\n"); + return ret; + } + + m_postingSizes.UpdateSize(pid, posting.size() / m_vectorInfoSize); + *m_checkSums[pid] = m_checkSum.CalcChecksum(posting.c_str(), (int)(posting.size())); + if (m_opt->m_consistencyCheck && (ret = db->Check(pid, m_postingSizes.GetSize(pid) * m_vectorInfoSize, nullptr)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[GetWritePosting] Check fail!\n"); + return ret; + } + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "PostingSize: %d\n", m_postingSizes.GetSize(pid)); + } else { + if ((ret = db->Get(pid, &posting, MaxTimeout, &(p_exWorkSpace->m_diskRequests))) != ErrorCode::Success || + !m_checkSum.ValidateChecksum(posting.c_str(), (int)(posting.size()), *m_checkSums[pid])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[GetWritePosting] Get fail!\n"); + return ret; + } + } + return ErrorCode::Success; + } + + ErrorCode Checkpoint(std::string prefix) override { + /**flush SPTAG, versionMap, block mapping, block pool**/ + /** Wait **/ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Waiting for index update complete\n"); + while(!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + if (m_asyncStatus != ErrorCode::Success) + return m_asyncStatus; + + std::string p_persistenMap = prefix + FolderSep + m_opt->m_deleteIDFile; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving version map\n"); + + ErrorCode ret; + if ((ret = m_versionMap->Save(p_persistenMap)) != ErrorCode::Success) + return ret; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving posting size\n"); + std::string p_persistenRecord = prefix + FolderSep + m_opt->m_ssdInfoFile; + if ((ret = m_postingSizes.Save(p_persistenRecord)) != ErrorCode::Success) + return ret; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving posting checksum\n"); + std::string p_checksumPath = prefix + FolderSep + m_opt->m_checksumFile; + if ((ret = m_checkSums.Save(p_checksumPath)) != ErrorCode::Success) + return ret; + + if ((ret = db->Checkpoint(prefix)) != ErrorCode::Success) + return ret; + if (m_opt->m_enableWAL && m_wal) { + /** delete all the previous record **/ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Checkpoint done, delete previous record\n"); + m_wal->ClearPreviousRecord(); + } + return ErrorCode::Success; + } + + ErrorCode GetPostingDebug(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr p_index, SizeType vid, std::vector& VIDs, std::shared_ptr& vecs) { + std::string posting; + db->Get(vid, &posting, MaxTimeout, &(p_exWorkSpace->m_diskRequests)); + int vectorNum = (int)(posting.size() / m_vectorInfoSize); + int vectorNum_real = vectorNum; + for (int j = 0; j < vectorNum; j++) { + char* vectorInfo = (char*)posting.data() + j * m_vectorInfoSize; + int vectorID = *(reinterpret_cast(vectorInfo)); + uint8_t version = *(reinterpret_cast(vectorInfo + sizeof(int))); + if(m_versionMap->GetVersion(vectorID) != version) { + vectorNum_real--; + } + + } + VIDs.resize(vectorNum_real); + ByteArray vector_array = ByteArray::Alloc(sizeof(ValueType) * vectorNum_real * m_opt->m_dim); + vecs.reset(new BasicVectorSet(vector_array, GetEnumValueType(), m_opt->m_dim, vectorNum_real)); + + for (int j = 0, i = 0; j < vectorNum; j++) { + char* vectorInfo = (char*)posting.data() + j * m_vectorInfoSize; + int vectorID = *(reinterpret_cast(vectorInfo)); + uint8_t version = *(reinterpret_cast(vectorInfo + sizeof(int))); + if(m_versionMap->GetVersion(vectorID) != version) { + continue; + } + VIDs[i] = vectorID; + auto outVec = vecs->GetVector(i); + memcpy(outVec, (void*)(vectorInfo + sizeof(int) + sizeof(uint8_t)), sizeof(ValueType) * m_opt->m_dim); + i++; + } + return ErrorCode::Success; + } + + private: + + int m_metaDataSize = 0; + + int m_vectorInfoSize = 0; + + int m_postingSizeLimit = INT_MAX; + + int m_bufferSizeLimit = INT_MAX; + + std::chrono::microseconds m_hardLatencyLimit = std::chrono::microseconds(2000); + + int m_mergeThreshold = 10; + ErrorCode m_asyncStatus = ErrorCode::Success; + + COMMON::Dataset* m_vectorTranslateMap; + + std::shared_ptr m_splitThreadPool; + std::shared_ptr m_reassignThreadPool; + }; +} // namespace SPTAG +#endif // _SPTAG_SPANN_EXTRADYNAMICSEARCHER_H_ diff --git a/AnnService/inc/Core/SPANN/ExtraFileController.h b/AnnService/inc/Core/SPANN/ExtraFileController.h new file mode 100644 index 000000000..70c1824e5 --- /dev/null +++ b/AnnService/inc/Core/SPANN/ExtraFileController.h @@ -0,0 +1,1161 @@ +#ifndef _SPTAG_SPANN_EXTRAFILECONTROLLER_H_ +#define _SPTAG_SPANN_EXTRAFILECONTROLLER_H_ +#define USE_ASYNC_IO + +#include "inc/Helper/KeyValueIO.h" +#include "inc/Core/Common/Dataset.h" +#include "inc/Core/Common/Labelset.h" +#include "inc/Core/Common/FineGrainedLock.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/ThreadPool.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Helper/AsyncFileReader.h" +#include "inc/Core/SPANN/Options.h" +#include +#include +#include +#include +#include +#include +namespace SPTAG::SPANN { + typedef std::int64_t AddressType; + class FileIO : public Helper::KeyValueIO { + class BlockController { + private: + char m_filePath[1024]; + std::shared_ptr m_fileHandle = nullptr; + + Helper::Concurrent::ConcurrentQueue m_blockAddresses; + Helper::Concurrent::ConcurrentQueue m_blockAddresses_reserve; + COMMON::Labelset m_available; + + std::atomic read_complete_vec = 0; + std::atomic read_submit_vec = 0; + std::atomic write_complete_vec = 0; + std::atomic write_submit_vec = 0; + std::atomic read_bytes_vec = 0; + std::atomic write_bytes_vec = 0; + std::atomic read_blocks_time_vec = 0; + + float m_growthThreshold = 0.05; + AddressType m_growthBlocks = 0; + AddressType m_maxBlocks = 0; + int m_batchSize = 64; + int m_preIOCompleteCount = 0; + int64_t m_preIOBytes = 0; + bool m_disableCheckpoint = false; + + std::chrono::high_resolution_clock::time_point m_startTime; + std::chrono::time_point m_preTime = std::chrono::high_resolution_clock::now(); + + std::atomic m_batchReadTimes = 0; + std::atomic m_batchReadTimeouts = 0; + + std::mutex m_expandLock; + std::atomic m_totalAllocatedBlocks = 0; + + private: + bool ExpandFile(AddressType blocksToAdd); + bool NeedsExpansion(int psize); + + // static void Start(void* args); + + // static void FileIoLoop(void *arg); + + // static void FileIoCallback(bool success, void *cb_arg); + + // static void Stop(void* args); + + public: + bool Initialize(SPANN::Options& p_opt); + + bool GetBlocks(AddressType* p_data, int p_size); + + bool ReleaseBlocks(AddressType* p_data, int p_size); + + bool ReadBlocks(AddressType* p_data, std::string* p_value, const std::chrono::microseconds &timeout, std::vector* reqs); + + bool ReadBlocks(const std::vector& p_data, std::vector* p_value, const std::chrono::microseconds &timeout, std::vector* reqs); + + bool ReadBlocks(const std::vector& p_data, std::vector>& p_value, const std::chrono::microseconds& timeout, std::vector* reqs); + + bool WriteBlocks(AddressType* p_data, int p_size, const std::string& p_value, const std::chrono::microseconds& timeout, std::vector* reqs); + + bool IOStatistics(); + + bool ShutDown(); + + int RemainBlocks() { + return (int)(m_blockAddresses.unsafe_size()); + } + + int ReserveBlocks() { + return (int)(m_blockAddresses_reserve.unsafe_size()); + } + + int TotalBlocks() { + return (int)(m_totalAllocatedBlocks.load()); + } + + ErrorCode Checkpoint(std::string prefix) { + std::string filename = prefix + "_blockpool"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Checkpoint - Starting block pool save...\n"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Checkpoint - Reload reserved blocks...\n"); + AddressType currBlockAddress = 0; + while (m_blockAddresses_reserve.try_pop(currBlockAddress)) + { + m_blockAddresses.push(currBlockAddress); + } + AddressType blocks = RemainBlocks(); + AddressType totalBlocks = m_totalAllocatedBlocks.load(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Checkpoint - Total allocated blocks: %llu\n", static_cast(totalBlocks)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Checkpoint - Remaining free blocks: %llu\n", blocks); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Checkpoint - Saving to file: %s\n", filename.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::out)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FileIO::BlockController::Checkpoint - Failed to open file: %s\n", filename.c_str()); + return ErrorCode::FailedCreateFile; + } + IOBINARY(ptr, WriteBinary, sizeof(AddressType), reinterpret_cast(&blocks)); + IOBINARY(ptr, WriteBinary, sizeof(AddressType), reinterpret_cast(&totalBlocks)); + for (auto it = m_blockAddresses.unsafe_begin(); it != m_blockAddresses.unsafe_end(); it++) { + IOBINARY(ptr, WriteBinary, sizeof(AddressType), reinterpret_cast(&(*it))); + } + /* + int i = 0; + for (auto it = m_blockAddresses.unsafe_begin(); it != m_blockAddresses.unsafe_end(); it++) { + std::cout << *it << " "; + i++; + if (i == 10) break; + } + std::cout << std::endl; + */ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Finish!\n"); + return ErrorCode::Success; + } + + ErrorCode LoadBlockPool(std::string prefix, AddressType startNumBlocks, bool allowInit, int blockSize, int blockCapacity) { + std::string blockfile = prefix + "_blockpool"; + if (allowInit && !fileexists(blockfile.c_str())) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "FileIO::BlockController::LoadBlockPool: initializing fresh pool (no existing file " + "found: %s)\n", + blockfile.c_str()); + m_available.Initialize(startNumBlocks, blockSize, blockCapacity); + for(AddressType i = 0; i < startNumBlocks; i++) { + m_blockAddresses.push(i); + m_available.Insert(i); + } + m_totalAllocatedBlocks.store(startNumBlocks); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::LoadBlockPool: initialized with %llu blocks (%.2f GB)\n", + static_cast(startNumBlocks), static_cast(startNumBlocks >> (30 - PageSizeEx))); + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Load blockpool from %s\n", blockfile.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(blockfile.c_str(), std::ios::binary | std::ios::in)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open the blockpool file: %s\n", blockfile.c_str()); + return ErrorCode::Fail; + } + + AddressType blocks = 0; + AddressType totalAllocated = 0; + + // Read block count + IOBINARY(ptr, ReadBinary, sizeof(AddressType), reinterpret_cast(&blocks)); + IOBINARY(ptr, ReadBinary, sizeof(AddressType), reinterpret_cast(&totalAllocated)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "FileIO::BlockController::LoadBlockPool: reading %llu free blocks into pool (%.2f GB), total allocated: %llu blocks\n", + static_cast(blocks), + static_cast(blocks >> (30 - PageSizeEx)), + static_cast(totalAllocated)); + + m_available.Initialize(totalAllocated, blockSize, blockCapacity); + AddressType currBlockAddress = 0; + for (AddressType i = 0; i < blocks; ++i) { + IOBINARY(ptr, ReadBinary, sizeof(AddressType), reinterpret_cast(&currBlockAddress)); + if (!m_available.Insert(currBlockAddress)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "LoadBlockPool: Double release blockid:%lld\n", + currBlockAddress); + continue; + } + m_blockAddresses.push(currBlockAddress); + } + + m_totalAllocatedBlocks.store(totalAllocated); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "FileIO::BlockController::LoadBlockPool: block pool initialized. Available: %llu, Total allocated: %llu\n", + static_cast(blocks), + static_cast(totalAllocated)); + } + /* + int i = 0; + for (auto it = m_blockAddresses.unsafe_begin(); it != m_blockAddresses.unsafe_end(); it++) { + std::cout << *it << " "; + i++; + if (i == 10) break; + } + std::cout << std::endl; + */ + return ErrorCode::Success; + } + }; + + class LRUCache { + int64_t capacity; + int64_t limit; + int64_t size; + std::list keys; // Page Address + std::unordered_map::iterator>> cache; // Page Address -> Page Address in Cache + int64_t queries; + std::atomic hits; + FileIO* fileIO; + Helper::RequestQueue processIocp; + std::vector reqs; + std::vector> pageBuffers; + + public: + LRUCache(int64_t capacity, int64_t limit, FileIO* fileIO) { + this->capacity = capacity; + this->limit = min(capacity, (limit << PageSizeEx)); + this->size = 0; + this->queries = 0; + this->hits = 0; + this->fileIO = fileIO; + this->reqs.resize(limit); + this->pageBuffers.resize(limit); + for (int i = 0; i < limit; i++) { + this->pageBuffers[i].ReservePageBuffer(PageSize); + auto& req = this->reqs[i]; + req.m_buffer = (char*)(this->pageBuffers[i].GetBuffer()); + req.m_extension = &processIocp; + +#ifdef _MSC_VER + memset(&(req.myres.m_col), 0, sizeof(OVERLAPPED)); + req.myres.m_col.m_data = (void*)(&req); +#else + memset(&(req.myiocb), 0, sizeof(struct iocb)); + req.myiocb.aio_buf = reinterpret_cast(req.m_buffer); + req.myiocb.aio_data = reinterpret_cast(&req); +#endif + } + } + + ~LRUCache() {} + + bool evict(SizeType key, void* value, int vsize, std::unordered_map::iterator>>::iterator& it) { + if (value != nullptr) { + std::string valstr((char*)value, vsize); + if (fileIO->Put(key, valstr, MaxTimeout, &reqs, false) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LRUCache: evict key:%d value size:%d to file failed\n", key, vsize); + return false; + } + } + + size -= it->second.first.size(); + keys.erase(it->second.second); + cache.erase(it); + return true; + } + + bool get(SizeType key, void* value, int& get_size) { + queries++; + auto it = cache.find(key); + if (it == cache.end()) { + return false; // If the key does not exist, return -1 + } + if (get_size > it->second.first.size()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cache get error: key %d required size %d, real size = %d\n", key, get_size, (int)(it->second.first.size())); + get_size = (int)(it->second.first.size()); + } + // Update access order, move the key to the head of the linked list + memcpy((char*)value, it->second.first.data(), get_size); + hits++; + return true; + } + + bool put(SizeType key, void* value, int put_size) { + auto it = cache.find(key); + if (it != cache.end()) { + if (put_size > limit) { + evict(key, it->second.first.data(), it->second.first.size(), it); + return false; + } + keys.splice(keys.begin(), keys, it->second.second); + it->second.second = keys.begin(); + + auto delta_size = put_size - (int)(it->second.first.size()); + while ((int)(capacity - size) < delta_size && (keys.size() > 1)) { + auto last = keys.back(); + auto lastit = cache.find(last); + if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) { + return false; + } + } + it->second.first.resize(put_size); + memcpy(it->second.first.data(), (char*)value, put_size); + size += delta_size; + hits++; + return true; + } + if (put_size > limit) { + return false; + } + while (put_size > (int)(capacity - size) && (!keys.empty())) { + auto last = keys.back(); + auto lastit = cache.find(last); + if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) { + return false; + } + } + auto keys_it = keys.insert(keys.begin(), key); + cache.insert({key, {std::string((char*)value, put_size), keys_it}}); + size += put_size; + return true; + } + + bool del(SizeType key) { + auto it = cache.find(key); + if (it == cache.end()) { + return false; // If the key does not exist, return false + } + evict(key, nullptr, 0, it); + return true; + } + + bool merge(SizeType key, void* value, int merge_size) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge size: %lld\n", merge_size); + auto it = cache.find(key); + if (it == cache.end()) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge key not found\n"); + std::string valstr; + if (fileIO->Get(key, &valstr, MaxTimeout, &reqs, false) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LRUCache: merge key not found in file\n"); + return false; // If the key does not exist, return false + } + cache.insert({key, {valstr, keys.insert(keys.begin(), key)}}); + size += valstr.size(); + it = cache.find(key); + } else { + hits++; + } + + if (merge_size + it->second.first.size() > limit) { + evict(key, it->second.first.data(), it->second.first.size(), it); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge size exceeded\n"); + return false; + } + keys.splice(keys.begin(), keys, it->second.second); + it->second.second = keys.begin(); + while((int)(capacity - size) < merge_size && (keys.size() > 1)) { + auto last = keys.back(); + auto lastit = cache.find(last); + if (!evict(last, lastit->second.first.data(), lastit->second.first.size(), lastit)) { + return false; + } + } + it->second.first.append((char*)value, merge_size); + size += merge_size; + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "LRUCache: merge success\n"); + return true; + } + + std::pair get_stat() { + return {queries, hits.load()}; + } + + bool flush() { + for (auto it = cache.begin(); it != cache.end(); it++) { + if (fileIO->Put(it->first, it->second.first, MaxTimeout, &reqs, false) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "LRUCache: evict key:%d value size:%d to file failed\n", it->first, (int)(it->second.first.size())); + return false; + } + } + cache.clear(); + keys.clear(); + size = 0; + return true; + } + }; + + class ShardedLRUCache { + int shards; + std::vector caches; + std::unique_ptr m_rwMutexs; + + public: + ShardedLRUCache(int shards, int64_t capacity, int64_t limit, FileIO* fileIO) : shards(shards) { + caches.resize(shards); + m_rwMutexs.reset(new std::shared_timed_mutex[shards]); + for (int i = 0; i < shards; i++) { + caches[i] = new LRUCache(capacity / shards, limit, fileIO); + } + if (capacity % shards != 0) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "LRUCache: capacity is not divisible by shards\n"); + } + } + + ~ShardedLRUCache() { + for (int i = 0; i < shards; i++) { + delete caches[i]; + } + } + + bool get(SizeType key, void* value, int& get_size) { + SizeType cid = hash(key); + std::shared_lock lock(m_rwMutexs[cid]); + return caches[cid]->get(key, value, get_size); + } + + bool put(SizeType key, void* value, int put_size) { + return caches[hash(key)]->put(key, value, put_size); + } + + bool del(SizeType key) { + return caches[hash(key)]->del(key); + } + + bool merge(SizeType key, void* value, int merge_size) { + return caches[hash(key)]->merge(key, value, merge_size); + } + + bool flush() { + for (int i = 0; i < shards; i++) { + if (!caches[i]->flush()) return false; + } + return true; + } + + std::shared_timed_mutex& getlock(SizeType key) + { + return m_rwMutexs[hash(key)]; + } + + SizeType hash(SizeType key) const + { + return key % shards; + } + + std::pair get_stat() { + int64_t queries = 0, hits = 0; + for (int i = 0; i < shards; i++) { + auto stat = caches[i]->get_stat(); + queries += stat.first; + hits += stat.second; + } + return {queries, hits}; + } + }; + + public: + FileIO(SPANN::Options& p_opt) { + m_mappingPath = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdMappingFile; + m_blockLimit = max(p_opt.m_postingPageLimit, p_opt.m_searchPostingPageLimit) + p_opt.m_bufferLength + 1; + m_bufferLimit = 1024; + m_shutdownCalled = true; + + m_pShardedLRUCache = nullptr; + int64_t capacity = static_cast(p_opt.m_cacheSize) << 30; + if (capacity > 0) { + m_pShardedLRUCache = new ShardedLRUCache(p_opt.m_cacheShards, capacity, m_blockLimit, this); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO: Using LRU Cache with capacity %d GB, limit %d pages, shards %d\n", p_opt.m_cacheSize, m_blockLimit, p_opt.m_cacheShards); + } + + if (p_opt.m_recovery) { + std::string recoverpath = p_opt.m_persistentBufferPath + FolderSep + p_opt.m_ssdMappingFile; + if (fileexists(recoverpath.c_str())) { + Load(recoverpath, 1024 * 1024, MaxSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load block mapping successfully from %s!\n", recoverpath.c_str()); + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot recover block mapping from %s!\n", recoverpath.c_str()); + return; + } + } + else { + if (fileexists(m_mappingPath.c_str())) { + Load(m_mappingPath, 1024 * 1024, MaxSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load block mapping successfully from %s!\n", m_mappingPath.c_str()); + /* + for (int i = 0; i < 10; i++) { + std::cout << "i=" << i << ": "; + for (int j = 0; j < 10; j++) { + std::cout << m_pBlockMapping[i][j] << " "; + } + std::cout << std::endl; + } + for (int i = m_pBlockMapping.R() - 10; i < m_pBlockMapping.R(); i++) { + std::cout << "i=" << i << ": "; + for (int j = 0; j < 10; j++) { + std::cout << m_pBlockMapping[i][j] << " "; + } + std::cout << std::endl; + } + */ + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Initialize block mapping successfully!\n"); + m_pBlockMapping.Initialize(0, 1, 1024 * 1024, MaxSize); + } + } + for (int i = 0; i < m_bufferLimit; i++) { + m_buffer.push((uintptr_t)(new AddressType[m_blockLimit])); + } + m_compactionThreadPool = std::make_shared(); + m_compactionThreadPool->init(1); + + if (!m_pBlockController.Initialize(p_opt)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Initialize FileIO!\n"); + return; + } + + m_shutdownCalled = false; + } + + ~FileIO() { + ShutDown(); + } + + bool Available() override + { + return !m_shutdownCalled; + } + + void ShutDown() override { + if (m_shutdownCalled) { + return; + } + m_shutdownCalled = true; + while (!m_key_reserve.empty()) + { + SizeType cleanKey = 0xffffffff; + if (m_key_reserve.try_pop(cleanKey)) + { + At(cleanKey) = 0xffffffffffffffff; + } + } + if (!m_mappingPath.empty()) Save(m_mappingPath); + // TODO: Should we add a lock here? + for (int i = 0; i < m_pBlockMapping.R(); i++) { + if (At(i) != 0xffffffffffffffff) { + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "delete at %d for addr: %llu\n", i, At(i)); + delete[]((AddressType*)At(i)); + At(i) = 0xffffffffffffffff; + } + } + while (!m_buffer.empty()) { + uintptr_t ptr = 0xffffffffffffffff; + if (m_buffer.try_pop(ptr)) delete[]((AddressType*)ptr); + } + m_pBlockController.ShutDown(); + if (m_pShardedLRUCache) { + delete m_pShardedLRUCache; + m_pShardedLRUCache = nullptr; + } + } + + inline uintptr_t& At(SizeType key) { + return *(m_pBlockMapping[key]); + } + + ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds &timeout, std::vector* reqs, bool useCache) { + auto get_begin_time = std::chrono::high_resolution_clock::now(); + SizeType r = m_pBlockMapping.R(); + + if (key >= r) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Key OverFlow! Key:%d R:%d\n", key, r); + return ErrorCode::Key_OverFlow; + } + AddressType* addr = (AddressType*)(At(key)); + if (((uintptr_t)addr) == 0xffffffffffffffff) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Key NotFound! Key:%d\n", key); + return ErrorCode::Key_NotFound; + } + + int size = (int)(addr[0]); + if (size < 0) return ErrorCode::Posting_SizeError; + + if (useCache && m_pShardedLRUCache) { + value->resize(size); + if (m_pShardedLRUCache->get(key, value->data(), size)) { + value->resize(size); + return ErrorCode::Success; + } + } + + // if (m_pBlockController.ReadBlocks((AddressType*)At(key), value)) { + // return ErrorCode::Success; + // } + auto begin_time = std::chrono::high_resolution_clock::now(); + auto result = m_pBlockController.ReadBlocks((AddressType*)At(key), value, timeout, reqs); + auto end_time = std::chrono::high_resolution_clock::now(); + read_time_vec += std::chrono::duration_cast(end_time - begin_time).count(); + get_times_vec++; + auto get_end_time = std::chrono::high_resolution_clock::now(); + get_time_vec += std::chrono::duration_cast(get_end_time - get_begin_time).count(); + return result ? ErrorCode::Success : ErrorCode::Fail; + } + + ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds &timeout, std::vector* reqs) override { + return Get(key, value, timeout, reqs, true); + } + + ErrorCode Get(const std::string& key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + return Get(std::stoi(key), value, timeout, reqs, true); + } + + ErrorCode MultiGet(const std::vector& keys, std::vector>& values, + const std::chrono::microseconds &timeout, std::vector* reqs) override { + std::vector blocks; + std::set lock_keys; + SizeType r; + int i = 0; + for (SizeType key : keys) { + r = m_pBlockMapping.R(); + if (key < r) { + AddressType* addr = (AddressType*)(At(key)); + if (m_pShardedLRUCache && ((uintptr_t)addr) != 0xffffffffffffffff && addr[0] >= 0) { + int size = (int)(addr[0]); + values[i].ReservePageBuffer(size); + if (m_pShardedLRUCache->get(key, values[i].GetBuffer(), size)) { + values[i].SetAvailableSize(size); + blocks.push_back(nullptr); + } + else { + blocks.push_back(addr); + } + } else { + blocks.push_back(addr); + } + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read key:%d total key number:%d\n", key, r); + } + i++; + } + // if (m_pBlockController.ReadBlocks(blocks, values, timeout)) return ErrorCode::Success; + auto result = m_pBlockController.ReadBlocks(blocks, values, timeout, reqs); + return result ? ErrorCode::Success : ErrorCode::Fail; + } + + + ErrorCode MultiGet(const std::vector& keys, std::vector* values, + const std::chrono::microseconds& timeout, std::vector* reqs) { + std::vector blocks; + std::set lock_keys; + SizeType r; + values->resize(keys.size()); + int i = 0; + for (SizeType key : keys) { + r = m_pBlockMapping.R(); + if (key < r) { + AddressType* addr = (AddressType*)(At(key)); + if (m_pShardedLRUCache && ((uintptr_t)addr) != 0xffffffffffffffff && addr[0] >= 0) { + int size = (int)(addr[0]); + (*values)[i].resize(size); + if (m_pShardedLRUCache->get(key, (*values)[i].data(), size)) { + (*values)[i].resize(size); + blocks.push_back(nullptr); + } + else { + blocks.push_back(addr); + } + } + else { + blocks.push_back(addr); + } + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read key:%d total key number:%d\n", key, r); + } + i++; + } + // if (m_pBlockController.ReadBlocks(blocks, values, timeout)) return ErrorCode::Success; + auto result = m_pBlockController.ReadBlocks(blocks, values, timeout, reqs); + return result ? ErrorCode::Success : ErrorCode::Fail; + } + + ErrorCode MultiGet(const std::vector& keys, std::vector* values, + const std::chrono::microseconds &timeout, std::vector* reqs) override { + std::vector int_keys; + for (const auto& key : keys) { + int_keys.push_back(std::stoi(key)); + } + return MultiGet(int_keys, values, timeout, reqs); + } + + /* + ErrorCode Scan(const SizeType start_key, const int record_count, std::vector &values, const std::chrono::microseconds &timeout = (std::chrono::microseconds::max)(), std::vector* reqs = nullptr) { + std::vector keys; + std::vector blocks; + SizeType curr_key = start_key; + while(keys.size() < record_count && curr_key < m_pBlockMapping.R()) { + if (At(curr_key) == 0xffffffffffffffff) { + curr_key++; + continue; + } + keys.push_back(curr_key); + blocks.push_back((AddressType*)At(curr_key)); + curr_key++; + } + auto result = m_pBlockController.ReadBlocks(blocks, values, timeout, reqs); + return result ? ErrorCode::Success : ErrorCode::Fail; + } + */ + + ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs, bool useCache) { + int blocks = (int)(((value.size() + PageSize - 1) >> PageSizeEx)); + if (blocks >= m_blockLimit) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to put key:%d value:%lld since value too long!\n", key, value.size()); + return ErrorCode::Posting_OverFlow; + } + // Calculate whether more mapping blocks are needed + int delta = key + 1 - m_pBlockMapping.R(); + if (delta > 0) { + // std::lock_guard lock(m_updateMutex); + m_updateMutex.lock(); + delta = key + 1 - m_pBlockMapping.R(); + if (delta > 0) { + m_pBlockMapping.AddBatch(delta); + } + m_updateMutex.unlock(); + } + + std::shared_timed_mutex *lock = nullptr; + if (useCache && m_pShardedLRUCache) + { + lock = &(m_pShardedLRUCache->getlock(key)); + lock->lock(); + } + uintptr_t tmpblocks = 0xffffffffffffffff; + // If this key has not been assigned mapping blocks yet, allocate a batch. + if (At(key) == 0xffffffffffffffff) { + // If there are spare blocks in m_buffer, use them directly; otherwise, allocate a new batch. + if (m_buffer.unsafe_size() > m_bufferLimit) { + while (!m_buffer.try_pop(tmpblocks)); + } + else { + tmpblocks = (uintptr_t)(new AddressType[m_blockLimit]); + } + // The 0th element of the block address list represents the data size; set it to -1. + memset((AddressType*)tmpblocks, -1, sizeof(AddressType) * m_blockLimit); + } else { + tmpblocks = At(key); + } + int64_t* postingSize = (int64_t*)tmpblocks; + // If postingSize is less than 0, it means the mapping block is newly allocated—directly + if (*postingSize < 0) { + if (!m_pBlockController.GetBlocks(postingSize + 1, blocks)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Put] Not enough blocks in the pool can be allocated!\n"); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + *postingSize = value.size(); + if (!useCache || m_pShardedLRUCache == nullptr || !m_pShardedLRUCache->put(key, (void*)(value.data()), (SPTAG::SizeType)(value.size()))) { + if (!m_pBlockController.WriteBlocks(postingSize + 1, blocks, value, timeout, reqs)) + { + m_pBlockController.ReleaseBlocks(postingSize + 1, blocks); + memset(postingSize + 1, -1, sizeof(AddressType) * blocks); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Put] Write new block failed!\n"); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + } + At(key) = tmpblocks; + } + else { + uintptr_t partialtmpblocks = 0xffffffffffffffff; + // Take a batch of mapping blocks from the buffer, and return a batch back later. + while (!m_buffer.try_pop(partialtmpblocks)); + // Acquire a new batch of disk blocks and write data directly. + // To ensure the effectiveness of the checkpoint, new blocks must be allocated for writing here. + if (!m_pBlockController.GetBlocks((AddressType*)partialtmpblocks + 1, blocks)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Put] Not enough blocks in the pool can be allocated!\n"); + m_buffer.push(partialtmpblocks); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + *((int64_t*)partialtmpblocks) = value.size(); + if (!useCache || m_pShardedLRUCache == nullptr || !m_pShardedLRUCache->put(key, (void*)(value.data()), (SPTAG::SizeType)(value.size()))) { + if (!m_pBlockController.WriteBlocks((AddressType*)partialtmpblocks + 1, blocks, value, timeout, reqs)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Put] Write new block failed!\n"); + m_pBlockController.ReleaseBlocks((AddressType*)partialtmpblocks + 1, blocks); + m_buffer.push(partialtmpblocks); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + } + + // Release the original blocks + m_pBlockController.ReleaseBlocks(postingSize + 1, (*postingSize + PageSize - 1) >> PageSizeEx); + while (InterlockedCompareExchange(&At(key), partialtmpblocks, (uintptr_t)postingSize) != (uintptr_t)postingSize) { + postingSize = (int64_t*)At(key); + } + m_buffer.push((uintptr_t)postingSize); + } + if (lock) lock->unlock(); + return ErrorCode::Success; + } + + ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + return Put(key, value, timeout, reqs, true); + } + + ErrorCode Put(const std::string &key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + return Put(std::stoi(key), value, timeout, reqs, true); + } + + ErrorCode Check(const SizeType key, int size, std::vector *visited) override + { + SizeType r = m_pBlockMapping.R(); + + if (key >= r || At(key) == 0xffffffffffffffff) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Check] Key range error: key: %d, mapping size: %d\n", key, r); + return ErrorCode::Key_OverFlow; + } + + int64_t *postingSize = (int64_t *)At(key); + if ((*postingSize) != size) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Check] Key %d failed: postingSize %d is not match real size %d\n", key, (int)(*postingSize), size); + return ErrorCode::Posting_SizeError; + } + + int blocks = ((*postingSize + PageSize - 1) >> PageSizeEx); + for (int i = 1; i <= blocks; i++) + { + if (postingSize[i] < 0 || postingSize[i] >= m_pBlockController.TotalBlocks()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Check] Key %d failed: error block id %d (should be 0 ~ %d)\n", key, + (int)(postingSize[i]), m_pBlockController.TotalBlocks()); + return ErrorCode::Block_IDError; + } + + if (visited == nullptr) + continue; + + if (visited->at(postingSize[i])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Check] Block %lld double used!\n", postingSize[i]); + return ErrorCode::Block_IDError; + } + else + { + visited->at(postingSize[i]) = true; + } + } + return ErrorCode::Success; + } + + void PrintPostingDiff(std::string& p1, std::string& p2, const char* pos) { + if (p1.size() != p2.size()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge %s: p1 and p2 have different sizes: before=%u after=%u\n", pos, p1.size(), p2.size()); + return; + } + std::string diff = ""; + for (size_t i = 0; i < p1.size(); i+=4) { + if (p1[i] != p2[i]) { + diff += "[" + std::to_string(i) + "]:" + std::to_string(int(p1[i])) + "^" + std::to_string(int(p2[i])) + " "; + } + } + if (diff.size() != 0) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Merge %s: %s\n", pos, diff.c_str()); + } + } + + ErrorCode Merge(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) { + SizeType r = m_pBlockMapping.R(); + if (key >= r) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Key range error: key: %d, mapping size: %d\n", key, r); + return ErrorCode::Key_OverFlow; + } + + std::shared_timed_mutex *lock = nullptr; + if (m_pShardedLRUCache) + { + lock = &(m_pShardedLRUCache->getlock(key)); + lock->lock(); + } + + int64_t* postingSize = (int64_t*)At(key); + if (((uintptr_t)postingSize) == 0xffffffffffffffff || *postingSize < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Merge] Key %d failed: postingSize < 0\n", key); + if (lock) lock->unlock(); + return ErrorCode::Key_NotFound; + //return Put(key, value, timeout, reqs); + } + + auto newSize = *postingSize + value.size(); + int newblocks = ((newSize + PageSize - 1) >> PageSizeEx); + if (newblocks >= m_blockLimit) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Key %d failed: new size %lld bytes requires %d blocks (limit: %d)\n", + key, static_cast(newSize), newblocks, m_blockLimit); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Original size: %lld bytes, Merge size: %lld bytes\n", + static_cast(*postingSize), static_cast(value.size())); + if (lock) lock->unlock(); + return ErrorCode::Posting_OverFlow; + } + + //std::string before; + //Get(key, &before, timeout, reqs); + + auto sizeInPage = (*postingSize) % PageSize; // Actual size of the last block + int oldblocks = (*postingSize >> PageSizeEx); + int allocblocks = newblocks - oldblocks; + // If the last block is not full, we need to read it first, then append the new data, and write it back. + if (sizeInPage != 0) { + std::string newValue; + AddressType readreq[] = { sizeInPage, *(postingSize + 1 + oldblocks) }; + if (!m_pBlockController.ReadBlocks(readreq, &newValue, timeout, reqs)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Merge] Cannot read original posting!\n"); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + //std::string lastblock = before.substr(before.size() - sizeInPage); + //PrintPostingDiff(lastblock, newValue, "0"); + newValue += value; + + uintptr_t tmpblocks = 0xffffffffffffffff; + while (!m_buffer.try_pop(tmpblocks)); + memcpy((AddressType*)tmpblocks, postingSize, sizeof(AddressType) * (oldblocks + 1)); + if (!m_pBlockController.GetBlocks((AddressType *)tmpblocks + 1 + oldblocks, allocblocks)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Merge] Not enough blocks in the pool can be allocated!\n"); + m_buffer.push(tmpblocks); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + *((int64_t*)tmpblocks) = newSize; + if (m_pShardedLRUCache == nullptr || !m_pShardedLRUCache->merge(key, (void *)(value.data()), value.size())) { + if (!m_pBlockController.WriteBlocks((AddressType *)tmpblocks + 1 + oldblocks, allocblocks, newValue, + timeout, reqs)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Write new block failed!\n"); + m_pBlockController.ReleaseBlocks((AddressType *)tmpblocks + 1 + oldblocks, allocblocks); + m_buffer.push(tmpblocks); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + } + + // This is also to ensure checkpoint correctness, so we release the partially used block and allocate a new one. + m_pBlockController.ReleaseBlocks(postingSize + 1 + oldblocks, 1); + while (InterlockedCompareExchange(&At(key), tmpblocks, (uintptr_t)postingSize) != (uintptr_t)postingSize) { + postingSize = (int64_t*)At(key); + } + m_buffer.push((uintptr_t)postingSize); + } + else { // Otherwise, directly allocate a new batch of blocks to append after the current ones. + if (!m_pBlockController.GetBlocks(postingSize + 1 + oldblocks, allocblocks)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Not enough blocks in the pool can be allocated!\n"); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + + if (m_pShardedLRUCache == nullptr || !m_pShardedLRUCache->merge(key, (void *)(value.data()), value.size())) { + if (!m_pBlockController.WriteBlocks(postingSize + 1 + oldblocks, allocblocks, value, timeout, reqs)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Merge] Write new block failed!\n"); + m_pBlockController.ReleaseBlocks(postingSize + 1 + oldblocks, allocblocks); + if (lock) lock->unlock(); + return ErrorCode::DiskIOFail; + } + } + *postingSize = newSize; + } + /* + std::string after; + Get(key, &after, timeout, reqs); + before += value; + PrintPostingDiff(before, after, "1"); + */ + if (lock) lock->unlock(); + return ErrorCode::Success; + } + + ErrorCode Merge(const std::string &key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) { + return Merge(std::stoi(key), value, timeout, reqs); + } + + ErrorCode Delete(SizeType key) override { + SizeType r = m_pBlockMapping.R(); + if (key >= r) return ErrorCode::Key_OverFlow; + + std::shared_timed_mutex *lock = nullptr; + if (m_pShardedLRUCache) + { + lock = &(m_pShardedLRUCache->getlock(key)); + lock->lock(); + m_pShardedLRUCache->del(key); + } + + int64_t* postingSize = (int64_t*)At(key); + if (((uintptr_t)postingSize) == 0xffffffffffffffff || *postingSize < 0) + { + if (lock) lock->unlock(); + return ErrorCode::Key_NotFound; + } + + int blocks = ((*postingSize + PageSize - 1) >> PageSizeEx); + m_pBlockController.ReleaseBlocks(postingSize + 1, blocks); + m_buffer.push((uintptr_t)postingSize); + //m_key_reserve.push(key); + /* + while (m_key_reserve.unsafe_size() > m_bufferLimit) + { + SizeType cleanKey = 0; + if (m_key_reserve.try_pop(cleanKey)) + { + At(cleanKey) = 0xffffffffffffffff; + } + } + */ + At(key) = 0xffffffffffffffff; + if (lock) lock->unlock(); + return ErrorCode::Success; + } + + ErrorCode Delete(const std::string &key) { + return Delete(std::stoi(key)); + } + + void ForceCompaction() { + Save(m_mappingPath); + } + + void GetStat() { + int remainBlocks = m_pBlockController.RemainBlocks(); + int reserveBlocks = m_pBlockController.ReserveBlocks(); + int totalBlocks = m_pBlockController.TotalBlocks(); + int remainGB = ((long long)(remainBlocks + reserveBlocks) >> (30 - PageSizeEx)); + // int remainGB = remainBlocks >> 20 << 2; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total %d blocks, Remain %d blocks, Reserve %d blocks, totally %d GB\n", totalBlocks, remainBlocks, reserveBlocks, remainGB); + double average_read_time = (double)read_time_vec / get_times_vec; + double average_get_time = (double)get_time_vec / get_times_vec; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Get times: %llu, get time: %llu us, read time: %llu us\n", get_times_vec.load(), get_time_vec.load(), read_time_vec.load()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Average read time: %lf us, average get time: %lf us\n", average_read_time, average_get_time); + if (m_pShardedLRUCache) { + auto cache_stat = m_pShardedLRUCache->get_stat(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Cache queries: %lld, Cache hits: %lld, Hit rates: %lf\n", cache_stat.first, cache_stat.second, cache_stat.second == 0 ? 0 : (double)cache_stat.second / cache_stat.first); + } + m_pBlockController.IOStatistics(); + } + + ErrorCode Load(std::string path, SizeType blockSize, SizeType capacity) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load mapping From %s\n", path.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(path.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + + SizeType CR, mycols; + IOBINARY(ptr, ReadBinary, sizeof(SizeType), (char*)&CR); + IOBINARY(ptr, ReadBinary, sizeof(SizeType), (char*)&mycols); + if (mycols > m_blockLimit) m_blockLimit = mycols; + + m_pBlockMapping.Initialize(CR, 1, blockSize, capacity); + for (int i = 0; i < CR; i++) { + At(i) = (uintptr_t)(new AddressType[m_blockLimit]); + IOBINARY(ptr, ReadBinary, sizeof(AddressType) * mycols, (char*)At(i)); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load mapping (%d,%d) Finish!\n", CR, mycols); + return ErrorCode::Success; + } + + ErrorCode Save(std::string path) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save mapping To %s\n", path.c_str()); + if (m_pShardedLRUCache) { + if (!m_pShardedLRUCache->flush()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to flush cache when saving mapping!\n"); + return ErrorCode::Fail; + } + } + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(path.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + + SizeType CR = m_pBlockMapping.R(); + IOBINARY(ptr, WriteBinary, sizeof(SizeType), (char*)&CR); + IOBINARY(ptr, WriteBinary, sizeof(SizeType), (char*)&m_blockLimit); + std::vector empty(m_blockLimit, 0xffffffffffffffff); + for (int i = 0; i < CR; i++) { + if (At(i) == 0xffffffffffffffff) { + IOBINARY(ptr, WriteBinary, sizeof(AddressType) * m_blockLimit, (char*)(empty.data())); + } + else { + int64_t* postingSize = (int64_t*)At(i); + IOBINARY(ptr, WriteBinary, sizeof(AddressType) * m_blockLimit, (char*)postingSize); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save mapping (%d,%d) Finish!\n", CR, m_blockLimit); + /* + for (int i = 0; i < 10; i++) { + std::cout << "i=" << i << ": "; + for (int j = 0; j < 10; j++) { + std::cout << m_pBlockMapping[i][j] << " "; + } + std::cout << std::endl; + } + for (int i = m_pBlockMapping.R() - 10; i < m_pBlockMapping.R(); i++) { + std::cout << "i=" << i << ": "; + for (int j = 0; j < 10; j++) { + std::cout << m_pBlockMapping[i][j] << " "; + } + std::cout << std::endl; + } + */ + return ErrorCode::Success; + } + + ErrorCode Checkpoint(std::string prefix) override { + std::string filename = prefix + FolderSep + m_mappingPath.substr(m_mappingPath.find_last_of(FolderSep) + 1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO: saving block mapping to %s\n", filename.c_str()); + Save(filename); + return m_pBlockController.Checkpoint(filename + "_postings"); + } + + int64_t GetNumBlocks() override + { + return m_pBlockController.TotalBlocks(); + } + + private: + std::atomic read_time_vec = 0; + std::atomic get_time_vec = 0; + std::atomic get_times_vec = 0; + + std::string m_mappingPath; + SizeType m_blockLimit; + COMMON::Dataset m_pBlockMapping; + SizeType m_bufferLimit; + Helper::Concurrent::ConcurrentQueue m_buffer; + Helper::Concurrent::ConcurrentQueue m_key_reserve; + + std::shared_ptr m_compactionThreadPool; + BlockController m_pBlockController; + ShardedLRUCache *m_pShardedLRUCache; + + bool m_shutdownCalled; + std::shared_mutex m_updateMutex; + }; +} +#endif diff --git a/AnnService/inc/Core/SPANN/ExtraRocksDBController.h b/AnnService/inc/Core/SPANN/ExtraRocksDBController.h new file mode 100644 index 000000000..732dda6f3 --- /dev/null +++ b/AnnService/inc/Core/SPANN/ExtraRocksDBController.h @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SPANN_EXTRAROCKSDBCONTROLLER_H_ +#define _SPTAG_SPANN_EXTRAROCKSDBCONTROLLER_H_ + +#include "inc/Helper/KeyValueIO.h" +#include "inc/Helper/StringConvert.h" + +#include "rocksdb/db.h" +#include "rocksdb/filter_policy.h" +#include "rocksdb/rate_limiter.h" +#include "rocksdb/slice.h" +#include "rocksdb/options.h" +#include "rocksdb/merge_operator.h" +#include "rocksdb/table.h" +#include "rocksdb/utilities/checkpoint.h" + +#include +#include +#include +#include + +namespace SPTAG::SPANN +{ + class RocksDBIO : public Helper::KeyValueIO + { + class AnnMergeOperator : public rocksdb::MergeOperator + { + public: + bool FullMergeV2(const rocksdb::MergeOperator::MergeOperationInput& merge_in, + rocksdb::MergeOperator::MergeOperationOutput* merge_out) const override + { + size_t length = (merge_in.existing_value)->size(); + for (const rocksdb::Slice& s : merge_in.operand_list) { + length += s.size(); + } + (merge_out->new_value).resize(length); + memcpy((char*)((merge_out->new_value).c_str()), + (merge_in.existing_value)->data(), (merge_in.existing_value)->size()); + size_t start = (merge_in.existing_value)->size(); + for (const rocksdb::Slice& s : merge_in.operand_list) { + memcpy((char*)((merge_out->new_value).c_str() + start), s.data(), s.size()); + start += s.size(); + } + return true; + } + + bool PartialMergeMulti(const rocksdb::Slice& key, + const std::deque& operand_list, + std::string* new_value, rocksdb::Logger* logger) const override + { + size_t length = 0; + for (const rocksdb::Slice& s : operand_list) { + length += s.size(); + } + new_value->resize(length); + size_t start = 0; + for (const rocksdb::Slice& s : operand_list) { + memcpy((char*)(new_value->c_str() + start), s.data(), s.size()); + start += s.size(); + } + return true; + } + + const char* Name() const override { + return "AnnMergeOperator"; + } + }; + + public: + RocksDBIO(const char* filePath, bool usdDirectIO, bool wal = false, bool recovery = false) { + dbPath = std::string(filePath); + //dbOptions.statistics = rocksdb::CreateDBStatistics(); + dbOptions.create_if_missing = true; + if (!wal) { + dbOptions.IncreaseParallelism(); + dbOptions.OptimizeLevelStyleCompaction(); + dbOptions.merge_operator.reset(new AnnMergeOperator); + // dbOptions.statistics = rocksdb::CreateDBStatistics(); + + // SST file size options + dbOptions.target_file_size_base = 128UL * 1024 * 1024; + dbOptions.target_file_size_multiplier = 2; + dbOptions.max_bytes_for_level_base = 16 * 1024UL * 1024 * 1024; + dbOptions.max_bytes_for_level_multiplier = 4; + dbOptions.max_subcompactions = 16; + dbOptions.num_levels = 4; + dbOptions.level0_file_num_compaction_trigger = 1; + dbOptions.level_compaction_dynamic_level_bytes = false; + dbOptions.write_buffer_size = 16UL * 1024 * 1024; + + // rate limiter options + // dbOptions.rate_limiter.reset(rocksdb::NewGenericRateLimiter(100UL << 20)); + + // blob options + dbOptions.enable_blob_files = true; + dbOptions.min_blob_size = 64; + dbOptions.blob_file_size = 8UL << 30; + dbOptions.blob_compression_type = rocksdb::CompressionType::kNoCompression; + dbOptions.enable_blob_garbage_collection = true; + dbOptions.compaction_pri = rocksdb::CompactionPri::kRoundRobin; + dbOptions.blob_garbage_collection_age_cutoff = 0.4; + // dbOptions.blob_garbage_collection_force_threshold = 0.5; + // dbOptions.blob_cache = rocksdb::NewLRUCache(5UL << 30); + // dbOptions.prepopulate_blob_cache = rocksdb::PrepopulateBlobCache::kFlushOnly; + + // dbOptions.env; + // dbOptions.sst_file_manager = std::shared_ptr(rocksdb::NewSstFileManager(dbOptions.env)); + // dbOptions.sst_file_manager->SetStatisticsPtr(dbOptions.statistics); + + // compression options + // dbOptions.compression = rocksdb::CompressionType::kLZ4Compression; + // dbOptions.bottommost_compression = rocksdb::CompressionType::kZSTD; + + // block cache options + rocksdb::BlockBasedTableOptions table_options; + // table_options.block_cache = rocksdb::NewSimCache(rocksdb::NewLRUCache(1UL << 30), (8UL << 30), -1); + table_options.block_cache = rocksdb::NewLRUCache(3UL << 30); + // table_options.no_block_cache = true; + + // filter options + table_options.filter_policy.reset(rocksdb::NewBloomFilterPolicy(10, true)); + table_options.optimize_filters_for_memory = true; + + dbOptions.table_factory.reset(rocksdb::NewBlockBasedTableFactory(table_options)); + } + + if (usdDirectIO) { + dbOptions.use_direct_io_for_flush_and_compaction = true; + dbOptions.use_direct_reads = true; + } + + auto s = rocksdb::DB::Open(dbOptions, dbPath, &db); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPFresh: New Rocksdb: %s\n", filePath); + if (s != rocksdb::Status::OK()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mRocksdb Open Error\e[0m: %s\n", s.getState()); + ShutDown(); + } + } + + ~RocksDBIO() override { + /* + std::string stats; + db->GetProperty("rocksdb.stats", &stats); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "RocksDB Status: %s\n%s", dbPath.c_str(),stats.c_str()); + */ + ShutDown(); + } + + void ShutDown() override { + if (db == nullptr) return; + + db->Close(); + //DestroyDB(dbPath, dbOptions); + delete db; + db = nullptr; + } + + bool Available() override + { + return (db != nullptr); + } + + ErrorCode Get(const std::string& key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + auto s = db->Get(rocksdb::ReadOptions(), key, value); + if (s == rocksdb::Status::OK()) { + return ErrorCode::Success; + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mError in Get\e[0m: %s, key: %d\n", s.getState(), *((SizeType*)(key.data()))); + return ErrorCode::Fail; + } + } + + ErrorCode Get(SizeType key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + std::string k((char*)&key, sizeof(SizeType)); + return Get(k, value, timeout, reqs); + } + + ErrorCode MultiGet(const std::vector& keys, std::vector>& values, + const std::chrono::microseconds& timeout, std::vector* reqs) override { + size_t num_keys = keys.size(); + + rocksdb::Slice* slice_keys = new rocksdb::Slice[num_keys]; + rocksdb::PinnableSlice* slice_values = new rocksdb::PinnableSlice[num_keys]; + rocksdb::Status* statuses = new rocksdb::Status[num_keys]; + + // RocksDB uses Slice for keys, which doen't own the memory. + // We need to convert SizeType keys to string slices. + + std::vector str_keys(num_keys); + + for (int i = 0; i < num_keys; i++) { + str_keys[i] = std::string((char*)&keys[i], sizeof(SizeType)); + } + + for (int i = 0; i < num_keys; i++) { + slice_keys[i] = rocksdb::Slice(str_keys[i]); + } + + db->MultiGet(rocksdb::ReadOptions(), db->DefaultColumnFamily(), + num_keys, slice_keys, slice_values, statuses); + + for (int i = 0; i < num_keys; i++) { + if (statuses[i] != rocksdb::Status::OK()) { + delete[] slice_keys; + delete[] slice_values; + delete[] statuses; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mError in MultiGet\e[0m: %s, key: %d\n", statuses[i].getState(), keys[i]); + return ErrorCode::Fail; + } + memcpy(values[i].GetBuffer(), slice_values[i].data(), slice_values[i].size()); + values[i].SetAvailableSize(slice_values[i].size()); + } + + delete[] slice_keys; + delete[] slice_values; + delete[] statuses; + return ErrorCode::Success; + } + + ErrorCode MultiGet(const std::vector& keys, std::vector* values, + const std::chrono::microseconds &timeout, std::vector* reqs) override { + size_t num_keys = keys.size(); + + rocksdb::Slice* slice_keys = new rocksdb::Slice[num_keys]; + rocksdb::PinnableSlice* slice_values = new rocksdb::PinnableSlice[num_keys]; + rocksdb::Status* statuses = new rocksdb::Status[num_keys]; + + for (int i = 0; i < num_keys; i++) { + slice_keys[i] = rocksdb::Slice(keys[i]); + } + + db->MultiGet(rocksdb::ReadOptions(), db->DefaultColumnFamily(), + num_keys, slice_keys, slice_values, statuses); + + for (int i = 0; i < num_keys; i++) { + if (statuses[i] != rocksdb::Status::OK()) { + delete[] slice_keys; + delete[] slice_values; + delete[] statuses; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mError in MultiGet\e[0m: %s, key: %d\n", statuses[i].getState(), *((SizeType*)(keys[i].data()))); + return ErrorCode::Fail; + } + values->push_back(slice_values[i].ToString()); + } + + delete[] slice_keys; + delete[] slice_values; + delete[] statuses; + return ErrorCode::Success; + } + + ErrorCode MultiGet(const std::vector& keys, std::vector* values, + const std::chrono::microseconds &timeout, std::vector* reqs) override { + std::vector str_keys; + + for (const auto& key : keys) { + str_keys.emplace_back((char*)(&key), sizeof(SizeType)); + } + + return MultiGet(str_keys, values, timeout, reqs); + } + + ErrorCode Put(const std::string& key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + auto s = db->Put(rocksdb::WriteOptions(), key, value); + if (s == rocksdb::Status::OK()) { + return ErrorCode::Success; + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mError in Put\e[0m: %s, key: %d\n", s.getState(), *((SizeType*)(key.data()))); + return ErrorCode::Fail; + } + } + + ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + std::string k((char*)&key, sizeof(SizeType)); + return Put(k, value, timeout, reqs); + } + + ErrorCode Merge(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + if (value.empty()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error! empty append posting!\n"); + } + std::string k((char*)&key, sizeof(SizeType)); + auto s = db->Merge(rocksdb::WriteOptions(), k, value); + if (s == rocksdb::Status::OK()) { + return ErrorCode::Success; + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mError in Merge\e[0m: %s, key: %d\n", s.getState(), key); + return ErrorCode::Fail; + } + } + + ErrorCode StartToScan(SizeType& key, std::string* value) { + it = db->NewIterator(rocksdb::ReadOptions()); + it->SeekToFirst(); + if (!it->Valid()) return ErrorCode::Fail; + key = *((SizeType*)(it->key().ToString().c_str())); + *value = it->value().ToString(); + return ErrorCode::Success; + } + + ErrorCode NextToScan(SizeType& key, std::string* value) { + it->Next(); + if (!it->Valid()) return ErrorCode::Fail; + key = *((SizeType*)(it->key().ToString().c_str())); + *value = it->value().ToString(); + return ErrorCode::Success; + } + + ErrorCode Delete(SizeType key) override { + std::string k((char*)&key, sizeof(SizeType)); + auto s = db->Delete(rocksdb::WriteOptions(), k); + if (s == rocksdb::Status::OK()) { + return ErrorCode::Success; + } + else { + return ErrorCode::Fail; + } + } + + ErrorCode DeleteRange(SizeType start, SizeType end) override { + std::string string_start((char*)&start, sizeof(SizeType)); + rocksdb::Slice slice_start = rocksdb::Slice(string_start); + std::string string_end((char*)&end, sizeof(SizeType)); + rocksdb::Slice slice_end = rocksdb::Slice(string_end); + auto s = db->DeleteRange(rocksdb::WriteOptions(), db->DefaultColumnFamily(), slice_start, slice_end); + if (s == rocksdb::Status::OK()) { + return ErrorCode::Success; + } + else { + return ErrorCode::Fail; + } + } + + void ForceCompaction() { + /* + std::string stats; + db->GetProperty("rocksdb.stats", &stats); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "RocksDB Status:\n%s", stats.c_str()); + */ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start Compaction\n"); + auto s = db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish Compaction\n"); + + if (s != rocksdb::Status::OK()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "\e[0;31mRocksdb Compact Error\e[0m: %s\n", s.getState()); + } + /* + db->GetProperty("rocksdb.stats", &stats); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "RocksDB Status:\n%s", stats.c_str()); + */ + } + + void GetStat() { + if (dbOptions.statistics != nullptr) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "RocksDB statistics:\n %s\n", dbOptions.statistics->ToString().c_str()); + else + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DB statistics not set!\n"); + } + + ErrorCode Checkpoint(std::string prefix) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "RocksDB: checkpoint\n"); + rocksdb::Checkpoint* checkpoint_ptr; + rocksdb::Checkpoint::Create(db, &checkpoint_ptr); + std::string filename = prefix + FolderSep + dbPath.substr(dbPath.find_last_of(FolderSep) + 1); + checkpoint_ptr->CreateCheckpoint(filename); + return ErrorCode::Success; + } + + private: + std::string dbPath; + rocksdb::DB* db{}; + rocksdb::Options dbOptions; + rocksdb::Iterator* it; + }; +} +#endif // _SPTAG_SPANN_EXTRAROCKSDBCONTROLLER_H_ diff --git a/AnnService/inc/Core/SPANN/ExtraSPDKController.h b/AnnService/inc/Core/SPANN/ExtraSPDKController.h new file mode 100644 index 000000000..396fb24be --- /dev/null +++ b/AnnService/inc/Core/SPANN/ExtraSPDKController.h @@ -0,0 +1,487 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SPANN_EXTRASPDKCONTROLLER_H_ +#define _SPTAG_SPANN_EXTRASPDKCONTROLLER_H_ + +#include "inc/Helper/KeyValueIO.h" +#include "inc/Helper/DiskIO.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Core/Common/Dataset.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SPANN/Options.h" +#include "inc/Helper/ThreadPool.h" +#include +#include +#include +#include +#include +#include + +extern "C" { +#include "spdk/env.h" +#include "spdk/event.h" +#include "spdk/log.h" +#include "spdk/thread.h" +#include "spdk/bdev.h" +} + +namespace SPTAG::SPANN +{ + typedef std::int64_t AddressType; + class SPDKIO : public Helper::KeyValueIO + { + class BlockController { + private: + static constexpr const char* kUseMemImplEnv = "SPFRESH_SPDK_USE_MEM_IMPL"; + static constexpr const char* kUseSsdImplEnv = "SPFRESH_SPDK_USE_SSD_IMPL"; + static constexpr const char* kSpdkConfEnv = "SPFRESH_SPDK_CONF"; + static constexpr const char* kSpdkBdevNameEnv = "SPFRESH_SPDK_BDEV"; + + Helper::Concurrent::ConcurrentQueue m_blockAddresses; + Helper::Concurrent::ConcurrentQueue m_blockAddresses_reserve; + + std::string m_filePath; + bool m_useSsdImpl = false; + const char* m_ssdSpdkBdevName = nullptr; + pthread_t m_ssdSpdkTid; + volatile bool m_ssdSpdkThreadStartFailed = false; + volatile bool m_ssdSpdkThreadReady = false; + volatile bool m_ssdSpdkThreadExiting = false; + struct spdk_bdev *m_ssdSpdkBdev = nullptr; + struct spdk_bdev_desc *m_ssdSpdkBdevDesc = nullptr; + struct spdk_io_channel *m_ssdSpdkBdevIoChannel = nullptr; + + Helper::Concurrent::ConcurrentQueue m_submittedSubIoRequests; + + static int m_ssdInflight; + + bool m_useMemImpl = false; + static std::unique_ptr m_memBuffer; + + float m_growthThreshold = 0.05; + AddressType m_growthBlocks = 0; + AddressType m_maxBlocks = 0; + int m_batchSize = 64; + + static int m_ioCompleteCount; + int m_preIOCompleteCount = 0; + std::chrono::time_point m_preTime = std::chrono::high_resolution_clock::now(); + + std::mutex m_expandLock; + std::atomic m_totalAllocatedBlocks = 0; + + bool ExpandFile(AddressType blocksToAdd); + + bool NeedsExpansion(); + + static void* InitializeSpdk(void* args); + + static void SpdkStart(void* args); + + static void SpdkIoLoop(void *arg); + + static void SpdkBdevEventCallback(enum spdk_bdev_event_type type, struct spdk_bdev *bdev, void *event_ctx); + + static void SpdkBdevIoCallback(struct spdk_bdev_io *bdev_io, bool success, void *cb_arg); + + static void SpdkStop(void* args); + public: + bool Initialize(SPANN::Options& p_opt); + + // get p_size blocks from front, and fill in p_data array + bool GetBlocks(AddressType* p_data, int p_size); + + // release p_size blocks, put them at the end of the queue + bool ReleaseBlocks(AddressType* p_data, int p_size); + + // read a posting list. p_data[0] is the total data size, + // p_data[1], p_data[2], ..., p_data[((p_data[0] + PageSize - 1) >> PageSizeEx)] are the addresses of the blocks + // concat all the block contents together into p_value string. + bool ReadBlocks(AddressType* p_data, std::string* p_value, const std::chrono::microseconds& timeout, std::vector* reqs); + + // parallel read a list of posting lists. + bool ReadBlocks(std::vector& p_data, std::vector* p_values, const std::chrono::microseconds& timeout, std::vector* reqs); + bool ReadBlocks(std::vector& p_data, std::vector>& p_values, const std::chrono::microseconds& timeout, std::vector* reqs); + + // write p_value into p_size blocks start from p_data + bool WriteBlocks(AddressType* p_data, int p_size, const std::string& p_value, const std::chrono::microseconds& timeout, std::vector* reqs); + + bool IOStatistics(); + + bool ShutDown(); + + int RemainBlocks() { + return m_blockAddresses.unsafe_size(); + } + + ErrorCode Checkpoint(std::string prefix) { + std::string filename = prefix + "_blockpool"; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDK: saving block pool\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Reload reserved blocks!\n"); + AddressType currBlockAddress = 0; + while (m_blockAddresses_reserve.try_pop(currBlockAddress)) + { + m_blockAddresses.push(currBlockAddress); + } + AddressType blocks = RemainBlocks(); + AddressType totalBlocks = m_totalAllocatedBlocks.load(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::Checkpoint - Total allocated blocks: %llu\n", static_cast(totalBlocks)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::Checkpoint - Remaining free blocks: %llu\n", blocks); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::Checkpoint - Saving to file: %s\n", filename.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(filename.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::Checkpoint - Failed to open file: %s\n", filename.c_str()); + return ErrorCode::FailedCreateFile; + } + IOBINARY(ptr, WriteBinary, sizeof(AddressType), (char*)&blocks); + IOBINARY(ptr, WriteBinary, sizeof(AddressType), reinterpret_cast(&totalBlocks)); + for (auto it = m_blockAddresses.unsafe_begin(); it != m_blockAddresses.unsafe_end(); it++) { + IOBINARY(ptr, WriteBinary, sizeof(AddressType), (char*)&(*it)); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save Finish!\n"); + return ErrorCode::Success; + } + + ErrorCode LoadBlockPool(std::string prefix, AddressType startNumBlocks, bool allowinit) { + std::string blockfile = prefix + "_blockpool"; + if (allowinit && !fileexists(blockfile.c_str())) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::LoadBlockPool: initializing fresh pool (no existing file found: %s)\n", blockfile.c_str()); + for(AddressType i = 0; i < startNumBlocks; i++) { + m_blockAddresses.push(i); + } + m_totalAllocatedBlocks.store(startNumBlocks); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::LoadBlockPool: initialized with %llu blocks (%.2f GB)\n", + static_cast(startNumBlocks), static_cast(startNumBlocks >> (30 - PageSizeEx))); + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::Load blockpool from %s\n", blockfile.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(blockfile.c_str(), std::ios::binary | std::ios::in)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open the blockpool file: %s\n", blockfile.c_str()); + return ErrorCode::Fail; + } + AddressType blocks = 0; + AddressType totalAllocated = 0; + + // Read block count + IOBINARY(ptr, ReadBinary, sizeof(AddressType), reinterpret_cast(&blocks)); + IOBINARY(ptr, ReadBinary, sizeof(AddressType), reinterpret_cast(&totalAllocated)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "SPDKIO::BlockController::LoadBlockPool: reading %llu free blocks into pool (%.2f GB), total allocated: %llu blocks\n", + static_cast(blocks), + static_cast(blocks >> (30 - PageSizeEx)), + static_cast(totalAllocated)); + + AddressType currBlockAddress = 0; + for (int i = 0; i < blocks; i++) { + IOBINARY(ptr, ReadBinary, sizeof(AddressType), (char*)&(currBlockAddress)); + m_blockAddresses.push(currBlockAddress); + } + m_totalAllocatedBlocks.store(totalAllocated); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "SPDKIO::BlockController::LoadBlockPool: block pool initialized. Available: %llu, Total allocated: %llu\n", + static_cast(blocks), + static_cast(totalAllocated)); + } + return ErrorCode::Success; + } + }; + + public: + SPDKIO(SPANN::Options& p_opt) + //const char* filePath, SizeType blockSize, SizeType capacity, SizeType postingBlocks, SizeType bufferSize = 1024, int batchSize = 64, bool recovery = false, int compactionThreads = 1) + { + m_opt = &p_opt; + m_mappingPath = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdMappingFile; + m_blockLimit = max(p_opt.m_postingPageLimit, p_opt.m_searchPostingPageLimit) + p_opt.m_bufferLength + 1; + m_bufferLimit = 1024; + m_shutdownCalled = true; + + if (p_opt.m_recovery) { + std::string recoverpath = p_opt.m_persistentBufferPath + FolderSep + p_opt.m_ssdMappingFile; + if (fileexists(recoverpath.c_str())) { + m_pBlockMapping.Load(recoverpath, 1024 * 1024, MaxSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load block mapping successfully from %s!\n", recoverpath.c_str()); + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot recover block mapping from %s!\n", recoverpath.c_str()); + return; + } + } + else { + if (fileexists(m_mappingPath.c_str())) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load blockmapping successfully from %s\n", m_mappingPath.c_str()); + Load(m_mappingPath, 1024 * 1024, MaxSize); + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Initialize block mapping successfully!\n"); + m_pBlockMapping.Initialize(0, 1, 1024 * 1024, MaxSize); + } + } + for (int i = 0; i < m_bufferLimit; i++) { + m_buffer.push((uintptr_t)(new AddressType[m_blockLimit])); + } + m_compactionThreadPool = std::make_shared(); + m_compactionThreadPool->init(1); + + if (!m_pBlockController.Initialize(p_opt)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Initialize SPDK!\n"); + return; + } + + m_shutdownCalled = false; + } + + ~SPDKIO() { + ShutDown(); + } + + void ShutDown() override { + if (m_shutdownCalled) { + return; + } + if (!m_mappingPath.empty()) Save(m_mappingPath); + for (int i = 0; i < m_pBlockMapping.R(); i++) { + if (At(i) != 0xffffffffffffffff) delete[]((AddressType*)At(i)); + } + while (!m_buffer.empty()) { + uintptr_t ptr = 0xffffffffffffffff; + if (m_buffer.try_pop(ptr)) delete[]((AddressType*)ptr); + } + m_pBlockController.ShutDown(); + m_shutdownCalled = true; + } + + bool Available() override + { + return !m_shutdownCalled; + } + + inline uintptr_t& At(SizeType key) { + return *(m_pBlockMapping[key]); + } + + ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + if (key >= m_pBlockMapping.R()) return ErrorCode::Fail; + + if (m_pBlockController.ReadBlocks((AddressType*)At(key), value, timeout, reqs)) return ErrorCode::Success; + return ErrorCode::Fail; + } + + ErrorCode MultiGet(const std::vector& keys, std::vector* values, + const std::chrono::microseconds &timeout, std::vector* reqs) override { + std::vector blocks; + for (SizeType key : keys) { + if (key < m_pBlockMapping.R()) blocks.push_back((AddressType*)At(key)); + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read key:%d total key number:%d\n", key, m_pBlockMapping.R()); + } + } + if (m_pBlockController.ReadBlocks(blocks, values, timeout, reqs)) return ErrorCode::Success; + return ErrorCode::Fail; + } + + ErrorCode MultiGet(const std::vector& keys, std::vector>& values, const std::chrono::microseconds& timeout, std::vector* reqs) override { + std::vector blocks; + for (SizeType key : keys) { + if (key < m_pBlockMapping.R()) blocks.push_back((AddressType*)At(key)); + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read key:%d total key number:%d\n", key, m_pBlockMapping.R()); + } + } + if (m_pBlockController.ReadBlocks(blocks, values, timeout, reqs)) return ErrorCode::Success; + return ErrorCode::Fail; + } + + ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + int blocks = ((value.size() + PageSize - 1) >> PageSizeEx); + if (blocks >= m_blockLimit) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failt to put key:%d value:%lld since value too long!\n", key, value.size()); + return ErrorCode::Fail; + } + int delta = key + 1 - m_pBlockMapping.R(); + if (delta > 0) { + { + std::lock_guard lock(m_updateMutex); + delta = key + 1 - m_pBlockMapping.R(); + if (delta > 0) { + m_pBlockMapping.AddBatch(delta); + } + } + } + if (At(key) == 0xffffffffffffffff) { + if (m_buffer.unsafe_size() > m_bufferLimit) { + uintptr_t tmpblocks = 0xffffffffffffffff; + while (!m_buffer.try_pop(tmpblocks)); + At(key) = tmpblocks; + } + else { + At(key) = (uintptr_t)(new AddressType[m_blockLimit]); + } + memset((AddressType*)At(key), -1, sizeof(AddressType) * m_blockLimit); + } + int64_t* postingSize = (int64_t*)At(key); + if (*postingSize < 0) { + m_pBlockController.GetBlocks(postingSize + 1, blocks); + m_pBlockController.WriteBlocks(postingSize + 1, blocks, value, timeout, reqs); + *postingSize = value.size(); + } + else { + uintptr_t tmpblocks = 0xffffffffffffffff; + while (!m_buffer.try_pop(tmpblocks)); + m_pBlockController.GetBlocks((AddressType*)tmpblocks + 1, blocks); + m_pBlockController.WriteBlocks((AddressType*)tmpblocks + 1, blocks, value, timeout, reqs); + *((int64_t*)tmpblocks) = value.size(); + + m_pBlockController.ReleaseBlocks(postingSize + 1, (*postingSize + PageSize -1) >> PageSizeEx); + while (InterlockedCompareExchange(&At(key), tmpblocks, (uintptr_t)postingSize) != (uintptr_t)postingSize) { + postingSize = (int64_t*)At(key); + } + m_buffer.push((uintptr_t)postingSize); + } + return ErrorCode::Success; + } + + ErrorCode Merge(SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) override { + if (key >= m_pBlockMapping.R()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Key range error: key: %d, mapping size: %d\n", key, m_pBlockMapping.R()); + return ErrorCode::Fail; + } + + int64_t* postingSize = (int64_t*)At(key); + auto newSize = *postingSize + value.size(); + int newblocks = ((newSize + PageSize - 1) >> PageSizeEx); + if (newblocks >= m_blockLimit) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Key %d failed: new size %lld bytes requires %d blocks (limit: %d)\n", + key, static_cast(newSize), newblocks, m_blockLimit); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "[Merge] Original size: %lld bytes, Merge size: %lld bytes\n", + static_cast(*postingSize), static_cast(value.size())); + return ErrorCode::Fail; + } + + auto sizeInPage = (*postingSize) % PageSize; + int oldblocks = (*postingSize >> PageSizeEx); + int allocblocks = newblocks - oldblocks; + if (sizeInPage != 0) { + std::string newValue; + AddressType readreq[] = { sizeInPage, *(postingSize + 1 + oldblocks) }; + m_pBlockController.ReadBlocks(readreq, &newValue, timeout, reqs); + newValue += value; + + uintptr_t tmpblocks = 0xffffffffffffffff; + while (!m_buffer.try_pop(tmpblocks)); + memcpy((AddressType*)tmpblocks, postingSize, sizeof(AddressType) * (oldblocks + 1)); + m_pBlockController.GetBlocks((AddressType*)tmpblocks + 1 + oldblocks, allocblocks); + m_pBlockController.WriteBlocks((AddressType*)tmpblocks + 1 + oldblocks, allocblocks, newValue, timeout, reqs); + *((int64_t*)tmpblocks) = newSize; + + m_pBlockController.ReleaseBlocks(postingSize + 1 + oldblocks, 1); + while (InterlockedCompareExchange(&At(key), tmpblocks, (uintptr_t)postingSize) != (uintptr_t)postingSize) { + postingSize = (int64_t*)At(key); + } + m_buffer.push((uintptr_t)postingSize); + } + else { + m_pBlockController.GetBlocks(postingSize + 1 + oldblocks, allocblocks); + m_pBlockController.WriteBlocks(postingSize + 1 + oldblocks, allocblocks, value, timeout, reqs); + *postingSize = newSize; + } + return ErrorCode::Success; + } + + ErrorCode Delete(SizeType key) override { + if (key >= m_pBlockMapping.R()) return ErrorCode::Fail; + int64_t* postingSize = (int64_t*)At(key); + if (*postingSize < 0) return ErrorCode::Fail; + + int blocks = ((*postingSize + PageSize - 1) >> PageSizeEx); + m_pBlockController.ReleaseBlocks(postingSize + 1, blocks); + m_buffer.push((uintptr_t)postingSize); + At(key) = 0xffffffffffffffff; + return ErrorCode::Success; + } + + void ForceCompaction() { + Save(m_mappingPath); + } + + void GetStat() { + int remainBlocks = m_pBlockController.RemainBlocks(); + int remainGB = (long long)remainBlocks << PageSizeEx >> 30; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Remain %d blocks, totally %d GB\n", remainBlocks, remainGB); + m_pBlockController.IOStatistics(); + } + + ErrorCode Load(std::string path, SizeType blockSize, SizeType capacity) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load mapping From %s\n", path.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(path.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + + SizeType CR, mycols; + IOBINARY(ptr, ReadBinary, sizeof(SizeType), (char*)&CR); + IOBINARY(ptr, ReadBinary, sizeof(SizeType), (char*)&mycols); + if (mycols > m_blockLimit) m_blockLimit = mycols; + + m_pBlockMapping.Initialize(CR, 1, blockSize, capacity); + for (int i = 0; i < CR; i++) { + At(i) = (uintptr_t)(new AddressType[m_blockLimit]); + IOBINARY(ptr, ReadBinary, sizeof(AddressType) * mycols, (char*)At(i)); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load mapping (%d,%d) Finish!\n", CR, mycols); + return ErrorCode::Success; + } + + ErrorCode Save(std::string path) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save mapping To %s\n", path.c_str()); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(path.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; + + SizeType CR = m_pBlockMapping.R(); + IOBINARY(ptr, WriteBinary, sizeof(SizeType), (char*)&CR); + IOBINARY(ptr, WriteBinary, sizeof(SizeType), (char*)&m_blockLimit); + std::vector empty(m_blockLimit, 0xffffffffffffffff); + for (int i = 0; i < CR; i++) { + if (At(i) == 0xffffffffffffffff) { + IOBINARY(ptr, WriteBinary, sizeof(AddressType) * m_blockLimit, (char*)(empty.data())); + } + else { + int64_t* postingSize = (int64_t*)At(i); + IOBINARY(ptr, WriteBinary, sizeof(AddressType) * m_blockLimit, (char*)postingSize); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save mapping (%d,%d) Finish!\n", CR, m_blockLimit); + return ErrorCode::Success; + } + + ErrorCode Checkpoint(std::string prefix) override { + std::string filename = prefix + FolderSep + m_mappingPath.substr(m_mappingPath.find_last_of(FolderSep) + 1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDK: saving block mapping to %s\n", filename.c_str()); + Save(filename); + return m_pBlockController.Checkpoint(filename + "_postings"); + } + + private: + SPANN::Options* m_opt; + std::string m_mappingPath; + SizeType m_blockLimit; + COMMON::Dataset m_pBlockMapping; + SizeType m_bufferLimit; + Helper::Concurrent::ConcurrentQueue m_buffer; + + std::shared_ptr m_compactionThreadPool; + BlockController m_pBlockController; + + bool m_shutdownCalled; + std::mutex m_updateMutex; + }; +} +#endif // _SPTAG_SPANN_EXTRASPDKCONTROLLER_H_ diff --git a/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h b/AnnService/inc/Core/SPANN/ExtraStaticSearcher.h similarity index 51% rename from AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h rename to AnnService/inc/Core/SPANN/ExtraStaticSearcher.h index bea5552db..2149e0add 100644 --- a/AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h +++ b/AnnService/inc/Core/SPANN/ExtraStaticSearcher.h @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifndef _SPTAG_SPANN_EXTRASEARCHER_H_ -#define _SPTAG_SPANN_EXTRASEARCHER_H_ +#ifndef _SPTAG_SPANN_EXTRASTATICSEARCHER_H_ +#define _SPTAG_SPANN_EXTRASTATICSEARCHER_H_ #include "inc/Helper/VectorSetReader.h" #include "inc/Helper/AsyncFileReader.h" @@ -36,11 +36,11 @@ namespace SPTAG { auto f_out = f_createIO(); if (f_out == nullptr || !f_out->Initialize(m_tmpfile.c_str(), std::ios::out | std::ios::binary | (fileexists(m_tmpfile.c_str()) ? std::ios::in : 0))) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to save selection for batching!\n", m_tmpfile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to save selection for batching!\n", m_tmpfile.c_str()); return ErrorCode::FailedOpenFile; } if (f_out->WriteBinary(sizeof(Edge) * (m_end - m_start), (const char*)m_selections.data(), sizeof(Edge) * m_start) != sizeof(Edge) * (m_end - m_start)) { - LOG(Helper::LogLevel::LL_Error, "Cannot write to %s!\n", m_tmpfile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write to %s!\n", m_tmpfile.c_str()); return ErrorCode::DiskIOFail; } std::vector batch_selection; @@ -53,14 +53,14 @@ namespace SPTAG { auto f_in = f_createIO(); if (f_in == nullptr || !f_in->Initialize(m_tmpfile.c_str(), std::ios::in | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to load selection batch!\n", m_tmpfile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to load selection batch!\n", m_tmpfile.c_str()); return ErrorCode::FailedOpenFile; } size_t readsize = end - start; m_selections.resize(readsize); if (f_in->ReadBinary(readsize * sizeof(Edge), (char*)m_selections.data(), start * sizeof(Edge)) != readsize * sizeof(Edge)) { - LOG(Helper::LogLevel::LL_Error, "Cannot read from %s! start:%zu size:%zu\n", m_tmpfile.c_str(), start, readsize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot read from %s! start:%zu size:%zu\n", m_tmpfile.c_str(), start, readsize); return ErrorCode::DiskIOFail; } m_start = start; @@ -77,7 +77,7 @@ namespace SPTAG Edge& operator[](size_t offset) { if (offset < m_start || offset >= m_end) { - LOG(Helper::LogLevel::LL_Error, "Error read offset in selections:%zu\n", offset); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error read offset in selections:%zu\n", offset); } return m_selections[offset - m_start]; } @@ -91,41 +91,67 @@ namespace SPTAG sizePostingListFullData = m_pCompressor->Decompress(buffer + listInfo->pageOffset, listInfo->listTotalBytes, p_postingListFullData, listInfo->listEleCount * m_vectorInfoSize, m_enableDictTraining);\ }\ catch (std::runtime_error& err) {\ - LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", listInfo->postingID, err.what());\ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", listInfo - m_listInfos.data(), err.what());\ return;\ }\ if (sizePostingListFullData != listInfo->listEleCount * m_vectorInfoSize) {\ - LOG(Helper::LogLevel::LL_Error, "PostingList %d decompressed size not match! %zu, %d, \n", listInfo->postingID, sizePostingListFullData, listInfo->listEleCount * m_vectorInfoSize);\ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "PostingList %d decompressed size not match! %zu, %d, \n", listInfo - m_listInfos.data(), sizePostingListFullData, listInfo->listEleCount * m_vectorInfoSize);\ return;\ }\ }\ }\ +#define DecompressPostingIterative(){\ + p_postingListFullData = (char*)p_exWorkSpace->m_decompressBuffer.GetBuffer(); \ + if (listInfo->listEleCount != 0) { \ + std::size_t sizePostingListFullData;\ + try {\ + sizePostingListFullData = m_pCompressor->Decompress(buffer + listInfo->pageOffset, listInfo->listTotalBytes, p_postingListFullData, listInfo->listEleCount * m_vectorInfoSize, m_enableDictTraining);\ + if (sizePostingListFullData != listInfo->listEleCount * m_vectorInfoSize) {\ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "PostingList %d decompressed size not match! %zu, %d, \n", listInfo - m_listInfos.data(), sizePostingListFullData, listInfo->listEleCount * m_vectorInfoSize);\ + }\ + }\ + catch (std::runtime_error& err) {\ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", listInfo - m_listInfos.data(), err.what());\ + }\ + }\ +}\ + #define ProcessPosting() \ for (int i = 0; i < listInfo->listEleCount; i++) { \ uint64_t offsetVectorID, offsetVector;\ - if (m_enablePostingListRearrange) { \ - offsetVectorID = (m_vectorInfoSize - sizeof(int)) * listInfo->listEleCount + sizeof(int) * i; \ - offsetVector = (m_vectorInfoSize - sizeof(int)) * i; \ - } else {\ - offsetVectorID = m_vectorInfoSize * i; \ - offsetVector = offsetVectorID + sizeof(int);\ - }\ + (this->*m_parsePosting)(offsetVectorID, offsetVector, i, listInfo->listEleCount);\ + int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID));\ + if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) { listElements--; continue; } \ + (this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector));\ + auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), p_postingListFullData + offsetVector); \ + queryResults.AddPoint(vectorID, distance2leaf); \ + } \ + +#define ProcessPostingOffset() \ + while (p_exWorkSpace->m_offset < listInfo->listEleCount) { \ + uint64_t offsetVectorID, offsetVector;\ + (this->*m_parsePosting)(offsetVectorID, offsetVector, p_exWorkSpace->m_offset, listInfo->listEleCount);\ + p_exWorkSpace->m_offset++;\ int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID));\ if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) continue; \ - if (m_enableDeltaEncoding){\ - ValueType* headVector = (ValueType*)p_index->GetSample(listInfo->postingID);\ - COMMON::SIMDUtils::ComputeSum(reinterpret_cast(p_postingListFullData + offsetVector), headVector, m_iDataDimension);\ - }\ + if (p_exWorkSpace->m_filterFunc != nullptr && !p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(vectorID))) continue; \ + (this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector));\ auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), p_postingListFullData + offsetVector); \ queryResults.AddPoint(vectorID, distance2leaf); \ + foundResult = true;\ + break;\ + } \ + if (p_exWorkSpace->m_offset == listInfo->listEleCount) { \ + p_exWorkSpace->m_pi++; \ + p_exWorkSpace->m_offset = 0; \ } \ template - class ExtraFullGraphSearcher : public IExtraSearcher + class ExtraStaticSearcher : public IExtraSearcher { public: - ExtraFullGraphSearcher() + ExtraStaticSearcher() { m_enableDeltaEncoding = false; m_enablePostingListRearrange = false; @@ -133,62 +159,105 @@ namespace SPTAG m_enableDictTraining = true; } - virtual ~ExtraFullGraphSearcher() + virtual ~ExtraStaticSearcher() + { + } + + virtual bool Available() override + { + return m_available; + } + + void InitWorkSpace(ExtraWorkSpace* p_exWorkSpace, bool clear = false) override { + if (clear) { + p_exWorkSpace->Clear(m_opt->m_searchInternalResultNum, max(m_opt->m_postingPageLimit, m_opt->m_searchPostingPageLimit + 1) << PageSizeEx, false, m_opt->m_enableDataCompression); + } + else { + p_exWorkSpace->Initialize(m_opt->m_maxCheck, m_opt->m_hashExp, m_opt->m_searchInternalResultNum, max(m_opt->m_postingPageLimit, m_opt->m_searchPostingPageLimit + 1) << PageSizeEx, false, m_opt->m_enableDataCompression); + int wid = 0; + if (m_freeWorkSpaceIds == nullptr || !m_freeWorkSpaceIds->try_pop(wid)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FreeWorkSpaceIds is not initalized or the workspace number is not enough! Please increase iothread number.\n"); + wid = m_workspaceCount.fetch_add(1); + } + for (auto & req : p_exWorkSpace->m_diskRequests) + { + req.m_status = wid; + } + p_exWorkSpace->m_callback = [m_freeWorkSpaceIds = m_freeWorkSpaceIds, wid] () { + if (m_freeWorkSpaceIds) m_freeWorkSpaceIds->push(wid); + }; + } } - virtual bool LoadIndex(Options& p_opt) { + virtual bool LoadIndex(Options& p_opt, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& p_vectorTranslateMap, std::shared_ptr m_index) { m_extraFullGraphFile = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdIndex; std::string curFile = m_extraFullGraphFile; + p_opt.m_searchPostingPageLimit = max(p_opt.m_searchPostingPageLimit, static_cast((p_opt.m_postingVectorLimit * (p_opt.m_dim * sizeof(ValueType) + sizeof(int)) + PageSize - 1) / PageSize)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load index with posting page limit:%d\n", p_opt.m_searchPostingPageLimit); do { auto curIndexFile = f_createAsyncIO(); - if (curIndexFile == nullptr || !curIndexFile->Initialize(curFile.c_str(), std::ios::binary | std::ios::in, + if (curIndexFile == nullptr || !curIndexFile->Initialize(curFile.c_str(), #ifndef _MSC_VER #ifdef BATCH_READ - p_opt.m_searchInternalResultNum, 2, 2, p_opt.m_iSSDNumberOfThreads + O_RDONLY | O_DIRECT, p_opt.m_searchInternalResultNum, 2, 2, p_opt.m_iSSDNumberOfThreads #else - p_opt.m_searchInternalResultNum * p_opt.m_iSSDNumberOfThreads / p_opt.m_ioThreads + 1, 2, 2, p_opt.m_ioThreads + O_RDONLY | O_DIRECT, p_opt.m_searchInternalResultNum * p_opt.m_iSSDNumberOfThreads / p_opt.m_ioThreads + 1, 2, 2, p_opt.m_ioThreads #endif -/* -#ifdef BATCH_READ - max(p_opt.m_searchInternalResultNum*m_vectorInfoSize, 1 << 12), 2, 2, p_opt.m_iSSDNumberOfThreads -#else - p_opt.m_searchInternalResultNum* p_opt.m_iSSDNumberOfThreads / p_opt.m_ioThreads + 1, 2, 2, p_opt.m_ioThreads -#endif -*/ #else - (p_opt.m_searchPostingPageLimit + 1) * PageSize, 2, 2, p_opt.m_ioThreads + GENERIC_READ, (p_opt.m_searchPostingPageLimit + 1) * PageSize, 2, 2, (std::uint16_t)p_opt.m_ioThreads #endif )) { - LOG(Helper::LogLevel::LL_Error, "Cannot open file:%s!\n", curFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open file:%s!\n", curFile.c_str()); return false; } m_indexFiles.emplace_back(curIndexFile); - m_listInfos.emplace_back(0); - m_enableDeltaEncoding = p_opt.m_enableDeltaEncoding; - m_enablePostingListRearrange = p_opt.m_enablePostingListRearrange; - m_enableDataCompression = p_opt.m_enableDataCompression; - m_enableDictTraining = p_opt.m_enableDictTraining; try { - m_totalListCount += LoadingHeadInfo(m_totalListCount, curFile, p_opt.m_searchPostingPageLimit, m_listInfos.back()); + m_totalListCount += LoadingHeadInfo(curFile, p_opt.m_searchPostingPageLimit, m_listInfos); } catch (std::exception& e) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error occurs when loading HeadInfo:%s\n", e.what()); return false; } curFile = m_extraFullGraphFile + "_" + std::to_string(m_indexFiles.size()); } while (fileexists(curFile.c_str())); + m_oneContext = (m_indexFiles.size() == 1); + + m_opt = &p_opt; + m_enableDeltaEncoding = p_opt.m_enableDeltaEncoding; + m_enablePostingListRearrange = p_opt.m_enablePostingListRearrange; + m_enableDataCompression = p_opt.m_enableDataCompression; + m_enableDictTraining = p_opt.m_enableDictTraining; + + if (m_enablePostingListRearrange) m_parsePosting = &ExtraStaticSearcher::ParsePostingListRearrange; + else m_parsePosting = &ExtraStaticSearcher::ParsePostingList; + if (m_enableDeltaEncoding) m_parseEncoding = &ExtraStaticSearcher::ParseDeltaEncoding; + else m_parseEncoding = &ExtraStaticSearcher::ParseEncoding; + m_listPerFile = static_cast((m_totalListCount + m_indexFiles.size() - 1) / m_indexFiles.size()); + p_versionMap.Load(p_opt.m_indexDirectory + FolderSep + p_opt.m_deleteIDFile, p_opt.m_datasetRowsInBlock, p_opt.m_datasetCapacity); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Current vector num: %d.\n", p_versionMap.Count()); + #ifndef _MSC_VER Helper::AIOTimeout.tv_nsec = p_opt.m_iotimeout * 1000; #endif + + m_freeWorkSpaceIds.reset(new Helper::Concurrent::ConcurrentQueue()); + int maxIOThreads = max(p_opt.m_searchThreadNum, p_opt.m_iSSDNumberOfThreads); + for (int i = 0; i < maxIOThreads; i++) { + m_freeWorkSpaceIds->push(i); + } + m_workspaceCount = maxIOThreads; + m_available = true; return true; } - virtual void SearchIndex(ExtraWorkSpace* p_exWorkSpace, + virtual ErrorCode SearchIndex(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_queryResults, std::shared_ptr p_index, SearchStats* p_stats, @@ -206,19 +275,11 @@ namespace SPTAG int unprocessed = 0; #endif - bool oneContext = (m_indexFiles.size() == 1); for (uint32_t pi = 0; pi < postingListCount; ++pi) { auto curPostingID = p_exWorkSpace->m_postingIDs[pi]; - int fileid = 0; - ListInfo* listInfo; - if (oneContext) { - listInfo = &(m_listInfos[0][curPostingID]); - } - else { - fileid = curPostingID / m_listPerFile; - listInfo = &(m_listInfos[fileid][curPostingID % m_listPerFile]); - } + ListInfo* listInfo = &(m_listInfos[curPostingID]); + int fileid = m_oneContext? 0: curPostingID / m_listPerFile; #ifndef BATCH_READ Helper::DiskIO* indexFile = m_indexFiles[fileid].get(); @@ -229,22 +290,20 @@ namespace SPTAG listElements += listInfo->listEleCount; size_t totalBytes = (static_cast(listInfo->listPageCount) << PageSizeEx); - char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[pi]).GetBuffer()); #ifdef ASYNC_READ auto& request = p_exWorkSpace->m_diskRequests[pi]; request.m_offset = listInfo->listOffset; request.m_readSize = totalBytes; - request.m_buffer = buffer; - request.m_status = (fileid << 16) | p_exWorkSpace->m_spaceID; + request.m_status = (fileid << 16) | (request.m_status & 0xffff); request.m_payload = (void*)listInfo; request.m_success = false; #ifdef BATCH_READ // async batch read - request.m_callback = [&p_exWorkSpace, &queryResults, &p_index, m_vectorInfoSize= m_vectorInfoSize, m_enableDeltaEncoding = m_enableDeltaEncoding, m_enablePostingListRearrange = m_enablePostingListRearrange, m_enableDictTraining = m_enableDictTraining, m_enableDataCompression = m_enableDataCompression, &m_pCompressor = m_pCompressor, m_iDataDimension = m_iDataDimension](Helper::AsyncReadRequest *request) + request.m_callback = [&p_exWorkSpace, &queryResults, &p_index, &request, &listElements, this](bool success) { - char* buffer = request->m_buffer; - ListInfo* listInfo = (ListInfo*)(request->m_payload); + char* buffer = request.m_buffer; + ListInfo* listInfo = (ListInfo*)(request.m_payload); // decompress posting list char* p_postingListFullData = buffer + listInfo->pageOffset; @@ -256,22 +315,23 @@ namespace SPTAG ProcessPosting(); }; #else // async read - request.m_callback = [&p_exWorkSpace](Helper::AsyncReadRequest* request) + request.m_callback = [&p_exWorkSpace, &request](bool success) { - p_exWorkSpace->m_processIocp.push(request); + p_exWorkSpace->m_processIocp.push(&request); }; ++unprocessed; if (!(indexFile->ReadFileAsync(request))) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file!\n"); unprocessed--; } #endif #else // sync read + char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[pi]).GetBuffer()); auto numRead = indexFile->ReadBinary(totalBytes, buffer, listInfo->listOffset); if (numRead != totalBytes) { - LOG(Helper::LogLevel::LL_Error, "File %s read bytes, expected: %zu, acutal: %llu.\n", m_extraFullGraphFile.c_str(), totalBytes, numRead); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File %s read bytes, expected: %zu, acutal: %llu.\n", m_extraFullGraphFile.c_str(), totalBytes, numRead); throw std::runtime_error("File read mismatch"); } // decompress posting list @@ -313,7 +373,7 @@ namespace SPTAG { auto curPostingID = p_exWorkSpace->m_postingIDs[pi]; - ListInfo* listInfo = &(m_listInfos[curPostingID / m_listPerFile][curPostingID % m_listPerFile]); + ListInfo* listInfo = &(m_listInfos[curPostingID]); char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[pi]).GetBuffer()); char* p_postingListFullData = buffer + listInfo->pageOffset; @@ -326,7 +386,7 @@ namespace SPTAG m_pCompressor->Decompress(buffer + listInfo->pageOffset, listInfo->listTotalBytes, p_postingListFullData, listInfo->listEleCount * m_vectorInfoSize, m_enableDictTraining); } catch (std::runtime_error& err) { - LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", curPostingID, err.what()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Decompress postingList %d failed! %s, \n", curPostingID, err.what()); continue; } } @@ -346,6 +406,180 @@ namespace SPTAG p_stats->m_diskIOCount = diskIO; p_stats->m_diskAccessCount = diskRead; } + queryResults.SetScanned(listElements); + return ErrorCode::Success; + } + + virtual ErrorCode SearchIndexWithoutParsing(ExtraWorkSpace *p_exWorkSpace) override + { + const uint32_t postingListCount = static_cast(p_exWorkSpace->m_postingIDs.size()); + + int diskRead = 0; + int diskIO = 0; + int listElements = 0; + +#if defined(ASYNC_READ) && !defined(BATCH_READ) + int unprocessed = 0; +#endif + + for (uint32_t pi = 0; pi < postingListCount; ++pi) + { + auto curPostingID = p_exWorkSpace->m_postingIDs[pi]; + ListInfo* listInfo = &(m_listInfos[curPostingID]); + int fileid = m_oneContext ? 0 : curPostingID / m_listPerFile; + +#ifndef BATCH_READ + Helper::DiskIO* indexFile = m_indexFiles[fileid].get(); +#endif + + diskRead += listInfo->listPageCount; + diskIO += 1; + listElements += listInfo->listEleCount; + + size_t totalBytes = (static_cast(listInfo->listPageCount) << PageSizeEx); + +#ifdef ASYNC_READ + auto& request = p_exWorkSpace->m_diskRequests[pi]; + request.m_offset = listInfo->listOffset; + request.m_readSize = totalBytes; + request.m_status = (fileid << 16) | (request.m_status & 0xffff); + request.m_payload = (void*)listInfo; + request.m_success = false; + +#ifdef BATCH_READ // async batch read + request.m_callback = [this](bool success) + { + //char* buffer = request.m_buffer; + //ListInfo* listInfo = (ListInfo*)(request.m_payload); + + // decompress posting list + /* + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + DecompressPosting(); + } + + ProcessPosting(); + */ + }; +#else // async read + request.m_callback = [&p_exWorkSpace, &request](bool success) + { + p_exWorkSpace->m_processIocp.push(&request); + }; + + ++unprocessed; + if (!(indexFile->ReadFileAsync(request))) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file!\n"); + unprocessed--; + } +#endif +#else // sync read + char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[pi]).GetBuffer()); + auto numRead = indexFile->ReadBinary(totalBytes, buffer, listInfo->listOffset); + if (numRead != totalBytes) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File %s read bytes, expected: %zu, acutal: %llu.\n", m_extraFullGraphFile.c_str(), totalBytes, numRead); + return ErrorCode::DiskIOFail; + } + // decompress posting list + /* + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + DecompressPosting(); + } + + ProcessPosting(); + */ +#endif + } + +#ifdef ASYNC_READ +#ifdef BATCH_READ + int retry = 0; + bool success = false; + while (retry < 2 && !success) + { + success = BatchReadFileAsync(m_indexFiles, (p_exWorkSpace->m_diskRequests).data(), postingListCount); + retry++; + } +#else + while (unprocessed > 0) + { + Helper::AsyncReadRequest* request; + if (!(p_exWorkSpace->m_processIocp.pop(request))) break; + + --unprocessed; + char* buffer = request->m_buffer; + ListInfo* listInfo = static_cast(request->m_payload); + // decompress posting list + /* + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + DecompressPosting(); + } + + ProcessPosting(); + */ + } +#endif +#endif + return (success)? ErrorCode::Success: ErrorCode::DiskIOFail; + } + + virtual ErrorCode SearchNextInPosting(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_queryResults, + std::shared_ptr& p_index, const VectorIndex* p_spann) override + { + COMMON::QueryResultSet& headResults = *((COMMON::QueryResultSet*) & p_headResults); + COMMON::QueryResultSet& queryResults = *((COMMON::QueryResultSet*) & p_queryResults); + bool foundResult = false; + BasicResult* head = headResults.GetResult(p_exWorkSpace->m_ri); + while (!foundResult && p_exWorkSpace->m_pi < p_exWorkSpace->m_postingIDs.size()) { + if (head && head->VID != -1 && p_exWorkSpace->m_ri <= p_exWorkSpace->m_pi && + (p_exWorkSpace->m_filterFunc == nullptr || p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(head->VID)))) { + queryResults.AddPoint(head->VID, head->Dist); + head = headResults.GetResult(++p_exWorkSpace->m_ri); + foundResult = true; + continue; + } + char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[p_exWorkSpace->m_pi]).GetBuffer()); + ListInfo* listInfo = static_cast(p_exWorkSpace->m_diskRequests[p_exWorkSpace->m_pi].m_payload); + // decompress posting list + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression && p_exWorkSpace->m_offset == 0) + { + DecompressPostingIterative(); + } + ProcessPostingOffset(); + } + if (!foundResult && head && head->VID != -1 && + (p_exWorkSpace->m_filterFunc == nullptr || p_exWorkSpace->m_filterFunc(p_spann->GetMetadata(head->VID)))) { + queryResults.AddPoint(head->VID, head->Dist); + head = headResults.GetResult(++p_exWorkSpace->m_ri); + foundResult = true; + } + if (foundResult) p_queryResults.SetScanned(p_queryResults.GetScanned() + 1); + return (foundResult)? ErrorCode::Success : ErrorCode::VectorNotFound; + } + + virtual ErrorCode SearchIterativeNext(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_query, + std::shared_ptr p_index, const VectorIndex* p_spann) override + { + if (p_exWorkSpace->m_loadPosting) { + ErrorCode ret = SearchIndexWithoutParsing(p_exWorkSpace); + if (ret != ErrorCode::Success) return ret; + p_exWorkSpace->m_ri = 0; + p_exWorkSpace->m_pi = 0; + p_exWorkSpace->m_offset = 0; + p_exWorkSpace->m_loadPosting = false; + } + + return SearchNextInPosting(p_exWorkSpace, p_headResults, p_query, p_index, p_spann); } std::string GetPostingListFullData( @@ -353,8 +587,8 @@ namespace SPTAG size_t p_postingListSize, Selection &p_selections, std::shared_ptr p_fullVectors, - bool m_enableDeltaEncoding = false, - bool m_enablePostingListRearrange = false, + bool p_enableDeltaEncoding = false, + bool p_enablePostingListRearrange = false, const ValueType *headVector = nullptr) { std::string postingListFullData(""); @@ -366,7 +600,7 @@ namespace SPTAG { if (p_selections[selectIdx].node != postingListId) { - LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH! node:%d offset:%zu\n", postingListId, selectIdx); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Selection ID NOT MATCH! node:%d offset:%zu\n", postingListId, selectIdx); throw std::runtime_error("Selection ID mismatch"); } std::string vectorID(""); @@ -376,7 +610,7 @@ namespace SPTAG vectorID.append(reinterpret_cast(&vid), sizeof(int)); ValueType *p_vector = reinterpret_cast(p_fullVectors->GetVector(vid)); - if (m_enableDeltaEncoding) + if (p_enableDeltaEncoding) { DimensionType n = p_fullVectors->Dimension(); std::vector p_vector_delta(n); @@ -391,7 +625,7 @@ namespace SPTAG vector.append(reinterpret_cast(p_vector), p_fullVectors->PerVectorDataSize()); } - if (m_enablePostingListRearrange) + if (p_enablePostingListRearrange) { vectorIDs += vectorID; vectors += vector; @@ -401,43 +635,35 @@ namespace SPTAG postingListFullData += (vectorID + vector); } } - if (m_enablePostingListRearrange) + if (p_enablePostingListRearrange) { return vectors + vectorIDs; } return postingListFullData; } - bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_headIndex, Options& p_opt) { + bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_headIndex, Options& p_opt, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& p_vectorTranslateMap, SizeType upperBound = -1) { std::string outputFile = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdIndex; if (outputFile.empty()) { - LOG(Helper::LogLevel::LL_Error, "Output file can't be empty!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Output file can't be empty!\n"); return false; } + m_opt = &p_opt; int numThreads = p_opt.m_iSSDNumberOfThreads; int candidateNum = p_opt.m_internalResultNum; - std::unordered_set headVectorIDS; + std::unordered_map headVectorIDS; if (p_opt.m_headIDFile.empty()) { - LOG(Helper::LogLevel::LL_Error, "Not found VectorIDTranslate!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Not found VectorIDTranslate!\n"); return false; } + for (int i = 0; i < p_vectorTranslateMap.R(); i++) { - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize((p_opt.m_indexDirectory + FolderSep + p_opt.m_headIDFile).c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "failed open VectorIDTranslate: %s\n", p_opt.m_headIDFile.c_str()); - return false; - } - - std::uint64_t vid; - while (ptr->ReadBinary(sizeof(vid), reinterpret_cast(&vid)) == sizeof(vid)) - { - headVectorIDS.insert(static_cast(vid)); - } - LOG(Helper::LogLevel::LL_Info, "Loaded %u Vector IDs\n", static_cast(headVectorIDS.size())); + headVectorIDS[static_cast(*(p_vectorTranslateMap[i]))] = i; } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loaded %u Vector IDs\n", static_cast(headVectorIDS.size())); SizeType fullCount = 0; size_t vectorInfoSize = 0; @@ -446,9 +672,12 @@ namespace SPTAG fullCount = fullVectors->Count(); vectorInfoSize = fullVectors->PerVectorDataSize() + sizeof(int); } + if (upperBound > 0) fullCount = upperBound; + + p_versionMap.Initialize(fullCount, p_headIndex->m_iDataBlockSize, p_headIndex->m_iDataCapacity); Selection selections(static_cast(fullCount) * p_opt.m_replicaCount, p_opt.m_tmpdir); - LOG(Helper::LogLevel::LL_Info, "Full vector count:%d Edge bytes:%llu selection size:%zu, capacity size:%zu\n", fullCount, sizeof(Edge), selections.m_selections.size(), selections.m_selections.capacity()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Full vector count:%d Edge bytes:%llu selection size:%zu, capacity size:%zu\n", fullCount, sizeof(Edge), selections.m_selections.size(), selections.m_selections.capacity()); std::vector replicaCount(fullCount); std::vector postingListSize(p_headIndex->GetNumSamples()); for (auto& pls : postingListSize) pls = 0; @@ -464,7 +693,7 @@ namespace SPTAG } } { - LOG(Helper::LogLevel::LL_Info, "Preparation done, start candidate searching.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Preparation done, start candidate searching.\n"); SizeType sampleSize = p_opt.m_samples; std::vector samples(sampleSize, 0); for (int i = 0; i < p_opt.m_batches; i++) { @@ -479,12 +708,14 @@ namespace SPTAG return false; } emptySet.clear(); - for (auto vid : headVectorIDS) { - if (vid >= start && vid < end) emptySet.insert(vid - start); + for (auto& pair : headVectorIDS) { + if (pair.first >= start && pair.first < end) emptySet.insert(pair.first - start); } } else { - emptySet = headVectorIDS; + for (auto& pair : headVectorIDS) { + emptySet.insert(pair.first); + } } int sampleNum = 0; @@ -494,16 +725,15 @@ namespace SPTAG } float acc = 0; -#pragma omp parallel for schedule(dynamic) for (int j = 0; j < sampleNum; j++) { COMMON::Utils::atomic_float_add(&acc, COMMON::TruthSet::CalculateRecall(p_headIndex.get(), fullVectors->GetVector(samples[j]), candidateNum)); } acc = acc / sampleNum; - LOG(Helper::LogLevel::LL_Info, "Batch %d vector(%d,%d) loaded with %d vectors (%zu) HeadIndex acc @%d:%f.\n", i, start, end, fullVectors->Count(), selections.m_selections.size(), candidateNum, acc); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Batch %d vector(%d,%d) loaded with %d vectors (%zu) HeadIndex acc @%d:%f.\n", i, start, end, fullVectors->Count(), selections.m_selections.size(), candidateNum, acc); p_headIndex->ApproximateRNG(fullVectors, emptySet, candidateNum, selections.m_selections.data(), p_opt.m_replicaCount, numThreads, p_opt.m_gpuSSDNumTrees, p_opt.m_gpuSSDLeafSize, p_opt.m_rngFactor, p_opt.m_numGPUs); - LOG(Helper::LogLevel::LL_Info, "Batch %d finished!\n", i); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Batch %d finished!\n", i); for (SizeType j = start; j < end; j++) { replicaCount[j] = 0; @@ -514,6 +744,11 @@ namespace SPTAG selections[vecOffset + resNum].tonode = j; ++replicaCount[j]; } + } else if (!p_opt.m_excludehead) { + selections[vecOffset].node = headVectorIDS[j]; + selections[vecOffset].tonode = j; + ++postingListSize[selections[vecOffset].node]; + ++replicaCount[j]; } } @@ -527,7 +762,7 @@ namespace SPTAG } } auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Searching replicas ended. Search Time: %.2lf mins\n", ((double)std::chrono::duration_cast(t2 - t1).count()) / 60.0); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Searching replicas ended. Search Time: %.2lf mins\n", ((double)std::chrono::duration_cast(t2 - t1).count()) / 60.0); if (p_opt.m_batches > 1) { @@ -541,15 +776,18 @@ namespace SPTAG VectorIndex::SortSelections(&selections.m_selections); auto t3 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Time to sort selections:%.2lf sec.\n", ((double)std::chrono::duration_cast(t3 - t2).count()) + ((double)std::chrono::duration_cast(t3 - t2).count()) / 1000); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Time to sort selections:%.2lf sec.\n", ((double)std::chrono::duration_cast(t3 - t2).count()) + ((double)std::chrono::duration_cast(t3 - t2).count()) / 1000); int postingSizeLimit = INT_MAX; if (p_opt.m_postingPageLimit > 0) { + p_opt.m_postingPageLimit = max(p_opt.m_postingPageLimit, static_cast((p_opt.m_postingVectorLimit * vectorInfoSize + PageSize - 1) / PageSize)); + p_opt.m_searchPostingPageLimit = p_opt.m_postingPageLimit; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build index with posting page limit:%d\n", p_opt.m_postingPageLimit); postingSizeLimit = static_cast(p_opt.m_postingPageLimit * PageSize / vectorInfoSize); } - LOG(Helper::LogLevel::LL_Info, "Posting size limit: %d\n", postingSizeLimit); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Posting size limit: %d\n", postingSizeLimit); { std::vector replicaCountDist(p_opt.m_replicaCount + 1, 0); @@ -559,34 +797,59 @@ namespace SPTAG ++replicaCountDist[replicaCount[i]]; } - LOG(Helper::LogLevel::LL_Info, "Before Posting Cut:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Before Posting Cut:\n"); for (int i = 0; i < replicaCountDist.size(); ++i) { - LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); } } - -#pragma omp parallel for schedule(dynamic) - for (int i = 0; i < postingListSize.size(); ++i) { - if (postingListSize[i] <= postingSizeLimit) continue; - - std::size_t selectIdx = std::lower_bound(selections.m_selections.begin(), selections.m_selections.end(), i, Selection::g_edgeComparer) - selections.m_selections.begin(); - - for (size_t dropID = postingSizeLimit; dropID < postingListSize[i]; ++dropID) + std::vector mythreads; + mythreads.reserve(m_opt->m_iSSDNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_opt->m_iSSDNumberOfThreads; tid++) { - int tonode = selections.m_selections[selectIdx + dropID].tonode; - --replicaCount[tonode]; + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < postingListSize.size()) + { + if (postingListSize[i] <= postingSizeLimit) + continue; + + std::size_t selectIdx = + std::lower_bound(selections.m_selections.begin(), selections.m_selections.end(), + i, Selection::g_edgeComparer) - + selections.m_selections.begin(); + + for (size_t dropID = postingSizeLimit; dropID < postingListSize[i]; ++dropID) + { + int tonode = selections.m_selections[selectIdx + dropID].tonode; + --replicaCount[tonode]; + } + postingListSize[i] = postingSizeLimit; + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); } - postingListSize[i] = postingSizeLimit; + mythreads.clear(); } - if (p_opt.m_outputEmptyReplicaID) { std::vector replicaCountDist(p_opt.m_replicaCount + 1, 0); auto ptr = SPTAG::f_createIO(); if (ptr == nullptr || !ptr->Initialize("EmptyReplicaID.bin", std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Fail to create EmptyReplicaID.bin!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create EmptyReplicaID.bin!\n"); return false; } for (int i = 0; i < replicaCount.size(); ++i) @@ -599,21 +862,21 @@ namespace SPTAG { long long vid = i; if (ptr->WriteBinary(sizeof(vid), reinterpret_cast(&vid)) != sizeof(vid)) { - LOG(Helper::LogLevel::LL_Error, "Failt to write EmptyReplicaID.bin!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failt to write EmptyReplicaID.bin!"); return false; } } } - LOG(Helper::LogLevel::LL_Info, "After Posting Cut:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After Posting Cut:\n"); for (int i = 0; i < replicaCountDist.size(); ++i) { - LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Replica Count Dist: %d, %d\n", i, replicaCountDist[i]); } } auto t4 = std::chrono::high_resolution_clock::now(); - LOG(SPTAG::Helper::LogLevel::LL_Info, "Time to perform posting cut:%.2lf sec.\n", ((double)std::chrono::duration_cast(t4 - t3).count()) + ((double)std::chrono::duration_cast(t4 - t3).count()) / 1000); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Time to perform posting cut:%.2lf sec.\n", ((double)std::chrono::duration_cast(t4 - t3).count()) + ((double)std::chrono::duration_cast(t4 - t3).count()) / 1000); // number of posting lists per file size_t postingFileSize = (postingListSize.size() + p_opt.m_ssdIndexFileNum - 1) / p_opt.m_ssdIndexFileNum; @@ -658,7 +921,7 @@ namespace SPTAG m_pCompressor = std::make_unique(p_opt.m_zstdCompressLevel, p_opt.m_dictBufferCapacity); // train dict if (p_opt.m_enableDictTraining) { - LOG(Helper::LogLevel::LL_Info, "Training dictionary...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Training dictionary...\n"); std::string samplesBuffer(""); std::vector samplesSizes; for (int j = 0; j < curPostingListSizes.size(); j++) { @@ -677,42 +940,79 @@ namespace SPTAG samplesSizes.push_back(postingListFullData.size()); if (samplesBuffer.size() > p_opt.m_minDictTraingBufferSize) break; } - LOG(Helper::LogLevel::LL_Info, "Using the first %zu postingLists to train dictionary... \n", samplesSizes.size()); - std::size_t dictSize = m_pCompressor->TrainDict(samplesBuffer, &samplesSizes[0], samplesSizes.size()); - LOG(Helper::LogLevel::LL_Info, "Dictionary trained, dictionary size: %zu \n", dictSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using the first %zu postingLists to train dictionary... \n", samplesSizes.size()); + std::size_t dictSize = m_pCompressor->TrainDict(samplesBuffer, &samplesSizes[0], (unsigned int)samplesSizes.size()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Dictionary trained, dictionary size: %zu \n", dictSize); } } if (p_opt.m_enableDataCompression) { - LOG(Helper::LogLevel::LL_Info, "Getting compressed size of each posting list...\n"); -#pragma omp parallel for schedule(dynamic) - for (int j = 0; j < curPostingListSizes.size(); j++) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Getting compressed size of each posting list...\n"); + std::vector mythreads; + mythreads.reserve(m_opt->m_iSSDNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_opt->m_iSSDNumberOfThreads; tid++) { - int postingListId = j + curPostingListOffSet; - // do not compress if no data - if (postingListSize[postingListId] == 0) { - curPostingListBytes[j] = 0; - continue; - } - ValueType* headVector = nullptr; - if (p_opt.m_enableDeltaEncoding) - { - headVector = (ValueType*)p_headIndex->GetSample(postingListId); - } - std::string postingListFullData = GetPostingListFullData( - postingListId, postingListSize[postingListId], selections, fullVectors, p_opt.m_enableDeltaEncoding, p_opt.m_enablePostingListRearrange, headVector); - size_t sizeToCompress = postingListSize[postingListId] * vectorInfoSize; - if (sizeToCompress != postingListFullData.size()) { - LOG(Helper::LogLevel::LL_Error, "Size to compress NOT MATCH! PostingListFullData size: %zu sizeToCompress: %zu \n", postingListFullData.size(), sizeToCompress); - } - curPostingListBytes[j] = m_pCompressor->GetCompressedSize(postingListFullData, p_opt.m_enableDictTraining); - if (postingListId % 10000 == 0 || curPostingListBytes[j] > static_cast(p_opt.m_postingPageLimit) * PageSize) { - LOG(Helper::LogLevel::LL_Info, "Posting list %d/%d, compressed size: %d, compression ratio: %.4f\n", postingListId, postingListSize.size(), curPostingListBytes[j], curPostingListBytes[j] / float(sizeToCompress)); - } + mythreads.emplace_back([&, tid]() { + size_t j = 0; + while (true) + { + j = sent.fetch_add(1); + if (j < curPostingListSizes.size()) + { + SizeType postingListId = j + (SizeType)curPostingListOffSet; + // do not compress if no data + if (postingListSize[postingListId] == 0) + { + curPostingListBytes[j] = 0; + continue; + } + ValueType *headVector = nullptr; + if (p_opt.m_enableDeltaEncoding) + { + headVector = (ValueType *)p_headIndex->GetSample(postingListId); + } + std::string postingListFullData = + GetPostingListFullData(postingListId, postingListSize[postingListId], + selections, fullVectors, p_opt.m_enableDeltaEncoding, + p_opt.m_enablePostingListRearrange, headVector); + size_t sizeToCompress = postingListSize[postingListId] * vectorInfoSize; + if (sizeToCompress != postingListFullData.size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Size to compress NOT MATCH! PostingListFullData size: %zu " + "sizeToCompress: %zu \n", + postingListFullData.size(), sizeToCompress); + } + curPostingListBytes[j] = m_pCompressor->GetCompressedSize( + postingListFullData, p_opt.m_enableDictTraining); + if (postingListId % 10000 == 0 || + curPostingListBytes[j] > + static_cast(p_opt.m_postingPageLimit) * PageSize) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Info, + "Posting list %d/%d, compressed size: %d, compression ratio: %.4f\n", + postingListId, postingListSize.size(), curPostingListBytes[j], + curPostingListBytes[j] / float(sizeToCompress)); + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); } - LOG(Helper::LogLevel::LL_Info, "Getted compressed size for all the %d posting lists in SSD Index file %d.\n", curPostingListBytes.size(), i); - LOG(Helper::LogLevel::LL_Info, "Mean compressed size: %.4f \n", std::accumulate(curPostingListBytes.begin(), curPostingListBytes.end(), 0.0) / curPostingListBytes.size()); - LOG(Helper::LogLevel::LL_Info, "Mean compression ratio: %.4f \n", std::accumulate(curPostingListBytes.begin(), curPostingListBytes.end(), 0.0) / (std::accumulate(curPostingListSizes.begin(), curPostingListSizes.end(), 0.0) * vectorInfoSize)); + mythreads.clear(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Getted compressed size for all the %d posting lists in SSD Index file %d.\n", curPostingListBytes.size(), i); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Mean compressed size: %.4f \n", std::accumulate(curPostingListBytes.begin(), curPostingListBytes.end(), 0.0) / curPostingListBytes.size()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Mean compression ratio: %.4f \n", std::accumulate(curPostingListBytes.begin(), curPostingListBytes.end(), 0.0) / (std::accumulate(curPostingListSizes.begin(), curPostingListSizes.end(), 0.0) * vectorInfoSize)); } else { for (int j = 0; j < curPostingListSizes.size(); j++) @@ -743,27 +1043,126 @@ namespace SPTAG curPostingListOffSet); } + p_versionMap.Save(p_opt.m_indexDirectory + FolderSep + p_opt.m_deleteIDFile); + auto t5 = std::chrono::high_resolution_clock::now(); - double elapsedSeconds = std::chrono::duration_cast(t5 - t1).count(); - LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedSeconds / 60.0, elapsedSeconds / 3600.0); - + auto elapsedSeconds = std::chrono::duration_cast(t5 - t1).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedSeconds / 60.0, elapsedSeconds / 3600.0); return true; } virtual bool CheckValidPosting(SizeType postingID) { - bool oneContext = (m_indexFiles.size() == 1); - int fileid = 0; - ListInfo* listInfo; - if (oneContext) { - listInfo = &(m_listInfos[0][postingID]); + return m_listInfos[postingID].listEleCount != 0; + } + + virtual ErrorCode CheckPosting(SizeType postingID, std::vector *visited = nullptr, + ExtraWorkSpace *p_exWorkSpace = nullptr) override + { + if (postingID < 0 || postingID >= m_totalListCount) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting]: Error postingID %d (should be 0 ~ %d)\n", + postingID, m_totalListCount); + return ErrorCode::Key_OverFlow; } - else { - fileid = postingID / m_listPerFile; - listInfo = &(m_listInfos[fileid][postingID % m_listPerFile]); + if (m_listInfos[postingID].listEleCount < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[CheckPosting]: postingID %d has wrong size:%d\n", + postingID, m_listInfos[postingID].listEleCount); + return ErrorCode::Posting_SizeError; } + return ErrorCode::Success; + } - return listInfo->listEleCount != 0; + virtual ErrorCode GetPostingDebug(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr p_index, SizeType vid, std::vector& VIDs, std::shared_ptr& vecs) + { + VIDs.clear(); + + SizeType curPostingID = vid; + ListInfo* listInfo = &(m_listInfos[curPostingID]); + VIDs.resize(listInfo->listEleCount); + ByteArray vector_array = ByteArray::Alloc(sizeof(ValueType) * listInfo->listEleCount * m_iDataDimension); + vecs.reset(new BasicVectorSet(vector_array, GetEnumValueType(), m_iDataDimension, listInfo->listEleCount)); + + int fileid = m_oneContext ? 0 : curPostingID / m_listPerFile; + +#ifndef BATCH_READ + Helper::DiskIO* indexFile = m_indexFiles[fileid].get(); +#endif + + size_t totalBytes = (static_cast(listInfo->listPageCount) << PageSizeEx); + +#ifdef ASYNC_READ + auto& request = p_exWorkSpace->m_diskRequests[0]; + request.m_offset = listInfo->listOffset; + request.m_readSize = totalBytes; + request.m_status = (fileid << 16) | (request.m_status & 0xffff); + request.m_payload = (void*)listInfo; + request.m_success = false; + +#ifdef BATCH_READ // async batch read + request.m_callback = [&p_exWorkSpace, &vecs, &VIDs, &p_index, &request, this](bool success) + { + char* buffer = request.m_buffer; + ListInfo* listInfo = (ListInfo*)(request.m_payload); + + // decompress posting list + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + DecompressPosting(); + } + + for (int i = 0; i < listInfo->listEleCount; i++) + { + uint64_t offsetVectorID, offsetVector; + (this->*m_parsePosting)(offsetVectorID, offsetVector, i, listInfo->listEleCount); + int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID)); + (this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector)); + VIDs[i] = vectorID; + auto outVec = vecs->GetVector(i); + memcpy(outVec, (void*)(p_postingListFullData + offsetVector), sizeof(ValueType) * m_iDataDimension); + } + }; +#else // async read + request.m_callback = [&p_exWorkSpace, &request](bool success) + { + p_exWorkSpace->m_processIocp.push(&request); + }; + + ++unprocessed; + if (!(indexFile->ReadFileAsync(request))) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file!\n"); + unprocessed--; + } +#endif +#else // sync read + char* buffer = (char*)((p_exWorkSpace->m_pageBuffers[0]).GetBuffer()); + auto numRead = indexFile->ReadBinary(totalBytes, buffer, listInfo->listOffset); + if (numRead != totalBytes) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File %s read bytes, expected: %zu, acutal: %llu.\n", m_extraFullGraphFile.c_str(), totalBytes, numRead); + throw std::runtime_error("File read mismatch"); + } + // decompress posting list + char* p_postingListFullData = buffer + listInfo->pageOffset; + if (m_enableDataCompression) + { + DecompressPosting(); + } + + for (int i = 0; i < listInfo->listEleCount; i++) + { + uint64_t offsetVectorID, offsetVector; + (this->*m_parsePosting)(offsetVectorID, offsetVector, i, listInfo->listEleCount); + int vectorID = *(reinterpret_cast(p_postingListFullData + offsetVectorID)); + (this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector)); + VIDs[i] = vectorID; + auto outVec = vecs->GetVector(i); + memcpy(outVec, (void*)(p_postingListFullData + offsetVector), sizeof(ValueType) * m_iDataDimension); + } +#endif + return ErrorCode::Success; } private: @@ -773,8 +1172,6 @@ namespace SPTAG int listEleCount = 0; - std::uint32_t postingID = 0; - std::uint16_t listPageCount = 0; std::uint64_t listOffset = 0; @@ -782,11 +1179,11 @@ namespace SPTAG std::uint16_t pageOffset = 0; }; - int LoadingHeadInfo(const int m_totalListCount, const std::string& p_file, int p_postingPageLimit, std::vector& m_listInfos) + int LoadingHeadInfo(const std::string& p_file, int p_postingPageLimit, std::vector& p_listInfos) { auto ptr = SPTAG::f_createIO(); if (ptr == nullptr || !ptr->Initialize(p_file.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to open file: %s\n", p_file.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open file: %s\n", p_file.c_str()); throw std::runtime_error("Failed open file in LoadingHeadInfo"); } m_pCompressor = std::make_unique(); // no need compress level to decompress @@ -796,29 +1193,30 @@ namespace SPTAG int m_listPageOffset; if (ptr->ReadBinary(sizeof(m_listCount), reinterpret_cast(&m_listCount)) != sizeof(m_listCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } if (ptr->ReadBinary(sizeof(m_totalDocumentCount), reinterpret_cast(&m_totalDocumentCount)) != sizeof(m_totalDocumentCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } if (ptr->ReadBinary(sizeof(m_iDataDimension), reinterpret_cast(&m_iDataDimension)) != sizeof(m_iDataDimension)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } if (ptr->ReadBinary(sizeof(m_listPageOffset), reinterpret_cast(&m_listPageOffset)) != sizeof(m_listPageOffset)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } if (m_vectorInfoSize == 0) m_vectorInfoSize = m_iDataDimension * sizeof(ValueType) + sizeof(int); else if (m_vectorInfoSize != m_iDataDimension * sizeof(ValueType) + sizeof(int)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file! DataDimension and ValueType are not match!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file! DataDimension and ValueType are not match!\n"); throw std::runtime_error("DataDimension and ValueType don't match in LoadingHeadInfo"); } - m_listInfos.resize(m_listCount); + size_t totalListCount = p_listInfos.size(); + p_listInfos.resize(totalListCount + m_listCount); size_t totalListElementCount = 0; @@ -829,44 +1227,45 @@ namespace SPTAG int pageNum; for (int i = 0; i < m_listCount; ++i) { - m_listInfos[i].postingID = m_totalListCount + i; + ListInfo* listInfo = &(p_listInfos[totalListCount + i]); + if (m_enableDataCompression) { - if (ptr->ReadBinary(sizeof(m_listInfos[i].listTotalBytes), reinterpret_cast(&(m_listInfos[i].listTotalBytes))) != sizeof(m_listInfos[i].listTotalBytes)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + if (ptr->ReadBinary(sizeof(listInfo->listTotalBytes), reinterpret_cast(&(listInfo->listTotalBytes))) != sizeof(listInfo->listTotalBytes)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } } if (ptr->ReadBinary(sizeof(pageNum), reinterpret_cast(&(pageNum))) != sizeof(pageNum)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } - if (ptr->ReadBinary(sizeof(m_listInfos[i].pageOffset), reinterpret_cast(&(m_listInfos[i].pageOffset))) != sizeof(m_listInfos[i].pageOffset)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + if (ptr->ReadBinary(sizeof(listInfo->pageOffset), reinterpret_cast(&(listInfo->pageOffset))) != sizeof(listInfo->pageOffset)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } - if (ptr->ReadBinary(sizeof(m_listInfos[i].listEleCount), reinterpret_cast(&(m_listInfos[i].listEleCount))) != sizeof(m_listInfos[i].listEleCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + if (ptr->ReadBinary(sizeof(listInfo->listEleCount), reinterpret_cast(&(listInfo->listEleCount))) != sizeof(listInfo->listEleCount)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } - if (ptr->ReadBinary(sizeof(m_listInfos[i].listPageCount), reinterpret_cast(&(m_listInfos[i].listPageCount))) != sizeof(m_listInfos[i].listPageCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + if (ptr->ReadBinary(sizeof(listInfo->listPageCount), reinterpret_cast(&(listInfo->listPageCount))) != sizeof(listInfo->listPageCount)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } - m_listInfos[i].listOffset = (static_cast(m_listPageOffset + pageNum) << PageSizeEx); + listInfo->listOffset = (static_cast(m_listPageOffset + pageNum) << PageSizeEx); if (!m_enableDataCompression) { - m_listInfos[i].listTotalBytes = m_listInfos[i].listEleCount * m_vectorInfoSize; - m_listInfos[i].listEleCount = min(m_listInfos[i].listEleCount, (min(static_cast(m_listInfos[i].listPageCount), p_postingPageLimit) << PageSizeEx) / m_vectorInfoSize); - m_listInfos[i].listPageCount = static_cast(ceil((m_vectorInfoSize * m_listInfos[i].listEleCount + m_listInfos[i].pageOffset) * 1.0 / (1 << PageSizeEx))); + listInfo->listTotalBytes = listInfo->listEleCount * m_vectorInfoSize; + listInfo->listEleCount = min(listInfo->listEleCount, (min(static_cast(listInfo->listPageCount), p_postingPageLimit) << PageSizeEx) / m_vectorInfoSize); + listInfo->listPageCount = static_cast(ceil((m_vectorInfoSize * listInfo->listEleCount + listInfo->pageOffset) * 1.0 / (1 << PageSizeEx))); } - totalListElementCount += m_listInfos[i].listEleCount; - int pageCount = m_listInfos[i].listPageCount; + totalListElementCount += listInfo->listEleCount; + int pageCount = listInfo->listPageCount; if (pageCount > 1) { ++biglistCount; - biglistElementCount += m_listInfos[i].listEleCount; + biglistElementCount += listInfo->listEleCount; } if (pageCountDist.count(pageCount) == 0) @@ -883,46 +1282,66 @@ namespace SPTAG { size_t dictBufferSize; if (ptr->ReadBinary(sizeof(size_t), reinterpret_cast(&dictBufferSize)) != sizeof(dictBufferSize)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } char* dictBuffer = new char[dictBufferSize]; if (ptr->ReadBinary(dictBufferSize, dictBuffer) != dictBufferSize) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file!\n"); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } try { m_pCompressor->SetDictBuffer(std::string(dictBuffer, dictBufferSize)); } catch (std::runtime_error& err) { - LOG(Helper::LogLevel::LL_Error, "Failed to read head info file: %s \n", err.what()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head info file: %s \n", err.what()); throw std::runtime_error("Failed read file in LoadingHeadInfo"); } delete[] dictBuffer; } - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish reading header info, list count %d, total doc count %d, dimension %d, list page offset %d.\n", m_listCount, m_totalDocumentCount, m_iDataDimension, m_listPageOffset); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Big page (>4K): list count %zu, total element count %zu.\n", biglistCount, biglistElementCount); - LOG(Helper::LogLevel::LL_Info, "Total Element Count: %llu\n", totalListElementCount); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total Element Count: %llu\n", totalListElementCount); for (auto& ele : pageCountDist) { - LOG(Helper::LogLevel::LL_Info, "Page Count Dist: %d %d\n", ele.first, ele.second); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Page Count Dist: %d %d\n", ele.first, ele.second); } return m_listCount; } + inline void ParsePostingListRearrange(uint64_t& offsetVectorID, uint64_t& offsetVector, int i, int eleCount) + { + offsetVectorID = (m_vectorInfoSize - sizeof(int)) * eleCount + sizeof(int) * i; + offsetVector = (m_vectorInfoSize - sizeof(int)) * i; + } + + inline void ParsePostingList(uint64_t& offsetVectorID, uint64_t& offsetVector, int i, int eleCount) + { + offsetVectorID = m_vectorInfoSize * i; + offsetVector = offsetVectorID + sizeof(int); + } + + inline void ParseDeltaEncoding(std::shared_ptr& p_index, ListInfo* p_info, ValueType* vector) + { + ValueType* headVector = (ValueType*)p_index->GetSample((SizeType)(p_info - m_listInfos.data())); + COMMON::SIMDUtils::ComputeSum(vector, headVector, m_iDataDimension); + } + + inline void ParseEncoding(std::shared_ptr& p_index, ListInfo* p_info, ValueType* vector) { } + void SelectPostingOffset( const std::vector& p_postingListBytes, std::unique_ptr& p_postPageNum, @@ -990,7 +1409,7 @@ namespace SPTAG currOffset += iter->rest; if (currOffset > PageSize) { - LOG(Helper::LogLevel::LL_Error, "Crossing extra pages\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Crossing extra pages\n"); throw std::runtime_error("Read too many pages"); } @@ -1006,14 +1425,14 @@ namespace SPTAG } } - LOG(Helper::LogLevel::LL_Info, "TotalPageNumbers: %d, IndexSize: %llu\n", currPageNum, static_cast(currPageNum) * PageSize + currOffset); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "TotalPageNumbers: %d, IndexSize: %llu\n", currPageNum, static_cast(currPageNum) * PageSize + currOffset); } void OutputSSDIndexFile(const std::string& p_outputFile, - bool m_enableDeltaEncoding, - bool m_enablePostingListRearrange, - bool m_enableDataCompression, - bool m_enableDictTraining, + bool p_enableDeltaEncoding, + bool p_enablePostingListRearrange, + bool p_enableDataCompression, + bool p_enableDictTraining, size_t p_spacePerVector, const std::vector& p_postingListSizes, const std::vector& p_postingListBytes, @@ -1025,7 +1444,7 @@ namespace SPTAG std::shared_ptr p_fullVectors, size_t p_postingListOffset) { - LOG(Helper::LogLevel::LL_Info, "Start output...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start output...\n"); auto t1 = std::chrono::high_resolution_clock::now(); @@ -1034,12 +1453,12 @@ namespace SPTAG // open file while (retry > 0 && (ptr == nullptr || !ptr->Initialize(p_outputFile.c_str(), std::ios::binary | std::ios::out))) { - LOG(Helper::LogLevel::LL_Error, "Failed open file %s, retrying...\n", p_outputFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open file %s, retrying...\n", p_outputFile.c_str()); retry--; } if (ptr == nullptr || !ptr->Initialize(p_outputFile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed open file %s\n", p_outputFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open file %s\n", p_outputFile.c_str()); throw std::runtime_error("Failed to open file for SSD index save"); } // meta size of global info @@ -1047,13 +1466,13 @@ namespace SPTAG // meta size of the posting lists listOffset += (sizeof(int) + sizeof(std::uint16_t) + sizeof(int) + sizeof(std::uint16_t)) * p_postingListSizes.size(); // write listTotalBytes only when enabled data compression - if (m_enableDataCompression) + if (p_enableDataCompression) { listOffset += sizeof(size_t) * p_postingListSizes.size(); } // compression dict - if (m_enableDataCompression && m_enableDictTraining) + if (p_enableDataCompression && p_enableDictTraining) { listOffset += sizeof(size_t); listOffset += m_pCompressor->GetDictBuffer().size(); @@ -1075,28 +1494,28 @@ namespace SPTAG // Number of posting lists int i32Val = static_cast(p_postingListSizes.size()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Number of vectors i32Val = static_cast(p_fullVectors->Count()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Vector dimension i32Val = static_cast(p_fullVectors->Dimension()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Page offset of list content section i32Val = static_cast(listOffset / PageSize); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } @@ -1122,46 +1541,46 @@ namespace SPTAG } } // Total bytes of the posting list, write only when enabled data compression - if (m_enableDataCompression && ptr->WriteBinary(sizeof(postingListByte), reinterpret_cast(&postingListByte)) != sizeof(postingListByte)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + if (p_enableDataCompression && ptr->WriteBinary(sizeof(postingListByte), reinterpret_cast(&postingListByte)) != sizeof(postingListByte)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Page number of the posting list if (ptr->WriteBinary(sizeof(pageNum), reinterpret_cast(&pageNum)) != sizeof(pageNum)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Page offset if (ptr->WriteBinary(sizeof(pageOffset), reinterpret_cast(&pageOffset)) != sizeof(pageOffset)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Number of vectors in the posting list if (ptr->WriteBinary(sizeof(listEleCount), reinterpret_cast(&listEleCount)) != sizeof(listEleCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // Page count of the posting list if (ptr->WriteBinary(sizeof(listPageCount), reinterpret_cast(&listPageCount)) != sizeof(listPageCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } } // compression dict - if (m_enableDataCompression && m_enableDictTraining) + if (p_enableDataCompression && p_enableDictTraining) { std::string dictBuffer = m_pCompressor->GetDictBuffer(); // dict size size_t dictBufferSize = dictBuffer.size(); if (ptr->WriteBinary(sizeof(size_t), reinterpret_cast(&dictBufferSize)) != sizeof(dictBufferSize)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } // dict if (ptr->WriteBinary(dictBuffer.size(), const_cast(dictBuffer.data())) != dictBuffer.size()) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } } @@ -1170,18 +1589,18 @@ namespace SPTAG if (paddingSize > 0) { if (ptr->WriteBinary(paddingSize, reinterpret_cast(paddingVals.get())) != paddingSize) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } } if (static_cast(ptr->TellP()) != listOffset) { - LOG(Helper::LogLevel::LL_Info, "List offset not match!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "List offset not match!\n"); throw std::runtime_error("List offset mismatch"); } - LOG(Helper::LogLevel::LL_Info, "SubIndex Size: %llu bytes, %llu MBytes\n", listOffset, listOffset >> 20); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SubIndex Size: %llu bytes, %llu MBytes\n", listOffset, listOffset >> 20); listOffset = 0; @@ -1192,7 +1611,7 @@ namespace SPTAG std::uint64_t targetOffset = static_cast(p_postPageNum[id]) * PageSize + p_postPageOffset[id]; if (targetOffset < listOffset) { - LOG(Helper::LogLevel::LL_Info, "List offset not match, targetOffset < listOffset!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "List offset not match, targetOffset < listOffset!\n"); throw std::runtime_error("List offset mismatch"); } // write padding vals before the posting list @@ -1200,12 +1619,12 @@ namespace SPTAG { if (targetOffset - listOffset > PageSize) { - LOG(Helper::LogLevel::LL_Error, "Padding size greater than page size!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Padding size greater than page size!\n"); throw std::runtime_error("Padding size mismatch with page size"); } if (ptr->WriteBinary(targetOffset - listOffset, reinterpret_cast(paddingVals.get())) != targetOffset - listOffset) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } @@ -1221,30 +1640,30 @@ namespace SPTAG int postingListId = id + (int)p_postingListOffset; // get posting list full content and write it at once ValueType *headVector = nullptr; - if (m_enableDeltaEncoding) + if (p_enableDeltaEncoding) { headVector = (ValueType *)p_headIndex->GetSample(postingListId); } std::string postingListFullData = GetPostingListFullData( - postingListId, p_postingListSizes[id], p_postingSelections, p_fullVectors, m_enableDeltaEncoding, m_enablePostingListRearrange, headVector); + postingListId, p_postingListSizes[id], p_postingSelections, p_fullVectors, p_enableDeltaEncoding, p_enablePostingListRearrange, headVector); size_t postingListFullSize = p_postingListSizes[id] * p_spacePerVector; if (postingListFullSize != postingListFullData.size()) { - LOG(Helper::LogLevel::LL_Error, "posting list full data size NOT MATCH! postingListFullData.size(): %zu postingListFullSize: %zu \n", postingListFullData.size(), postingListFullSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "posting list full data size NOT MATCH! postingListFullData.size(): %zu postingListFullSize: %zu \n", postingListFullData.size(), postingListFullSize); throw std::runtime_error("Posting list full size mismatch"); } - if (m_enableDataCompression) + if (p_enableDataCompression) { - std::string compressedData = m_pCompressor->Compress(postingListFullData, m_enableDictTraining); + std::string compressedData = m_pCompressor->Compress(postingListFullData, p_enableDictTraining); size_t compressedSize = compressedData.size(); if (compressedSize != p_postingListBytes[id]) { - LOG(Helper::LogLevel::LL_Error, "Compressed size NOT MATCH! compressed size:%zu, pre-calculated compressed size:%zu\n", compressedSize, p_postingListBytes[id]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Compressed size NOT MATCH! compressed size:%zu, pre-calculated compressed size:%zu\n", compressedSize, p_postingListBytes[id]); throw std::runtime_error("Compression size mismatch"); } if (ptr->WriteBinary(compressedSize, compressedData.data()) != compressedSize) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } listOffset += compressedSize; @@ -1253,7 +1672,7 @@ namespace SPTAG { if (ptr->WriteBinary(postingListFullSize, postingListFullData.data()) != postingListFullSize) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } listOffset += postingListFullSize; @@ -1275,23 +1694,51 @@ namespace SPTAG { if (ptr->WriteBinary(paddingSize, reinterpret_cast(paddingVals.get())) != paddingSize) { - LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write SSDIndex File!"); throw std::runtime_error("Failed to write SSDIndex File"); } } - LOG(Helper::LogLevel::LL_Info, "Padded Size: %llu, final total size: %llu.\n", paddedSize, listOffset); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Padded Size: %llu, final total size: %llu.\n", paddedSize, listOffset); - LOG(Helper::LogLevel::LL_Info, "Output done...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Output done...\n"); auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Time to write results:%.2lf sec.\n", ((double)std::chrono::duration_cast(t2 - t1).count()) + ((double)std::chrono::duration_cast(t2 - t1).count()) / 1000); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Time to write results:%.2lf sec.\n", ((double)std::chrono::duration_cast(t2 - t1).count()) + ((double)std::chrono::duration_cast(t2 - t1).count()) / 1000); + } + + ErrorCode GetWritePosting(ExtraWorkSpace* p_exWorkSpace, SizeType pid, std::string& posting, bool write = false) override { + if (write) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unsupport write\n"); + return ErrorCode::Undefined; + } + ListInfo* listInfo = &(m_listInfos[pid]); + size_t totalBytes = (static_cast(listInfo->listPageCount) << PageSizeEx); + size_t realBytes = listInfo->listEleCount * m_vectorInfoSize; + posting.resize(totalBytes); + int fileid = m_oneContext? 0: pid / m_listPerFile; + Helper::DiskIO* indexFile = m_indexFiles[fileid].get(); + auto numRead = indexFile->ReadBinary(totalBytes, (char*)posting.data(), listInfo->listOffset); + if (numRead != totalBytes) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File %s read bytes, expected: %zu, acutal: %llu.\n", m_extraFullGraphFile.c_str(), totalBytes, numRead); + return ErrorCode::DiskIOFail; + } + char* ptr = (char*)(posting.c_str()); + memcpy(ptr, posting.c_str() + listInfo->pageOffset, realBytes); + posting.resize(realBytes); + return ErrorCode::Success; } private: - + bool m_available = false; + + std::shared_ptr> m_freeWorkSpaceIds; + std::atomic m_workspaceCount = 0; + std::string m_extraFullGraphFile; - std::vector> m_listInfos; + std::vector m_listInfos; + bool m_oneContext; + Options* m_opt; std::vector> m_indexFiles; std::unique_ptr m_pCompressor; @@ -1299,6 +1746,9 @@ namespace SPTAG bool m_enablePostingListRearrange; bool m_enableDataCompression; bool m_enableDictTraining; + + void (ExtraStaticSearcher::*m_parsePosting)(uint64_t&, uint64_t&, int, int); + void (ExtraStaticSearcher::*m_parseEncoding)(std::shared_ptr&, ListInfo*, ValueType*); int m_vectorInfoSize = 0; int m_iDataDimension = 0; @@ -1310,4 +1760,4 @@ namespace SPTAG } // namespace SPANN } // namespace SPTAG -#endif // _SPTAG_SPANN_EXTRASEARCHER_H_ +#endif // _SPTAG_SPANN_EXTRASTATICSEARCHER_H_ diff --git a/AnnService/inc/Core/SPANN/IExtraSearcher.h b/AnnService/inc/Core/SPANN/IExtraSearcher.h index 8b5d5d073..95d171854 100644 --- a/AnnService/inc/Core/SPANN/IExtraSearcher.h +++ b/AnnService/inc/Core/SPANN/IExtraSearcher.h @@ -7,12 +7,15 @@ #include "Options.h" #include "inc/Core/VectorIndex.h" +#include "inc/Core/Common/VersionLabel.h" #include "inc/Helper/AsyncFileReader.h" - +#include "inc/Helper/VectorSetReader.h" +#include "inc/Helper/ConcurrentSet.h" #include #include #include #include +#include namespace SPTAG { namespace SPANN { @@ -32,7 +35,11 @@ namespace SPTAG { m_asyncLatency1(0), m_asyncLatency2(0), m_queueLatency(0), - m_sleepLatency(0) + m_sleepLatency(0), + m_compLatency(0), + m_diskReadLatency(0), + m_exSetUpLatency(0), + m_threadID(0) { } @@ -62,69 +69,112 @@ namespace SPTAG { double m_sleepLatency; + double m_compLatency; + + double m_diskReadLatency; + + double m_exSetUpLatency; + std::chrono::steady_clock::time_point m_searchRequestTime; int m_threadID; }; - template - class PageBuffer - { - public: - PageBuffer() - : m_pageBufferSize(0) - { - } - - void ReservePageBuffer(std::size_t p_size) - { - if (m_pageBufferSize < p_size) - { - m_pageBufferSize = p_size; - m_pageBuffer.reset(static_cast(PAGE_ALLOC(sizeof(T) * m_pageBufferSize)), [=](T* ptr) { PAGE_FREE(ptr); }); + struct IndexStats { + std::atomic_uint32_t m_headMiss{ 0 }; + uint32_t m_appendTaskNum{ 0 }; + uint32_t m_splitNum{ 0 }; + uint32_t m_theSameHeadNum{ 0 }; + uint32_t m_reAssignNum{ 0 }; + uint32_t m_garbageNum{ 0 }; + uint64_t m_reAssignScanNum{ 0 }; + uint32_t m_mergeNum{ 0 }; + + //Split + double m_splitCost{ 0 }; + double m_getCost{ 0 }; + double m_putCost{ 0 }; + double m_clusteringCost{ 0 }; + double m_updateHeadCost{ 0 }; + double m_reassignScanCost{ 0 }; + double m_reassignScanIOCost{ 0 }; + + // Append + double m_appendCost{ 0 }; + double m_appendIOCost{ 0 }; + + // reAssign + double m_reAssignCost{ 0 }; + double m_selectCost{ 0 }; + double m_reAssignAppendCost{ 0 }; + + // GC + double m_garbageCost{ 0 }; + + void PrintStat(int finishedInsert, bool cost = false, bool reset = false) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "After %d insertion, head vectors split %d times, head missing %d times, same head %d times, reassign %d times, reassign scan %ld times, garbage collection %d times, merge %d times\n", + finishedInsert, m_splitNum, m_headMiss.load(), m_theSameHeadNum, m_reAssignNum, m_reAssignScanNum, m_garbageNum, m_mergeNum); + + if (cost) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AppendTaskNum: %d, TotalCost: %.3lf us, PerCost: %.3lf us\n", m_appendTaskNum, m_appendCost, m_appendCost / m_appendTaskNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AppendTaskNum: %d, AppendIO TotalCost: %.3lf us, PerCost: %.3lf us\n", m_appendTaskNum, m_appendIOCost, m_appendIOCost / m_appendTaskNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, TotalCost: %.3lf ms, PerCost: %.3lf ms\n", m_splitNum, m_splitCost, m_splitCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, Read TotalCost: %.3lf us, PerCost: %.3lf us\n", m_splitNum, m_getCost, m_getCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, Clustering TotalCost: %.3lf us, PerCost: %.3lf us\n", m_splitNum, m_clusteringCost, m_clusteringCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, UpdateHead TotalCost: %.3lf ms, PerCost: %.3lf ms\n", m_splitNum, m_updateHeadCost, m_updateHeadCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, Write TotalCost: %.3lf us, PerCost: %.3lf us\n", m_splitNum, m_putCost, m_putCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, ReassignScan TotalCost: %.3lf ms, PerCost: %.3lf ms\n", m_splitNum, m_reassignScanCost, m_reassignScanCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SplitNum: %d, ReassignScanIO TotalCost: %.3lf us, PerCost: %.3lf us\n", m_splitNum, m_reassignScanIOCost, m_reassignScanIOCost / m_splitNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "GCNum: %d, TotalCost: %.3lf us, PerCost: %.3lf us\n", m_garbageNum, m_garbageCost, m_garbageCost / m_garbageNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReassignNum: %d, TotalCost: %.3lf us, PerCost: %.3lf us\n", m_reAssignNum, m_reAssignCost, m_reAssignCost / m_reAssignNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReassignNum: %d, Select TotalCost: %.3lf us, PerCost: %.3lf us\n", m_reAssignNum, m_selectCost, m_selectCost / m_reAssignNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ReassignNum: %d, ReassignAppend TotalCost: %.3lf us, PerCost: %.3lf us\n", m_reAssignNum, m_reAssignAppendCost, m_reAssignAppendCost / m_reAssignNum); } - } - - T* GetBuffer() - { - return m_pageBuffer.get(); - } - std::size_t GetPageSize() - { - return m_pageBufferSize; + if (reset) { + m_splitNum = 0; + m_headMiss = 0; + m_theSameHeadNum = 0; + m_reAssignNum = 0; + m_reAssignScanNum = 0; + m_mergeNum = 0; + m_garbageNum = 0; + m_appendTaskNum = 0; + m_splitCost = 0; + m_clusteringCost = 0; + m_garbageCost = 0; + m_updateHeadCost = 0; + m_getCost = 0; + m_putCost = 0; + m_reassignScanCost = 0; + m_reassignScanIOCost = 0; + m_appendCost = 0; + m_appendIOCost = 0; + m_reAssignCost = 0; + m_selectCost = 0; + m_reAssignAppendCost = 0; + } } - - private: - std::shared_ptr m_pageBuffer; - - std::size_t m_pageBufferSize; }; - struct ExtraWorkSpace + struct ExtraWorkSpace : public SPTAG::COMMON::IWorkSpace { ExtraWorkSpace() {} - ~ExtraWorkSpace() {} + ~ExtraWorkSpace() { + if (m_callback) { + m_callback(); + } + } ExtraWorkSpace(ExtraWorkSpace& other) { - Initialize(other.m_deduper.MaxCheck(), other.m_deduper.HashTableExponent(), (int)other.m_pageBuffers.size(), (int)(other.m_pageBuffers[0].GetPageSize()), other.m_enableDataCompression); - m_spaceID = g_spaceCount++; + Initialize(other.m_deduper.MaxCheck(), other.m_deduper.HashTableExponent(), (int)other.m_pageBuffers.size(), (int)(other.m_pageBuffers[0].GetPageSize()), other.m_blockIO, other.m_enableDataCompression); } - void Initialize(int p_maxCheck, int p_hashExp, int p_internalResultNum, int p_maxPages, bool enableDataCompression) { - m_postingIDs.reserve(p_internalResultNum); + void Initialize(int p_maxCheck, int p_hashExp, int p_internalResultNum, int p_maxPages, bool p_blockIO, bool enableDataCompression) { m_deduper.Init(p_maxCheck, p_hashExp); - m_processIocp.reset(p_internalResultNum); - m_pageBuffers.resize(p_internalResultNum); - for (int pi = 0; pi < p_internalResultNum; pi++) { - m_pageBuffers[pi].ReservePageBuffer(p_maxPages); - } - m_diskRequests.resize(p_internalResultNum); - m_enableDataCompression = enableDataCompression; - if (enableDataCompression) { - m_decompressBuffer.ReservePageBuffer(p_maxPages); - } + Clear(p_internalResultNum, p_maxPages, p_blockIO, enableDataCompression); + m_relaxedMono = false; } void Initialize(va_list& arg) { @@ -132,11 +182,66 @@ namespace SPTAG { int hashExp = va_arg(arg, int); int internalResultNum = va_arg(arg, int); int maxPages = va_arg(arg, int); + bool blockIo = bool(va_arg(arg, int)); bool enableDataCompression = bool(va_arg(arg, int)); - Initialize(maxCheck, hashExp, internalResultNum, maxPages, enableDataCompression); + Initialize(maxCheck, hashExp, internalResultNum, maxPages, blockIo, enableDataCompression); } - static void Reset() { g_spaceCount = 0; } + void Clear(int p_internalResultNum, int p_maxPages, bool p_blockIO, bool enableDataCompression) { + if (p_internalResultNum > m_pageBuffers.size() || p_maxPages > m_pageBuffers[0].GetPageSize()) { + m_postingIDs.reserve(p_internalResultNum); + m_pageBuffers.resize(p_internalResultNum); + for (int pi = 0; pi < p_internalResultNum; pi++) { + m_pageBuffers[pi].ReservePageBuffer(p_maxPages); + } + m_blockIO = p_blockIO; + if (p_blockIO) { + int numPages = (p_maxPages >> PageSizeEx); + m_diskRequests.resize(p_internalResultNum * numPages); + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "WorkSpace Init %d*%d reqs\n", p_internalResultNum, numPages); + for (int pi = 0; pi < p_internalResultNum; pi++) { + for (int pg = 0; pg < numPages; pg++) { + int rid = pi * numPages + pg; + auto& req = m_diskRequests[rid]; + + req.m_buffer = (char*)(m_pageBuffers[pi].GetBuffer() + ((std::uint64_t)pg << PageSizeEx)); + req.m_extension = &m_processIocp; +#ifdef _MSC_VER + memset(&(req.myres.m_col), 0, sizeof(OVERLAPPED)); + req.myres.m_col.m_data = (void*)(&req); +#else + memset(&(req.myiocb), 0, sizeof(struct iocb)); + req.myiocb.aio_buf = reinterpret_cast(req.m_buffer); + req.myiocb.aio_data = reinterpret_cast(&req); +#endif + } + } + } + else { + m_diskRequests.resize(p_internalResultNum); + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "WorkSpace Init %d reqs\n", p_internalResultNum); + for (int pi = 0; pi < p_internalResultNum; pi++) { + auto& req = m_diskRequests[pi]; + + req.m_buffer = (char*)(m_pageBuffers[pi].GetBuffer()); + req.m_extension = &m_processIocp; +#ifdef _MSC_VER + memset(&(req.myres.m_col), 0, sizeof(OVERLAPPED)); + req.myres.m_col.m_data = (void*)(&req); +#else + memset(&(req.myiocb), 0, sizeof(struct iocb)); + req.myiocb.aio_buf = reinterpret_cast(req.m_buffer); + req.myiocb.aio_data = reinterpret_cast(&req); +#endif + } + } + } + + m_enableDataCompression = enableDataCompression; + if (enableDataCompression) { + m_decompressBuffer.ReservePageBuffer(p_maxPages); + } + } std::vector m_postingIDs; @@ -144,16 +249,31 @@ namespace SPTAG { Helper::RequestQueue m_processIocp; - std::vector> m_pageBuffers; + std::vector> m_pageBuffers; + + bool m_blockIO = false; + + bool m_enableDataCompression = false; - bool m_enableDataCompression; - PageBuffer m_decompressBuffer; + Helper::PageBuffer m_decompressBuffer; std::vector m_diskRequests; - int m_spaceID; + int m_ri = 0; + + int m_pi = 0; + + int m_offset = 0; + + bool m_loadPosting = false; + + bool m_relaxedMono = false; + + int m_loadedPostingNum = 0; + + std::function m_filterFunc; - static std::atomic_int g_spaceCount; + std::function m_callback; }; class IExtraSearcher @@ -163,26 +283,68 @@ namespace SPTAG { { } - virtual ~IExtraSearcher() + ~IExtraSearcher() { } + virtual bool Available() = 0; - virtual bool LoadIndex(Options& p_options) = 0; + virtual bool LoadIndex(Options& p_options, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& m_vectorTranslateMap, std::shared_ptr m_index) = 0; - virtual void SearchIndex(ExtraWorkSpace* p_exWorkSpace, + virtual ErrorCode SearchIndex(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_queryResults, std::shared_ptr p_index, - SearchStats* p_stats, - std::set* truth = nullptr, - std::map>* found = nullptr) = 0; + SearchStats* p_stats, std::set* truth = nullptr, std::map>* found = nullptr) = 0; + + virtual ErrorCode SearchIterativeNext(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_queryResults, + std::shared_ptr p_index, const VectorIndex* p_spann) = 0; + + virtual ErrorCode SearchIndexWithoutParsing(ExtraWorkSpace* p_exWorkSpace) = 0; + + virtual ErrorCode SearchNextInPosting(ExtraWorkSpace* p_exWorkSpace, QueryResult& p_headResults, + QueryResult& p_queryResults, + std::shared_ptr& p_index, const VectorIndex* p_spann) = 0; virtual bool BuildIndex(std::shared_ptr& p_reader, std::shared_ptr p_index, - Options& p_opt) = 0; + Options& p_opt, COMMON::VersionLabel& p_versionMap, COMMON::Dataset& p_vectorTranslateMap, SizeType upperBound = -1) = 0; + + virtual void InitWorkSpace(ExtraWorkSpace* p_exWorkSpace, bool clear = false) = 0; + + virtual ErrorCode GetPostingDebug(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr p_index, SizeType vid, std::vector& VIDs, std::shared_ptr& vecs) = 0; + + virtual ErrorCode RefineIndex(std::shared_ptr p_index, bool p_prereassign = true, + std::vector *p_headmapping = nullptr, + std::vector *p_mapping = nullptr) + { + return ErrorCode::Undefined; + } + virtual ErrorCode AddIndex(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr& p_vectorSet, + std::shared_ptr p_index, SizeType p_begin) { return ErrorCode::Undefined; } + virtual ErrorCode DeleteIndex(SizeType p_id) { return ErrorCode::Undefined; } + + virtual bool AllFinished() { return false; } + virtual void GetDBStats() { return; } + virtual int64_t GetNumBlocks() { return 0; } + virtual void GetIndexStats(int finishedInsert, bool cost, bool reset) { return; } + virtual void ForceCompaction() { return; } virtual bool CheckValidPosting(SizeType postingID) = 0; + virtual ErrorCode CheckPosting(SizeType postingiD, std::vector *visited = nullptr, + ExtraWorkSpace *p_exWorkSpace = nullptr) = 0; + virtual SizeType SearchVector(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr& p_vectorSet, + std::shared_ptr p_index, int testNum = 64, SizeType VID = -1) { return -1; } + virtual void ForceGC(ExtraWorkSpace* p_exWorkSpace, VectorIndex* p_index) { return; } + + virtual ErrorCode GetWritePosting(ExtraWorkSpace *p_exWorkSpace, SizeType pid, std::string &posting, + bool write = false) + { + return ErrorCode::Undefined; + } + + virtual ErrorCode Checkpoint(std::string prefix) { return ErrorCode::Success; } }; } // SPANN } // SPTAG -#endif // _SPTAG_SPANN_IEXTRASEARCHER_H_ \ No newline at end of file +#endif // _SPTAG_SPANN_IEXTRASEARCHER_H_ diff --git a/AnnService/inc/Core/SPANN/Index.h b/AnnService/inc/Core/SPANN/Index.h index de8983b2a..54b33c2b4 100644 --- a/AnnService/inc/Core/SPANN/Index.h +++ b/AnnService/inc/Core/SPANN/Index.h @@ -13,6 +13,9 @@ #include "inc/Core/Common/QueryResultSet.h" #include "inc/Core/Common/BKTree.h" #include "inc/Core/Common/WorkSpacePool.h" +#include "inc/Core/Common/FineGrainedLock.h" +#include "inc/Core/Common/VersionLabel.h" +#include "inc/Core/Common/PostingSizeRecord.h" #include "inc/Core/Common/Labelset.h" #include "inc/Helper/SimpleIniReader.h" @@ -36,27 +39,39 @@ namespace SPTAG class IniReader; } + namespace SPANN { + template + class SPANNResultIterator; + template class Index : public VectorIndex { private: std::shared_ptr m_index; - std::shared_ptr m_vectorTranslateMap; + COMMON::Dataset m_vectorTranslateMap; std::unordered_map m_headParameters; + COMMON::VersionLabel m_versionMap; std::shared_ptr m_extraSearcher; - std::unique_ptr> m_workSpacePool; + std::unique_ptr> m_workSpaceFactory; Options m_options; std::function m_fComputeDistance; int m_iBaseSquare; + std::mutex m_dataAddLock; + std::shared_timed_mutex m_dataDeleteLock; + std::shared_timed_mutex m_checkPointLock; + + + public: Index() { + m_workSpaceFactory = std::make_unique>(); m_fComputeDistance = std::function(COMMON::DistanceCalcSelector(m_options.m_distCalcMethod)); m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; } @@ -67,9 +82,10 @@ namespace SPTAG inline std::shared_ptr GetDiskIndex() { return m_extraSearcher; } inline Options* GetOptions() { return &m_options; } - inline SizeType GetNumSamples() const { return m_options.m_vectorSize; } + inline SizeType GetNumSamples() const { return m_versionMap.Count(); } inline DimensionType GetFeatureDim() const { return m_pQuantizer ? m_pQuantizer->ReconstructDim() : m_index->GetFeatureDim(); } - + inline SizeType GetValueSize() const { return m_options.m_dim * sizeof(T); } + inline int GetCurrMaxCheck() const { return m_options.m_maxCheck; } inline int GetNumThreads() const { return m_options.m_iSSDNumberOfThreads; } inline DistCalcMethod GetDistCalcMethod() const { return m_options.m_distCalcMethod; } @@ -77,7 +93,7 @@ namespace SPTAG inline VectorValueType GetVectorValueType() const { return GetEnumValueType(); } void SetQuantizer(std::shared_ptr quantizer); - + inline float AccurateDistance(const void* pX, const void* pY) const { if (m_options.m_distCalcMethod == DistCalcMethod::L2) return m_fComputeDistance((const T*)pX, (const T*)pY, m_options.m_dim); @@ -87,7 +103,11 @@ namespace SPTAG return 1.0f - xy / (sqrt(xx) * sqrt(yy)); } inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_options.m_dim); } - inline bool ContainSample(const SizeType idx) const { return idx < m_options.m_vectorSize; } + inline float GetDistance(const void* target, const SizeType idx) const { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "GetDistance NOT SUPPORT FOR SPANN"); + return -1; + } + inline bool ContainSample(const SizeType idx) const { return idx >= 0 && idx < m_versionMap.Count() && !m_versionMap.Deleted(idx); } std::shared_ptr> BufferSize() const { @@ -119,24 +139,79 @@ namespace SPTAG ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized = false, bool p_shareOwnership = false); ErrorCode BuildIndex(bool p_normalized = false); ErrorCode SearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const; + + std::shared_ptr GetIterator(const void* p_target, bool p_searchDeleted = false, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + ErrorCode SearchIndexIterativeNext(QueryResult& p_results, COMMON::WorkSpace* workSpace, int batch, int& resultCount, bool p_isFirst, bool p_searchDeleted = false) const; + ErrorCode SearchIndexIterativeEnd(std::unique_ptr workSpace) const; + ErrorCode SearchIndexIterativeEnd(std::unique_ptr extraWorkspace) const; + bool SearchIndexIterativeFromNeareast(QueryResult& p_query, COMMON::WorkSpace* p_space, bool p_isFirst, bool p_searchDeleted = false) const; + std::unique_ptr RentWorkSpace(int batch, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const; + ErrorCode SearchIndexIterative(QueryResult& p_headQuery, QueryResult& p_query, COMMON::WorkSpace* p_indexWorkspace, ExtraWorkSpace* p_extraWorkspace, int p_batch, int& resultCount, bool first) const; + + ErrorCode SearchIndexWithFilter(QueryResult& p_query, std::function filterFunc, int maxCheck = 0, bool p_searchDeleted = false) const; + + ErrorCode SearchDiskIndex(QueryResult& p_query, SearchStats* p_stats = nullptr) const; + ErrorCode SearchDiskIndexIterative(QueryResult& p_headQuery, QueryResult& p_query, ExtraWorkSpace* extraWorkspace) const; ErrorCode DebugSearchDiskIndex(QueryResult& p_query, int p_subInternalResultNum, int p_internalResultNum, - SearchStats* p_stats = nullptr, std::set* truth = nullptr, std::map>* found = nullptr); + SearchStats* p_stats = nullptr, std::set* truth = nullptr, std::map>* found = nullptr) const; ErrorCode UpdateIndex(); ErrorCode SetParameter(const char* p_param, const char* p_value, const char* p_section = nullptr); std::string GetParameter(const char* p_param, const char* p_section = nullptr) const; inline const void* GetSample(const SizeType idx) const { return nullptr; } - inline SizeType GetNumDeleted() const { return 0; } - inline bool NeedRefine() const { return false; } - + inline SizeType GetNumDeleted() const { return m_versionMap.GetDeleteCount(); } + inline bool NeedRefine() const + { + return m_versionMap.GetDeleteCount() > (size_t)(GetNumSamples() * m_options.m_fDeletePercentageForRefine); + } ErrorCode RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const { return ErrorCode::Undefined; } ErrorCode SearchTree(QueryResult& p_query) const { return ErrorCode::Undefined; } - ErrorCode AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false, bool p_normalized = false) { return ErrorCode::Undefined; } - ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum) { return ErrorCode::Undefined; } - ErrorCode DeleteIndex(const SizeType& p_id) { return ErrorCode::Undefined; } - ErrorCode RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort) { return ErrorCode::Undefined; } + ErrorCode AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false, bool p_normalized = false); + ErrorCode DeleteIndex(const SizeType& p_id); + + ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum); + ErrorCode RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping); ErrorCode RefineIndex(std::shared_ptr& p_newIndex) { return ErrorCode::Undefined; } + + ErrorCode Check() override; + + ErrorCode SetWorkSpaceFactory(std::unique_ptr> up_workSpaceFactory) + { + SPTAG::COMMON::IWorkSpaceFactory* raw_generic_ptr = up_workSpaceFactory.release(); + if (!raw_generic_ptr) return ErrorCode::Fail; + + + SPTAG::COMMON::IWorkSpaceFactory* raw_specialized_ptr = dynamic_cast*>(raw_generic_ptr); + if (!raw_specialized_ptr) + { + // If it is of type SPTAG::COMMON::WorkSpace, we should pass on to child index + if (!m_index) + { + delete raw_generic_ptr; + return ErrorCode::Fail; + } + else + { + return m_index->SetWorkSpaceFactory(std::unique_ptr>(raw_generic_ptr)); + } + + } + else + { + m_workSpaceFactory = std::unique_ptr>(raw_specialized_ptr); + return ErrorCode::Success; + } + } + + SizeType GetGlobalVID(SizeType vid) + { + return static_cast(*(m_vectorTranslateMap[vid])); + } + + ErrorCode GetPostingDebug(SizeType vid, std::vector& VIDs, std::shared_ptr& vecs); + private: bool CheckHeadIndexType(); void SelectHeadAdjustOptions(int p_vectorCount); @@ -147,6 +222,117 @@ namespace SPTAG bool SelectHeadInternal(std::shared_ptr& p_reader); ErrorCode BuildIndexInternal(std::shared_ptr& p_reader); + + public: + bool AllFinished() { if (m_options.m_storage != Storage::STATIC) return m_extraSearcher->AllFinished(); return true; } + + void GetDBStat() { + if (m_options.m_storage != Storage::STATIC) m_extraSearcher->GetDBStats(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Current Vector Num: %d, Deleted: %d .\n", GetNumSamples(), GetNumDeleted()); + } + + void GetIndexStat(int finishedInsert, bool cost, bool reset) { if (m_options.m_storage != Storage::STATIC) m_extraSearcher->GetIndexStats(finishedInsert, cost, reset); } + + void ForceCompaction() { if (m_options.m_storage == Storage::ROCKSDBIO) m_extraSearcher->ForceCompaction(); } + + void StopMerge() { m_options.m_inPlace = true; } + + void OpenMerge() { m_options.m_inPlace = false; } + + void ForceGC() { + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); + m_extraSearcher->ForceGC(workSpace.get(), m_index.get()); + } + + ErrorCode Checkpoint() { + /** Lock & wait until all jobs done **/ + while (!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + /** Lock **/ + if (m_options.m_persistentBufferPath == "") return ErrorCode::FailedCreateFile; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Locking Index\n"); + std::unique_lock lock(m_checkPointLock); + + // Flush block pool states & block mapping states + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving storage states\n"); + ErrorCode ret; + if ((ret = m_extraSearcher->Checkpoint(m_options.m_persistentBufferPath)) != ErrorCode::Success) + return ret; + + /** Flush the checkpoint file: SPTAG states, block pool states, block mapping states **/ + std::string filename = m_options.m_persistentBufferPath + FolderSep + m_options.m_headIndexFolder; + // Flush SPTAG + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving in-memory index to %s\n", filename.c_str()); + if ((ret = m_index->SaveIndex(filename)) != ErrorCode::Success) + return ret; + return ErrorCode::Success; + } + + ErrorCode AddIndexSPFresh(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, SizeType* VID) { + if (m_options.m_storage == Storage::STATIC || m_extraSearcher == nullptr) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Only Support KV Extra Update\n"); + return ErrorCode::Fail; + } + + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; + if (p_dimension != GetFeatureDim()) return ErrorCode::DimensionSizeMismatch; + + std::shared_lock lock(m_checkPointLock); + + SizeType begin; + { + std::lock_guard lock(m_dataAddLock); + + begin = m_versionMap.GetVectorNum(); + + if (begin == 0) { return ErrorCode::EmptyIndex; } + + if (m_versionMap.AddBatch(p_vectorNum) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MemoryOverFlow: VID: %d, Map Size:%d\n", begin, m_versionMap.BufferSize()); + return ErrorCode::MemoryOverFlow; + } + } + for (int i = 0; i < p_vectorNum; i++) VID[i] = begin + i; + + std::shared_ptr vectorSet; + if (m_options.m_distCalcMethod == DistCalcMethod::Cosine) { + ByteArray arr = ByteArray::Alloc(sizeof(T) * p_vectorNum * p_dimension); + memcpy(arr.Data(), p_data, sizeof(T) * p_vectorNum * p_dimension); + vectorSet.reset(new BasicVectorSet(arr, GetEnumValueType(), p_dimension, p_vectorNum)); + int base = COMMON::Utils::GetBase(); + for (SizeType i = 0; i < p_vectorNum; i++) { + COMMON::Utils::Normalize((T*)(vectorSet->GetVector(i)), p_dimension, base); + } + } + else { + vectorSet.reset(new BasicVectorSet(ByteArray((std::uint8_t*)p_data, sizeof(T) * p_vectorNum * p_dimension, false), + GetEnumValueType(), p_dimension, p_vectorNum)); + } + + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); + return m_extraSearcher->AddIndex(workSpace.get(), vectorSet, m_index, begin); + } }; } // namespace SPANN } // namespace SPTAG diff --git a/AnnService/inc/Core/SPANN/Options.h b/AnnService/inc/Core/SPANN/Options.h index 7c726c0a4..414111a1a 100644 --- a/AnnService/inc/Core/SPANN/Options.h +++ b/AnnService/inc/Core/SPANN/Options.h @@ -43,6 +43,8 @@ namespace SPTAG { bool m_deleteHeadVectors; int m_ssdIndexFileNum; std::string m_quantizerFilePath; + int m_datasetRowsInBlock; + int m_datasetCapacity; // Section 2: for selecting head bool m_selectHead; @@ -68,9 +70,6 @@ namespace SPTAG { bool m_recursiveCheckSmallCluster; bool m_printSizeCount; std::string m_selectType; - // Dataset constructor args - int m_datasetRowsInBlock; - int m_datasetCapacity; // Section 3: for build head bool m_buildHead; @@ -97,6 +96,18 @@ namespace SPTAG { float m_rngFactor; int m_samples; bool m_excludehead; + int m_postingVectorLimit; + std::string m_fullDeletedIDFile; + Storage m_storage; + std::string m_KVFile; + std::string m_ssdMappingFile; + std::string m_ssdInfoFile; + std::string m_checksumFile; + bool m_useDirectIO; + bool m_preReassign; + float m_preReassignRatio; + bool m_enableWAL; + bool m_disableCheckpoint; // GPU building int m_gpuSSDNumTrees; @@ -122,6 +133,68 @@ namespace SPTAG { bool m_enableADC; int m_iotimeout; + int m_searchThreadNum; + + // Calculating + std::string m_truthFilePrefix; + bool m_calTruth; + bool m_calAllTruth; + int m_searchTimes; + int m_minInternalResultNum; + int m_stepInternalResultNum; + int m_maxInternalResultNum; + bool m_onlySearchFinalBatch; + + // Updating + bool m_disableReassign; + bool m_searchDuringUpdate; + int m_reassignK; + bool m_recovery; + + // Updating(SPFresh Update Test) + bool m_update; + bool m_inPlace; + bool m_outOfPlace; + float m_latencyLimit; + int m_step; + int m_insertThreadNum; + int m_endVectorNum; + std::string m_persistentBufferPath; + int m_appendThreadNum; + int m_reassignThreadNum; + int m_batch; + std::string m_fullVectorPath; + + // Steady State Update + std::string m_updateFilePrefix; + std::string m_updateMappingPrefix; + int m_days; + int m_deleteQPS; + int m_sampling; + bool m_showUpdateProgress; + int m_mergeThreshold; + bool m_loadAllVectors; + bool m_steadyState; + int m_spdkBatchSize; + bool m_stressTest; + int m_bufferLength; + int m_maxFileSize; + int m_startFileSize; + int m_growthFileSize; + float m_growThreshold; + float m_fDeletePercentageForRefine; + bool m_oneClusterCutMax; + bool m_consistencyCheck; + bool m_checksumCheck; + bool m_checksumInRead; + int m_cacheSize; + int m_cacheShards; + + // Iterative + int m_headBatch; + int m_asyncAppendQueueSize; + bool m_allowZeroReplica; + Options() { #define DefineBasicParameter(VarName, VarType, DefaultValue, RepresentStr) \ VarName = DefaultValue; \ @@ -158,7 +231,7 @@ namespace SPTAG { #define DefineBasicParameter(VarName, VarType, DefaultValue, RepresentStr) \ if (Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ VarType tmp; \ if (Helper::Convert::ConvertStringTo(p_value, tmp)) \ { \ @@ -175,7 +248,7 @@ namespace SPTAG { #define DefineSelectHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ if (Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ VarType tmp; \ if (Helper::Convert::ConvertStringTo(p_value, tmp)) \ { \ @@ -192,7 +265,7 @@ namespace SPTAG { #define DefineBuildHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ if (Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ VarType tmp; \ if (Helper::Convert::ConvertStringTo(p_value, tmp)) \ { \ @@ -209,7 +282,7 @@ namespace SPTAG { #define DefineSSDParameter(VarName, VarType, DefaultValue, RepresentStr) \ if (Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ VarType tmp; \ if (Helper::Convert::ConvertStringTo(p_value, tmp)) \ { \ diff --git a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h index 7e3704d63..61c752e4c 100644 --- a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h +++ b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h @@ -22,7 +22,7 @@ DefineBasicParameter(m_warmupSize, SPTAG::SizeType, -1, "WarmupSize") DefineBasicParameter(m_warmupDelimiter, std::string, std::string("|"), "WarmupDelimiter") DefineBasicParameter(m_truthPath, std::string, std::string(""), "TruthPath") DefineBasicParameter(m_truthType, SPTAG::TruthFileType, SPTAG::TruthFileType::Undefined, "TruthType") -DefineBasicParameter(m_generateTruth, bool, false, "GenerateTruth") +DefineBasicParameter(m_generateTruth, bool, false, "GenerateTruth") // Mutable DefineBasicParameter(m_indexDirectory, std::string, std::string("SPANN"), "IndexDirectory") DefineBasicParameter(m_headIDFile, std::string, std::string("SPTAGHeadVectorIDs.bin"), "HeadVectorIDs") DefineBasicParameter(m_deleteIDFile, std::string, std::string("DeletedIDs.bin"), "DeletedIDs") @@ -32,7 +32,8 @@ DefineBasicParameter(m_ssdIndex, std::string, std::string("SPTAGFullList.bin"), DefineBasicParameter(m_deleteHeadVectors, bool, false, "DeleteHeadVectors") DefineBasicParameter(m_ssdIndexFileNum, int, 1, "SSDIndexFileNum") DefineBasicParameter(m_quantizerFilePath, std::string, std::string(), "QuantizerFilePath") - +DefineBasicParameter(m_datasetRowsInBlock, int, 1024 * 1024, "DataBlockSize") +DefineBasicParameter(m_datasetCapacity, int, SPTAG::MaxSize, "DataCapacity") #endif #ifdef DefineSelectHeadParameter @@ -44,7 +45,7 @@ DefineSelectHeadParameter(m_iBKTLeafSize, int, 8, "BKTLeafSize") DefineSelectHeadParameter(m_iSamples, int, 1000, "SamplesNumber") DefineSelectHeadParameter(m_fBalanceFactor, float, -1.0F, "BKTLambdaFactor") -DefineSelectHeadParameter(m_iSelectHeadNumberOfThreads, int, 4, "NumberOfThreads") +DefineSelectHeadParameter(m_iSelectHeadNumberOfThreads, int, 4, "NumberOfThreads") // Mutable DefineSelectHeadParameter(m_saveBKT, bool, false, "SaveBKT") DefineSelectHeadParameter(m_analyzeOnly, bool, false, "AnalyzeOnly") @@ -61,10 +62,6 @@ DefineSelectHeadParameter(m_headVectorCount, int, 0, "Count") DefineSelectHeadParameter(m_recursiveCheckSmallCluster, bool, true, "RecursiveCheckSmallCluster") DefineSelectHeadParameter(m_printSizeCount, bool, true, "PrintSizeCount") DefineSelectHeadParameter(m_selectType, std::string, "BKT", "SelectHeadType") - -DefineSelectHeadParameter(m_datasetRowsInBlock, int, 1024 * 1024, "DataBlockSize") -DefineSelectHeadParameter(m_datasetCapacity, int, SPTAG::MaxSize, "DataCapacity") - #endif #ifdef DefineBuildHeadParameter @@ -76,7 +73,7 @@ DefineBuildHeadParameter(m_buildHead, bool, false, "isExecute") #ifdef DefineSSDParameter DefineSSDParameter(m_enableSSD, bool, false, "isExecute") DefineSSDParameter(m_buildSsdIndex, bool, false, "BuildSsdIndex") -DefineSSDParameter(m_iSSDNumberOfThreads, int, 16, "NumberOfThreads") +DefineSSDParameter(m_iSSDNumberOfThreads, int, 16, "NumberOfThreads") // Mutable DefineSSDParameter(m_enableDeltaEncoding, bool, false, "EnableDeltaEncoding") DefineSSDParameter(m_enablePostingListRearrange, bool, false, "EnablePostingListRearrange") DefineSSDParameter(m_enableDataCompression, bool, false, "EnableDataCompression") @@ -95,6 +92,20 @@ DefineSSDParameter(m_tmpdir, std::string, std::string("."), "TmpDir") DefineSSDParameter(m_rngFactor, float, 1.0f, "RNGFactor") DefineSSDParameter(m_samples, int, 100, "RecallTestSampleNumber") DefineSSDParameter(m_excludehead, bool, true, "ExcludeHead") +DefineSSDParameter(m_postingVectorLimit, int, 118, "PostingVectorLimit") +DefineSSDParameter(m_fullDeletedIDFile, std::string, std::string("fulldeleted"), "FullDeletedIDFile") +DefineSSDParameter(m_storage, SPTAG::Storage, SPTAG::Storage::STATIC, "Storage") +DefineSSDParameter(m_spdkBatchSize, int, 64, "SpdkBatchSize") +DefineSSDParameter(m_KVFile, std::string, std::string("rocksdb"), "KVFile") +DefineSSDParameter(m_ssdMappingFile, std::string, std::string("ssdmapping"), "SsdMappingFile") +DefineSSDParameter(m_ssdInfoFile, std::string, std::string("ssdinfo"), "SsdInfoFile") +DefineSSDParameter(m_checksumFile, std::string, std::string("checksum"), "ChecksumFile") +DefineSSDParameter(m_useDirectIO, bool, false, "UseDirectIO") +DefineSSDParameter(m_preReassign, bool, false, "PreReassign") +DefineSSDParameter(m_preReassignRatio, float, 0.7f, "PreReassignRatio") +DefineSSDParameter(m_bufferLength, int, 3, "BufferLength") +DefineSSDParameter(m_enableWAL, bool, false, "EnableWAL") +DefineSSDParameter(m_disableCheckpoint, bool, false, "DisableCheckpoint") // GPU Building DefineSSDParameter(m_gpuSSDNumTrees, int, 100, "GPUSSDNumTrees") @@ -105,19 +116,101 @@ DefineSSDParameter(m_numGPUs, int, 1, "NumGPUs") DefineSSDParameter(m_searchResult, std::string, std::string(""), "SearchResult") DefineSSDParameter(m_logFile, std::string, std::string(""), "LogFile") DefineSSDParameter(m_qpsLimit, int, 0, "QpsLimit") -DefineSSDParameter(m_resultNum, int, 5, "ResultNum") -DefineSSDParameter(m_truthResultNum, int, -1, "TruthResultNum") -DefineSSDParameter(m_maxCheck, int, 4096, "MaxCheck") +DefineSSDParameter(m_resultNum, int, 5, "ResultNum") // Mutable +DefineSSDParameter(m_truthResultNum, int, -1, "TruthResultNum") // Mutable +DefineSSDParameter(m_maxCheck, int, 4096, "MaxCheck") // Mutable DefineSSDParameter(m_hashExp, int, 4, "HashTableExponent") DefineSSDParameter(m_queryCountLimit, int, (std::numeric_limits::max)(), "QueryCountLimit") DefineSSDParameter(m_maxDistRatio, float, 10000, "MaxDistRatio") -DefineSSDParameter(m_ioThreads, int, 4, "IOThreadsPerHandler") -DefineSSDParameter(m_searchInternalResultNum, int, 64, "SearchInternalResultNum") -DefineSSDParameter(m_searchPostingPageLimit, int, 3, "SearchPostingPageLimit") +DefineSSDParameter(m_ioThreads, int, 4, "IOThreadsPerHandler") // Mutable +DefineSSDParameter(m_searchInternalResultNum, int, 64, "SearchInternalResultNum") // Mutable +DefineSSDParameter(m_searchPostingPageLimit, int, 3, "SearchPostingPageLimit") // Mutable DefineSSDParameter(m_rerank, int, 0, "Rerank") DefineSSDParameter(m_enableADC, bool, false, "EnableADC") DefineSSDParameter(m_recall_analysis, bool, false, "RecallAnalysis") DefineSSDParameter(m_debugBuildInternalResultNum, int, 64, "DebugBuildInternalResultNum") -DefineSSDParameter(m_iotimeout, int, 30, "IOTimeout") +DefineSSDParameter(m_iotimeout, int, 30, "IOTimeout") // Mutable + +// Calculating +// TruthFilePrefix +DefineSSDParameter(m_truthFilePrefix, std::string, std::string(""), "TruthFilePrefix") +// CalTruth +DefineSSDParameter(m_calTruth, bool, true, "CalTruth") // Mutable +DefineSSDParameter(m_onlySearchFinalBatch, bool, false, "OnlySearchFinalBatch") +// Search multiple times for stable result +DefineSSDParameter(m_searchTimes, int, 1, "SearchTimes") +// Frontend search threadnum +DefineSSDParameter(m_searchThreadNum, int, 2, "SearchThreadNum") // Mutable +// Show tradeoff of latency and acurracy +DefineSSDParameter(m_minInternalResultNum, int, -1, "MinInternalResultNum") +DefineSSDParameter(m_stepInternalResultNum, int, -1, "StepInternalResultNum") +DefineSSDParameter(m_maxInternalResultNum, int, -1, "MaxInternalResultNum") + +// Updating(SPFresh Update Test) +// For update mode: current only update +DefineSSDParameter(m_update, bool, false, "Update") +// For Test Mode +DefineSSDParameter(m_inPlace, bool, true, "InPlace") +DefineSSDParameter(m_outOfPlace, bool, false, "OutOfPlace") +// latency limit +DefineSSDParameter(m_latencyLimit, float, 10.0, "LatencyLimit") // Mutable +// Update batch size +DefineSSDParameter(m_step, int, 0, "Step") +// Frontend update threadnum +DefineSSDParameter(m_insertThreadNum, int, 1, "InsertThreadNum") // Mutable +// Update limit +DefineSSDParameter(m_endVectorNum, int, -1, "EndVectorNum") +// Persistent buffer path +DefineSSDParameter(m_persistentBufferPath, std::string, std::string(""), "PersistentBufferPath") +// Background append threadnum +DefineSSDParameter(m_appendThreadNum, int, 1, "AppendThreadNum") // Mutable +// Background reassign threadnum +DefineSSDParameter(m_reassignThreadNum, int, 0, "ReassignThreadNum") // Mutable +// Background process batch size +DefineSSDParameter(m_batch, int, 1000, "Batch") +// Total Vector Path +DefineSSDParameter(m_fullVectorPath, std::string, std::string(""), "FullVectorPath") +// Steady State: update trace +DefineSSDParameter(m_updateFilePrefix, std::string, std::string(""), "UpdateFilePrefix") +// Steady State: update mapping +DefineSSDParameter(m_updateMappingPrefix, std::string, std::string(""), "UpdateMappingPrefix") +// Steady State: days +DefineSSDParameter(m_days, int, 0, "Days") +// Steady State: deleteQPS +DefineSSDParameter(m_deleteQPS, int, -1, "DeleteQPS") +// Steady State: sampling +DefineSSDParameter(m_sampling, int, -1, "Sampling") +// Steady State: showUpdateProgress +DefineSSDParameter(m_showUpdateProgress, bool, true, "ShowUpdateProgress") +// Steady State: Merge Threshold +DefineSSDParameter(m_mergeThreshold, int, 10, "MergeThreshold") +// Steady State: showUpdateProgress +DefineSSDParameter(m_loadAllVectors, bool, false, "LoadAllVectors") +// Steady State: steady state +DefineSSDParameter(m_steadyState, bool, false, "SteadyState") +// Steady State: stress test +DefineSSDParameter(m_stressTest, bool, false, "StressTest") + +// SPANN +DefineSSDParameter(m_disableReassign, bool, false, "DisableReassign") +DefineSSDParameter(m_searchDuringUpdate, bool, false, "SearchDuringUpdate") +DefineSSDParameter(m_reassignK, int, 0, "ReassignK") +DefineSSDParameter(m_recovery, bool, false, "Recovery") +DefineSSDParameter(m_maxFileSize, int, 300, "MaxFileSizeGB") +DefineSSDParameter(m_startFileSize, int, 10, "StartFileSizeGB") +DefineSSDParameter(m_growthFileSize, int, 10, "GrowthFileSizeGB") +DefineSSDParameter(m_growThreshold, float, 0.05, "GrowthThreshold") +DefineSSDParameter(m_fDeletePercentageForRefine, float, 0.4F, "DeletePercentageForRefine") // Mutable +DefineSSDParameter(m_oneClusterCutMax, bool, false, "OneClusterCutMax") // Mutable +DefineSSDParameter(m_consistencyCheck, bool, false, "ConsistencyCheck") // Mutable +DefineSSDParameter(m_checksumCheck, bool, false, "ChecksumCheck") // Mutable +DefineSSDParameter(m_checksumInRead, bool, false, "ChecksumInRead") // Mutable +DefineSSDParameter(m_cacheSize, int, 0, "CacheSizeGB") // Mutable +DefineSSDParameter(m_cacheShards, int, 1, "CacheShards") // Mutable +DefineSSDParameter(m_asyncAppendQueueSize, int, 0, "AsyncAppendQueueSize") // Mutable +DefineSSDParameter(m_allowZeroReplica, bool, false, "AllowZeroReplica") + +// Iterative +DefineSSDParameter(m_headBatch, int, 32, "IterativeSearchHeadBatch") // Mutable #endif diff --git a/AnnService/inc/Core/SPANN/PersistentBuffer.h b/AnnService/inc/Core/SPANN/PersistentBuffer.h new file mode 100644 index 000000000..2fd5bca62 --- /dev/null +++ b/AnnService/inc/Core/SPANN/PersistentBuffer.h @@ -0,0 +1,56 @@ +#include "inc/Helper/KeyValueIO.h" +#include + +namespace SPTAG { + namespace SPANN { + // concurrently safe with RocksDBIO + class PersistentBuffer + { + public: + PersistentBuffer(std::shared_ptr db) : db(db), _size(0) { } + + ~PersistentBuffer() {} + + inline int GetNewAssignmentID() { return _size++; } + + inline int PutAssignment(std::string& assignment) + { + int assignmentID = GetNewAssignmentID(); + db->Put(assignmentID, assignment, MaxTimeout, nullptr); + return assignmentID; + } + + inline bool StartToScan(std::string& assignment) + { + SizeType newSize = 0; + if (db->StartToScan(newSize, &assignment) != ErrorCode::Success) return false; + _size = newSize+1; + return true; + } + + inline bool NextToScan(std::string& assignment) + { + SizeType newSize = 0; + if (db->NextToScan(newSize, &assignment) != ErrorCode::Success) return false; + _size = newSize+1; + return true; + } + + inline void ClearPreviousRecord() + { + db->DeleteRange(0, _size.load()); + _size = 0; + } + + inline int StopPB() + { + db->ShutDown(); + return 0; + } + + private: + std::shared_ptr db; + std::atomic_int _size; + }; + } +} \ No newline at end of file diff --git a/AnnService/inc/Core/SPANN/SPANNResultIterator.h b/AnnService/inc/Core/SPANN/SPANNResultIterator.h new file mode 100644 index 000000000..fde992b74 --- /dev/null +++ b/AnnService/inc/Core/SPANN/SPANNResultIterator.h @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _SPTAG_SPANN_RESULT_ITERATOR_H +#define _SPTAG_SPANN_RESULT_ITERATOR_H + +#include + +#include "inc/Core/SPANN/Index.h" +#include "inc/Core/SearchQuery.h" +#include "inc/Core/ResultIterator.h" +#include "inc/Core/Common/WorkSpace.h" +#include "inc/Core/SPANN/IExtraSearcher.h" + +namespace SPTAG +{ + namespace SPANN + { + template + class SPANNResultIterator : public ResultIterator + { + public: + SPANNResultIterator(const Index* p_spannIndex, const VectorIndex* p_index, const void* p_target, + std::unique_ptr p_extraWorkspace, + int p_batch, int p_maxCheck): ResultIterator(p_index, p_target, false, p_batch, nullptr, p_maxCheck), + m_spannIndex(p_spannIndex), + m_extraWorkspace(std::move(p_extraWorkspace)), + m_errorCode(ErrorCode::Success) + { + m_headQueryResult = std::make_unique(p_target, p_batch, false); + } + + ~SPANNResultIterator() + { + Close(); + } + + virtual std::shared_ptr Next(int batch) + { + if (m_queryResult == nullptr) { + m_queryResult = std::make_unique(m_target, batch, true); + } + else if (batch <= m_queryResult->GetResultNum()) { + m_queryResult->SetResultNum(batch); + } + else { + batch = m_queryResult->GetResultNum(); + } + + m_queryResult->Reset(); + if (m_workspace == nullptr) return m_queryResult; + + int resultCount = 0; + m_errorCode = m_spannIndex->SearchIndexIterative(*m_headQueryResult, *m_queryResult, + (COMMON::WorkSpace*)GetWorkSpace(), m_extraWorkspace.get(), batch, resultCount, m_isFirstResult); + m_isFirstResult = false; + + for (int i = 0; i < resultCount; i++) + { + m_queryResult->GetResult(i)->RelaxedMono = m_extraWorkspace->m_relaxedMono; + } + m_queryResult->SetResultNum(resultCount); + return m_queryResult; + } + + virtual bool GetRelaxedMono() + { + if (m_extraWorkspace == nullptr) return false; + + return m_extraWorkspace->m_relaxedMono; + } + + virtual ErrorCode GetErrorCode() + { + return m_errorCode; + } + + virtual void Close() + { + ResultIterator::Close(); + if (m_extraWorkspace != nullptr) { + m_spannIndex->SearchIndexIterativeEnd(std::move(m_extraWorkspace)); + m_extraWorkspace = nullptr; + } + } + + private: + const Index* m_spannIndex; + std::unique_ptr m_headQueryResult; + std::unique_ptr m_extraWorkspace; + ErrorCode m_errorCode; + }; + }// namespace SPANN +} // namespace SPTAG +#endif diff --git a/AnnService/inc/Core/SearchQuery.h b/AnnService/inc/Core/SearchQuery.h index 20a841100..531cca483 100644 --- a/AnnService/inc/Core/SearchQuery.h +++ b/AnnService/inc/Core/SearchQuery.h @@ -22,12 +22,18 @@ class QueryResult : m_target(nullptr), m_resultNum(0), m_withMeta(false), - m_quantizedTarget(nullptr) + m_quantizedTarget(nullptr), + m_quantizedSize(0), + m_scanned(0) { } + QueryResult(int p_resultNum) + { + Init(nullptr, p_resultNum, true); + } - QueryResult(const void* p_target, int p_resultNum, bool p_withMeta) : m_quantizedTarget(nullptr), m_quantizedSize(0) + QueryResult(const void* p_target, int p_resultNum, bool p_withMeta) { Init(p_target, p_resultNum, p_withMeta); } @@ -37,8 +43,9 @@ class QueryResult : m_target(p_target), m_resultNum(p_resultNum), m_withMeta(p_withMeta), - m_quantizedTarget(nullptr), - m_quantizedSize(0) + m_quantizedTarget((void*)p_target), + m_quantizedSize(0), + m_scanned(0) { m_results.Set(p_results, p_resultNum, false); } @@ -51,7 +58,7 @@ class QueryResult { std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data()); } - if (p_other.m_quantizedTarget) + if (p_other.m_target != p_other.m_quantizedTarget) { m_quantizedSize = p_other.m_quantizedSize; m_quantizedTarget = ALIGN_ALLOC(m_quantizedSize); @@ -62,19 +69,26 @@ class QueryResult QueryResult& operator=(const QueryResult& p_other) { + if (m_target != m_quantizedTarget) ALIGN_FREE(m_quantizedTarget); + Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta); if (m_resultNum > 0) { std::copy(p_other.m_results.Data(), p_other.m_results.Data() + m_resultNum, m_results.Data()); } - + if (p_other.m_target != p_other.m_quantizedTarget) + { + m_quantizedSize = p_other.m_quantizedSize; + m_quantizedTarget = ALIGN_ALLOC(m_quantizedSize); + std::copy(reinterpret_cast(p_other.m_quantizedTarget), reinterpret_cast(p_other.m_quantizedTarget) + m_quantizedSize, reinterpret_cast(m_quantizedTarget)); + } return *this; } ~QueryResult() { - if (m_quantizedTarget) + if (m_target != m_quantizedTarget) { ALIGN_FREE(m_quantizedTarget); } @@ -86,7 +100,9 @@ class QueryResult m_target = p_target; m_resultNum = p_resultNum; m_withMeta = p_withMeta; - m_quantizedTarget = nullptr; + m_quantizedTarget = (void*)p_target; + m_quantizedSize = 0; + m_scanned = 0; m_results = Array::Alloc(p_resultNum); } @@ -98,22 +114,49 @@ class QueryResult } + inline void SetResultNum(int p_resultNum) + { + m_resultNum = p_resultNum; + } + + inline const void* GetTarget() { return m_target; } + inline const void* GetQuantizedTarget() + { + return m_quantizedTarget; + } + + inline void SetTarget(const void* p_target) { - m_target = p_target; - if (m_quantizedTarget) + if (m_target != m_quantizedTarget) { ALIGN_FREE(m_quantizedTarget); } - m_quantizedTarget = nullptr; + m_target = p_target; + m_quantizedTarget = (void*)p_target; m_quantizedSize = 0; } + + + inline bool HasQuantizedTarget() + { + return m_target != m_quantizedTarget; + } + + + inline void CleanQuantizedTarget() + { + if (m_target != m_quantizedTarget) { + ALIGN_FREE(m_quantizedTarget); + m_quantizedTarget = (void*)m_target; + } + } inline BasicResult* GetResult(int i) const @@ -174,14 +217,15 @@ class QueryResult } } - inline void ClearTmp() + inline void SetScanned(int p_scanned) { - if (m_quantizedTarget) { - ALIGN_FREE(m_quantizedTarget); - m_quantizedTarget = nullptr; - } + m_scanned = p_scanned; } + inline int GetScanned() + { + return m_scanned; + } iterator begin() { @@ -219,6 +263,8 @@ class QueryResult bool m_withMeta; Array m_results; + + int m_scanned; }; } // namespace SPTAG diff --git a/AnnService/inc/Core/SearchResult.h b/AnnService/inc/Core/SearchResult.h index a84d4910d..8c458fe3b 100644 --- a/AnnService/inc/Core/SearchResult.h +++ b/AnnService/inc/Core/SearchResult.h @@ -67,12 +67,14 @@ namespace SPTAG SizeType VID; float Dist; ByteArray Meta; + bool RelaxedMono; - BasicResult() : VID(-1), Dist(MaxDist) {} + BasicResult() : VID(-1), Dist(MaxDist), RelaxedMono(false) {} - BasicResult(SizeType p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {} + BasicResult(SizeType p_vid, float p_dist) : VID(p_vid), Dist(p_dist), RelaxedMono(false) {} - BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta) : VID(p_vid), Dist(p_dist), Meta(p_meta) {} + BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta) : VID(p_vid), Dist(p_dist), Meta(p_meta), RelaxedMono(false) {} + BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta, bool p_relaxedMono) : VID(p_vid), Dist(p_dist), Meta(p_meta), RelaxedMono(p_relaxedMono) {} }; } // namespace SPTAG diff --git a/AnnService/inc/Core/VectorIndex.h b/AnnService/inc/Core/VectorIndex.h index 19896251c..953bf9618 100644 --- a/AnnService/inc/Core/VectorIndex.h +++ b/AnnService/inc/Core/VectorIndex.h @@ -4,17 +4,22 @@ #ifndef _SPTAG_VECTORINDEX_H_ #define _SPTAG_VECTORINDEX_H_ +#include #include "Common.h" +#include "Common/WorkSpace.h" +#include "inc/Helper/DiskIO.h" #include "SearchQuery.h" #include "VectorSet.h" #include "MetadataSet.h" #include "inc/Helper/SimpleIniReader.h" -#include #include "inc/Core/Common/IQuantizer.h" +class ResultIterator; + namespace SPTAG { +extern std::shared_ptr(*f_createIO)(); class IAbortOperation { public: @@ -31,19 +36,35 @@ class VectorIndex virtual ErrorCode BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized = false, bool p_shareOwnership = false) = 0; virtual ErrorCode AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex = false, bool p_normalized = false) = 0; + virtual ErrorCode AddIndexId(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, int& beginHead, int& endHead) { return ErrorCode::Undefined; } + virtual ErrorCode AddIndexIdx(SizeType begin, SizeType end) { return ErrorCode::Undefined; } + virtual ErrorCode DeleteIndex(const void* p_vectors, SizeType p_vectorNum) = 0; virtual ErrorCode SearchIndex(QueryResult& p_results, bool p_searchDeleted = false) const = 0; + virtual std::shared_ptr GetIterator(const void* p_target, bool p_searchDeleted = false, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const = 0; + + virtual ErrorCode SearchIndexIterativeNext(QueryResult& p_query, COMMON::WorkSpace* workSpace, int p_batch, int& resultCount, bool p_isFirst, bool p_searchDeleted) const = 0; + + virtual ErrorCode SearchIndexIterativeEnd(std::unique_ptr workSpace) const = 0; + + virtual bool SearchIndexIterativeFromNeareast(QueryResult& p_query, COMMON::WorkSpace* p_space, bool p_isFirst, bool p_searchDeleted = false) const = 0; + + virtual std::unique_ptr RentWorkSpace(int batch, std::function p_filterFunc = nullptr, int p_maxCheck = 0) const = 0; + virtual ErrorCode RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted = false) const = 0; + virtual ErrorCode SearchIndexWithFilter(QueryResult& p_query, std::function filterFunc, int maxCheck = 0, bool p_searchDeleted = false) const = 0; + virtual ErrorCode SearchTree(QueryResult &p_query) const = 0; virtual ErrorCode RefineIndex(std::shared_ptr& p_newIndex) = 0; virtual float AccurateDistance(const void* pX, const void* pY) const = 0; virtual float ComputeDistance(const void* pX, const void* pY) const = 0; + virtual float GetDistance(const void* target, const SizeType idx) const = 0; virtual const void* GetSample(const SizeType idx) const = 0; virtual bool ContainSample(const SizeType idx) const = 0; virtual bool NeedRefine() const = 0; @@ -113,6 +134,28 @@ class VectorIndex virtual ErrorCode LoadQuantizer(std::string p_quantizerFile); + virtual std::shared_ptr GetQuantizer() { + return m_pQuantizer; + } + + virtual ErrorCode QuantizeVector(const void* p_data, SizeType p_num, ByteArray p_out) { + if (m_pQuantizer != nullptr && p_out.Length() >= m_pQuantizer->GetNumSubvectors() * (size_t)p_num) { + for (int i = 0; i < p_num; i++) + m_pQuantizer->QuantizeVector(((std::uint8_t*)p_data) + i * (size_t)(m_pQuantizer->ReconstructSize()), p_out.Data() + i * (size_t)(m_pQuantizer->GetNumSubvectors()), false); + return ErrorCode::Success; + } + return ErrorCode::Fail; + } + + virtual ErrorCode ReconstructVector(const void* p_data, SizeType p_num, ByteArray p_out) { + if (m_pQuantizer != nullptr && p_out.Length() >= m_pQuantizer->ReconstructSize() * (size_t)p_num) { + for (int i = 0; i < p_num; i++) + m_pQuantizer->ReconstructVector(((std::uint8_t*)p_data) + i * (size_t)(m_pQuantizer->GetNumSubvectors()), p_out.Data() + i * (size_t)(m_pQuantizer->ReconstructSize())); + return ErrorCode::Success; + } + return ErrorCode::Fail; + } + static std::shared_ptr CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype); static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex); @@ -125,6 +168,8 @@ class VectorIndex static std::uint64_t EstimatedMemoryUsage(std::uint64_t p_vectorCount, DimensionType p_dimension, VectorValueType p_valuetype, SizeType p_vectorsInBlock, SizeType p_maxmeta, IndexAlgoType p_algo, int p_treeNumber, int p_neighborhoodSize); + virtual std::shared_ptr Clone(std::string p_clone); + virtual std::shared_ptr> BufferSize() const = 0; virtual std::shared_ptr> GetIndexFiles() const = 0; @@ -141,7 +186,9 @@ class VectorIndex virtual ErrorCode DeleteIndex(const SizeType& p_id) = 0; - virtual ErrorCode RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort) = 0; + virtual ErrorCode RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort, std::vector* p_mapping) = 0; + + virtual ErrorCode SetWorkSpaceFactory(std::unique_ptr> up_workSpaceFactory) = 0; inline bool HasMetaMapping() const { return nullptr != m_pMetaToVec; } @@ -151,7 +198,14 @@ class VectorIndex void BuildMetaMapping(bool p_checkDeleted = true); -private: + virtual ErrorCode Check() + { + return ErrorCode::Undefined; + } + + virtual std::string GetPriorityID(int queryID) const { return ""; } + + private: ErrorCode LoadIndexConfig(Helper::IniReader& p_reader); ErrorCode SaveIndexConfig(std::shared_ptr p_configOut); diff --git a/AnnService/inc/Helper/ArgumentsParser.h b/AnnService/inc/Helper/ArgumentsParser.h index 716e5abe9..6f97d31c6 100644 --- a/AnnService/inc/Helper/ArgumentsParser.h +++ b/AnnService/inc/Helper/ArgumentsParser.h @@ -118,7 +118,7 @@ class ArgumentsParser std::size_t padding = 40; if (!m_representStringShort.empty()) { - LOG(Helper::LogLevel::LL_Empty, "%s", m_representStringShort.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "%s", m_representStringShort.c_str()); padding -= m_representStringShort.size(); } @@ -126,26 +126,26 @@ class ArgumentsParser { if (!m_representStringShort.empty()) { - LOG(Helper::LogLevel::LL_Empty, ", "); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, ", "); padding -= 2; } - LOG(Helper::LogLevel::LL_Empty, "%s", m_representString.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "%s", m_representString.c_str()); padding -= m_representString.size(); } if (m_followedValue) { - LOG(Helper::LogLevel::LL_Empty, " "); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, " "); padding -= 8; } while (padding-- > 0) { - LOG(Helper::LogLevel::LL_Empty, " "); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, " "); } - LOG(Helper::LogLevel::LL_Empty, "%s", m_description.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "%s", m_description.c_str()); } diff --git a/AnnService/inc/Helper/AsyncFileReader.h b/AnnService/inc/Helper/AsyncFileReader.h index 8a46f94a4..46f63f2f2 100644 --- a/AnnService/inc/Helper/AsyncFileReader.h +++ b/AnnService/inc/Helper/AsyncFileReader.h @@ -15,60 +15,60 @@ #include #include -#define ASYNC_READ 1 - #ifdef _MSC_VER #include #include -#define BATCH_READ 1 #else -#define BATCH_READ 1 #include #include #include +#ifdef URING +#include +#endif +#ifdef NUMA +#include +#endif #endif +#define ASYNC_READ 1 +#define BATCH_READ 1 + namespace SPTAG { namespace Helper { - -#ifdef _MSC_VER - namespace DiskUtils + struct AsyncReadRequest { + std::uint64_t m_offset; + std::uint64_t m_readSize; + char* m_buffer; + std::function m_callback; + int m_status; - struct PrioritizedDiskFileReaderResource; + // Carry items like counter for callback to process. + void* m_payload; + bool m_success; - struct CallbackOverLapped : public OVERLAPPED - { - PrioritizedDiskFileReaderResource* const c_registeredResource; + // Carry exension metadata needed by some DiskIO implementations + void* m_extension; + void* m_ctrl; - void* m_data; - - CallbackOverLapped(PrioritizedDiskFileReaderResource* p_registeredResource) - : c_registeredResource(p_registeredResource), - m_data(nullptr) - { - } - }; - - - struct PrioritizedDiskFileReaderResource - { - CallbackOverLapped m_col; +#ifdef _MSC_VER + DiskUtils::PrioritizedDiskFileReaderResource myres; +#else + struct iocb myiocb; +#endif - PrioritizedDiskFileReaderResource() - : m_col(this) - { - } - }; - } + AsyncReadRequest() : m_offset(0), m_readSize(0), m_buffer(nullptr), m_status(0), m_payload(nullptr), m_success(false), m_extension(nullptr) {} + }; + void SetThreadAffinity(int threadID, std::thread& thread, NumaStrategy socketStrategy = NumaStrategy::LOCAL, OrderStrategy idStrategy = OrderStrategy::ASC); +#ifdef _MSC_VER class HandleWrapper { public: HandleWrapper(HANDLE p_handle) : m_handle(p_handle) {} - HandleWrapper(HandleWrapper&& p_right) : m_handle(std::move(p_right.m_handle)) {} + HandleWrapper(HandleWrapper&& p_right) noexcept: m_handle(std::move(p_right.m_handle)) {} HandleWrapper() : m_handle(INVALID_HANDLE_VALUE) {} ~HandleWrapper() {} @@ -105,8 +105,6 @@ namespace SPTAG m_handle.Reset(::CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0)); } - void reset(int capacity) {} - void push(AsyncReadRequest* j) { ::PostQueuedCompletionStatus(m_handle.GetHandle(), 0, @@ -128,6 +126,20 @@ namespace SPTAG return true; } + bool try_pop(AsyncReadRequest*& j) { + DWORD cBytes; + ULONG_PTR key; + OVERLAPPED* ol; + BOOL ret = ::GetQueuedCompletionStatus(m_handle.GetHandle(), + &cBytes, + &key, + &ol, + 0); + if (FALSE == ret || nullptr == ol) return false; + j = reinterpret_cast(ol); + return true; + } + private: HandleWrapper m_handle; }; @@ -139,46 +151,114 @@ namespace SPTAG virtual ~AsyncFileIO() { ShutDown(); } + virtual bool Available() + { + return !m_shutdown; + } + + virtual bool NewFile(const char* filePath, uint64_t maxFileSize) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO: Create a file\n"); + m_fileHandle.Reset(::CreateFile(filePath, + GENERIC_WRITE, + 0, + NULL, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + NULL)); + if (!m_fileHandle.IsValid()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::InitializeFileIo failed\n"); + return false; + } + + LARGE_INTEGER pos = { 0 }; + pos.QuadPart = maxFileSize; + if (!SetFilePointerEx(m_fileHandle.GetHandle(), pos, NULL, FILE_BEGIN) || + !SetEndOfFile(m_fileHandle.GetHandle())) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::InitializeFileIo failed: fallocate failed\n"); + m_fileHandle.Close(); + return false; + } + + m_fileHandle.Close(); + m_currSize = (uint64_t)filesize(filePath); + if (m_currSize != maxFileSize) SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot fallocate enough space actural size (%llu) < max size (%llu)\n", m_currSize, maxFileSize); + return true; + } + + virtual bool ExpandFile(uint64_t expandSize) { + LARGE_INTEGER new_pos = { 0 }; + new_pos.QuadPart = m_currSize + expandSize; + if (!SetFilePointerEx(m_fileHandle.GetHandle(), new_pos, NULL, FILE_BEGIN) || + !SetEndOfFile(m_fileHandle.GetHandle())) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "AsyncFileReader:[ExpandFile] fallocate failed at offset %lld for size %lld bytes\n", + static_cast(m_currSize), + static_cast(expandSize)); + m_fileHandle.Close(); + return false; + } + m_currSize += expandSize; + return true; + } + virtual bool Initialize(const char* filePath, int openMode, - std::uint64_t maxIOSize = (1 << 20), + std::uint64_t maxNumBlocks = (1 << 20), std::uint32_t maxReadRetries = 2, std::uint32_t maxWriteRetries = 2, - std::uint16_t threadPoolSize = 4) + std::uint16_t threadPoolSize = 4, + std::uint64_t maxFileSize = (300ULL << 30)) { + if (!fileexists(filePath)) { + if (openMode == GENERIC_READ) { + SPTAGLIB_LOG(LogLevel::LL_Error, "Failed to open file handle: %s\n", filePath); + return false; + } + if (!NewFile(filePath, maxFileSize)) return false; + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::Open a file\n"); + m_currSize = filesize(filePath); + if (openMode == GENERIC_READ || m_currSize >= maxFileSize) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::InitializeFileIo: file has been created with enough space.\n"); + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::Open failed! currSize(%llu) < maxFileSize(%llu)\n", m_currSize, maxFileSize); + if (!NewFile(filePath, maxFileSize)) return false; + } + } + m_fileHandle.Reset(::CreateFile(filePath, - GENERIC_READ, + GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED, NULL)); - if (!m_fileHandle.IsValid()) return false; + if (!m_fileHandle.IsValid()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::Open failed for %s!\n", filePath); + return false; + } m_diskSectorSize = static_cast(GetSectorSize(filePath)); - LOG(LogLevel::LL_Info, "Success open file handle: %s DiskSectorSize: %u\n", filePath, m_diskSectorSize); - - PreAllocQueryContext(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::InitializeFileIo: file %s opened, DiskSectorSize=%u threads=%d maxNumBlocks=%d\n", filePath, m_diskSectorSize, threadPoolSize, maxNumBlocks); -#ifndef BATCH_READ + m_shutdown = false; int iocpThreads = threadPoolSize; m_fileIocp.Reset(::CreateIoCompletionPort(m_fileHandle.GetHandle(), NULL, NULL, iocpThreads)); for (int i = 0; i < iocpThreads; ++i) { - m_fileIocpThreads.emplace_back(std::thread(std::bind(&AsyncFileIO::ListionIOCP, this))); + m_fileIocpThreads.emplace_back(std::thread(std::bind(&AsyncFileIO::ListionIOCP, this, i))); } return m_fileIocp.IsValid(); -#else - return true; -#endif } virtual std::uint64_t ReadBinary(std::uint64_t readSize, char* buffer, std::uint64_t offset = UINT64_MAX) { - ResourceType* resource = GetResource(); + ResourceType resource; - DiskUtils::CallbackOverLapped& col = resource->m_col; - memset(&col, 0, sizeof(col)); + DiskUtils::CallbackOverLapped& col = resource.m_col; + memset(&col, 0, sizeof(OVERLAPPED)); col.Offset = (offset & 0xffffffff); col.OffsetHigh = (offset >> 32); col.m_data = nullptr; @@ -189,14 +269,12 @@ namespace SPTAG nullptr, &col) && GetLastError() != ERROR_IO_PENDING) { if (col.hEvent) CloseHandle(col.hEvent); - ReturnResource(resource); return 0; } DWORD cBytes = 0; ::GetOverlappedResult(m_fileHandle.GetHandle(), &col, &cBytes, TRUE); if (col.hEvent) CloseHandle(col.hEvent); - ReturnResource(resource); return cBytes; } @@ -217,76 +295,126 @@ namespace SPTAG virtual bool ReadFileAsync(AsyncReadRequest& readRequest) { - ResourceType* resource = GetResource(); - - DiskUtils::CallbackOverLapped& col = resource->m_col; - memset(&col, 0, sizeof(col)); + DiskUtils::CallbackOverLapped& col = readRequest.myres.m_col; col.Offset = (readRequest.m_offset & 0xffffffff); col.OffsetHigh = (readRequest.m_offset >> 32); - col.m_data = (void*)&readRequest; if (!::ReadFile(m_fileHandle.GetHandle(), readRequest.m_buffer, - static_cast(readRequest.m_readSize), + static_cast(ROUND_UP(readRequest.m_readSize, PageSize)), nullptr, &col) && GetLastError() != ERROR_IO_PENDING) { - ReturnResource(resource); return false; } return true; } - virtual bool BatchReadFile(AsyncReadRequest* readRequests, std::uint32_t requestCount) + virtual void Wait(AsyncReadRequest& readRequest) { - std::vector waitResources; - for (std::uint32_t i = 0; i < requestCount; i++) { - AsyncReadRequest* readRequest = &(readRequests[i]); - - ResourceType* resource = GetResource(); - DiskUtils::CallbackOverLapped& col = resource->m_col; - memset(&col, 0, sizeof(col)); - col.Offset = (readRequest->m_offset & 0xffffffff); - col.OffsetHigh = (readRequest->m_offset >> 32); - col.m_data = (void*)readRequest; - - col.hEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - if (!::ReadFile(m_fileHandle.GetHandle(), - readRequest->m_buffer, - static_cast(readRequest->m_readSize), - nullptr, - &col) && GetLastError() != ERROR_IO_PENDING) - { - if (col.hEvent) CloseHandle(col.hEvent); - ReturnResource(resource); - LOG(Helper::LogLevel::LL_Error, "Cannot read request %d error:%u\n", i, GetLastError()); + // currently not used anywhere, effective only when ASYNC_READ is defined and BATCH_READ is not defined + throw std::runtime_error("Not implemented"); + } + + virtual std::uint32_t BatchReadFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) + { + if (requestCount <= 0) return 0; + + auto t1 = std::chrono::high_resolution_clock::now(); + if (batchSize < 0 || batchSize > requestCount) batchSize = requestCount; + + RequestQueue* completeQueue = reinterpret_cast(readRequests[0].m_extension); + uint32_t batchTotalDone = 0, skip = 0; + uint32_t currId = 0; + AsyncReadRequest* req; + while (currId < requestCount) { + int totalSubmitted = 0, totalDone = 0; + uint32_t batchStart = currId; + while (currId < requestCount && totalSubmitted < batchSize) { + AsyncReadRequest* readRequest = &(readRequests[currId++]); + if (readRequest->m_readSize == 0) { + skip++; + continue; + } + if (ReadFileAsync(*readRequest)) totalSubmitted++; + } - else { - waitResources.push_back(resource); + while (totalDone < totalSubmitted) { + if (!completeQueue->pop(req)) break; + totalDone++; + if (req->m_callback) req->m_callback(true); + } + + batchTotalDone += totalDone; + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) { + if (batchTotalDone < requestCount - skip) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "AsyncFileReader::ReadBlocks (batch[%u:%u]) : timeout, continue for next batch...\n", batchStart, currId); + } + //break; } } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::ReadBlocks: finish\n"); + return batchTotalDone; + } - for (int i = 0; i < waitResources.size(); i++) { - DWORD readbytes = 0; - DiskUtils::CallbackOverLapped* col = &(waitResources[i]->m_col); - AsyncReadRequest* readRequest = (AsyncReadRequest*)(col->m_data); - - if (!::GetOverlappedResult(m_fileHandle.GetHandle(), col, &readbytes, TRUE) || readbytes != readRequest->m_readSize) { - LOG(Helper::LogLevel::LL_Error, "Cannot wait request %d readbytes:%llu expect:%llu error:%u\n", i, readbytes, readRequest->m_readSize, GetLastError()); + virtual std::uint32_t BatchWriteFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) + { + if (requestCount <= 0) return 0; + + auto t1 = std::chrono::high_resolution_clock::now(); + if (batchSize < 0 || batchSize > requestCount) batchSize = requestCount; + + RequestQueue* completeQueue = reinterpret_cast(readRequests[0].m_extension); + uint32_t batchTotalDone = 0; + uint32_t currId = 0; + AsyncReadRequest* req; + while (currId < requestCount) { + int totalSubmitted = 0, totalDone = 0; + uint32_t batchStart = currId; + while (currId < requestCount && totalSubmitted < batchSize) { + AsyncReadRequest* readRequest = &(readRequests[currId++]); + + DiskUtils::CallbackOverLapped& col = readRequest->myres.m_col; + col.Offset = (readRequest->m_offset & 0xffffffff); + col.OffsetHigh = (readRequest->m_offset >> 32); + + if (!::WriteFile(m_fileHandle.GetHandle(), + readRequest->m_buffer, + static_cast(PageSize), + nullptr, + &col) && GetLastError() != ERROR_IO_PENDING) + { + continue; + } + totalSubmitted++; } - else { - readRequest->m_success = true; + while (totalDone < totalSubmitted) { + if (!completeQueue->pop(req)) break; + totalDone++; + if (req->m_callback) req->m_callback(true); + } + + batchTotalDone += totalDone; + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) { + if (batchTotalDone < requestCount) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "AsyncFileReader::WriteBlocks (batch[%u:%u]) : timeout, continue for next batch...\n", batchStart, currId); + } + //break; } - if (col->hEvent) CloseHandle(col->hEvent); - ReturnResource(waitResources[i]); } - return true; + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::WriteBlocks: %d reqs finish\n", requestCount); + return batchTotalDone; } virtual std::uint64_t TellP() { return 0; } virtual void ShutDown() { + if (m_shutdown) return; + + m_shutdown = true; m_fileHandle.Close(); m_fileIocp.Close(); @@ -297,15 +425,7 @@ namespace SPTAG th.join(); } } - - ResourceType* res = nullptr; - while (m_resources.try_pop(res)) - { - if (res != nullptr) - { - delete res; - } - } + m_fileIocpThreads.clear(); } private: @@ -388,14 +508,17 @@ namespace SPTAG // Display the error message and exit the process - LOG(Helper::LogLevel::LL_Error, "Failed with: %s\n", (char*)lpMsgBuf); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed with: %s\n", (char*)lpMsgBuf); LocalFree(lpMsgBuf); - ExitProcess(dw); + ExitProcess(dw); + ShutDown(); } - void ListionIOCP() + void ListionIOCP(int i) { + SetThreadAffinity(i, m_fileIocpThreads[i], NumaStrategy::SCATTER, OrderStrategy::DESC); // avoid IO threads overlap with search threads + DWORD cBytes; ULONG_PTR key; OVERLAPPED* ol; @@ -410,144 +533,140 @@ namespace SPTAG INFINITE); if (FALSE == ret || nullptr == ol) { - //We exit the thread + //SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ListionIOCP thread failed!\n"); return; } col = (DiskUtils::CallbackOverLapped*)ol; auto req = static_cast(col->m_data); - ReturnResource(col->c_registeredResource); if (nullptr != req) { - req->m_callback(req); + reinterpret_cast(req->m_extension)->push(req); } } } - ResourceType* GetResource() - { - ResourceType* ret = nullptr; - if (m_resources.try_pop(ret)) - { - } - else - { - ret = new ResourceType(); - } - - return ret; - } - - void ReturnResource(ResourceType* p_res) - { - if (p_res != nullptr) - { - m_resources.push(p_res); - } - } - - void PreAllocQueryContext() - { - const size_t num = 64 * 64; - typedef ResourceType* ResourcePtr; - - ResourcePtr* contextArray = new ResourcePtr[num]; - - for (int i = 0; i < num; ++i) - { - contextArray[i] = GetResource(); - } - - for (int i = 0; i < num; ++i) - { - ReturnResource(contextArray[i]); - contextArray[i] = nullptr; - } - - delete[] contextArray; - - } - private: + bool m_shutdown = true; + HandleWrapper m_fileHandle; + uint64_t m_currSize; + HandleWrapper m_fileIocp; std::vector m_fileIocpThreads; - uint32_t m_diskSectorSize; - - Helper::Concurrent::ConcurrentQueue m_resources; + uint32_t m_diskSectorSize = 512; }; #else extern struct timespec AIOTimeout; - class RequestQueue + using RequestQueue = Helper::Concurrent::ConcurrentQueue; + + class AsyncFileIO : public DiskIO { public: + AsyncFileIO(DiskIOScenario scenario = DiskIOScenario::DIS_UserRead) {} - RequestQueue() :m_front(0), m_end(0), m_capacity(0) {} + virtual ~AsyncFileIO() { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader: Destroying fd=%d!\n", m_fileHandle); + ShutDown(); + } - ~RequestQueue() {} + virtual bool Available() + { + return !m_shutdown; + } - void reset(int capacity) { - if (capacity > m_capacity) { - m_capacity = capacity + 1; - m_queue.reset(new AsyncReadRequest * [m_capacity]); + virtual bool NewFile(const char* filePath, uint64_t maxFileSize) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO: Create a file\n"); + m_fileHandle = open(filePath, O_CREAT | O_WRONLY, 0666); + if (m_fileHandle == -1) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::InitializeFileIo failed\n"); + return false; } - } - void push(AsyncReadRequest* j) - { - m_queue[m_end++] = j; - if (m_end == m_capacity) m_end = 0; + if (fallocate(m_fileHandle, 0, 0, maxFileSize) == -1) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::InitializeFileIo failed: fallocate failed\n"); + m_fileHandle = -1; + return false; + } + close(m_fileHandle); + m_currSize = filesize(filePath); + if (m_currSize != maxFileSize) SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot fallocate enough space actural size (%llu) < max size (%llu)\n", m_currSize, maxFileSize); + return true; } - bool pop(AsyncReadRequest*& j) - { - while (m_front == m_end) usleep(AIOTimeout.tv_nsec / 1000); - j = m_queue[m_front++]; - if (m_front == m_capacity) m_front = 0; + virtual bool ExpandFile(uint64_t expandSize) { + if (fallocate(m_fileHandle, 0, m_currSize, expandSize) != 0) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "AsyncFileReader:[ExpandFile] fallocate failed at offset %lld for size %lld bytes: %s\n", + static_cast(m_currSize), + static_cast(expandSize), + strerror(errno)); + return false; + } + m_currSize += expandSize; return true; } - protected: - int m_front, m_end, m_capacity; - std::unique_ptr m_queue; - }; - - class AsyncFileIO : public DiskIO - { - public: - AsyncFileIO(DiskIOScenario scenario = DiskIOScenario::DIS_UserRead) {} - - virtual ~AsyncFileIO() { ShutDown(); } - virtual bool Initialize(const char* filePath, int openMode, - std::uint64_t maxIOSize = (1 << 20), + std::uint64_t maxNumBlocks = (1 << 20), std::uint32_t maxReadRetries = 2, std::uint32_t maxWriteRetries = 2, - std::uint16_t threadPoolSize = 4) + std::uint16_t threadPoolSize = 4, + std::uint64_t maxFileSize = (300ULL << 30)) { - m_fileHandle = open(filePath, O_RDONLY | O_DIRECT); - if (m_fileHandle <= 0) { - LOG(LogLevel::LL_Error, "Failed to create file handle: %s\n", filePath); + if (!fileexists(filePath)) { + if (openMode == (O_RDONLY | O_DIRECT)) { + SPTAGLIB_LOG(LogLevel::LL_Error, "Failed to open file handle: %s\n", filePath); + return false; + } + if (!NewFile(filePath, maxFileSize)) return false; + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::Open a file\n"); + m_currSize = filesize(filePath); + if (openMode == (O_RDONLY | O_DIRECT) || m_currSize >= maxFileSize) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::InitializeFileIo: file has been created with enough space.\n"); + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::Open failed! currSize(%llu) < maxFileSize(%llu)\n", m_currSize, maxFileSize); + if (!NewFile(filePath, maxFileSize)) return false; + } + } + m_fileHandle = open(filePath, O_RDWR | O_DIRECT); + if (m_fileHandle == -1) { + auto err_str = strerror(errno); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileIO::Open failed: %s\n", err_str); return false; } - + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::InitializeFileIo: file %s opened, fd=%d threads=%d maxNumBlocks=%d\n", filePath, m_fileHandle, threadPoolSize, maxNumBlocks); m_iocps.resize(threadPoolSize); +#ifdef URING + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileIO::InitializeFileIo: using io uring for read!\n"); + m_uring.resize(threadPoolSize); +#endif memset(m_iocps.data(), 0, sizeof(aio_context_t) * threadPoolSize); for (int i = 0; i < threadPoolSize; i++) { - auto ret = syscall(__NR_io_setup, (int)maxIOSize, &(m_iocps[i])); + auto ret = syscall(__NR_io_setup, (int)maxNumBlocks, &(m_iocps[i])); if (ret < 0) { - LOG(LogLevel::LL_Error, "Cannot setup aio: %s\n", strerror(errno)); + SPTAGLIB_LOG(LogLevel::LL_Error, "Cannot setup aio: %s\n", strerror(errno)); + return false; + } +#ifdef URING + ret = io_uring_queue_init((int)maxNumBlocks, &m_uring[i], 0); + if (ret < 0) + { + SPTAGLIB_LOG(LogLevel::LL_Error, "Cannot setup io_uring: %s\n", strerror(-ret)); return false; } +#endif } + m_shutdown = false; #ifndef BATCH_READ - m_shutdown = false; for (int i = 0; i < threadPoolSize; ++i) { m_fileIocpThreads.emplace_back(std::thread(std::bind(&AsyncFileIO::ListionIOCP, this, i))); @@ -576,19 +695,220 @@ namespace SPTAG return 0; } + virtual std::uint32_t BatchReadFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) + { + if (requestCount <= 0) return 0; + + auto t1 = std::chrono::high_resolution_clock::now(); + if (batchSize < 0 || batchSize > requestCount) batchSize = requestCount; + std::vector iocbs(batchSize); + std::vector events(batchSize); + if (readRequests[0].m_status < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "The iocp number is not enough! Please increase iothread number.\n"); + return 0; + } + int iocp = readRequests[0].m_status % m_iocps.size(); + //struct timespec timeout_ts {0, 0}; + //while (syscall(__NR_io_getevents, m_iocps[iocp], batchSize, batchSize, events.data(), &timeout_ts) > 0); + + uint32_t realCount = 0; + for (int i = 0; i < requestCount; i++) { + AsyncReadRequest* readRequest = readRequests + i; +#ifndef URING + struct iocb* myiocb = &(readRequest->myiocb); + memset(myiocb, 0, sizeof(struct iocb)); + myiocb->aio_data = reinterpret_cast(&readRequest); + myiocb->aio_lio_opcode = IOCB_CMD_PREAD; + myiocb->aio_fildes = m_fileHandle; + myiocb->aio_buf = (std::uint64_t)(readRequest->m_buffer); + myiocb->aio_nbytes = PageSize; + myiocb->aio_offset = static_cast(readRequest->m_offset); +#endif + if (readRequest->m_readSize > 0) realCount++; + } + uint32_t batchTotalDone = 0; + uint32_t reqidx = 0; + for (int currSubIoStartId = 0; currSubIoStartId < realCount; currSubIoStartId += batchSize) { + int currSubIoEndId = (currSubIoStartId + batchSize) > realCount ? realCount : currSubIoStartId + batchSize; + int totalToSubmit = currSubIoEndId - currSubIoStartId; + int totalSubmitted = 0, totalDone = 0; + for (int i = 0; i < totalToSubmit; i++) { + while (reqidx < requestCount && readRequests[reqidx].m_readSize == 0) reqidx++; + if (reqidx >= requestCount) SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: error reqidx(%d) >= requestCount(%d)\n", reqidx, requestCount); + else { +#ifdef URING + auto sqe = io_uring_get_sqe(&m_uring[iocp]); + if (!sqe) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_uring_get_sqe failed\n"); + return 0; + } + io_uring_prep_read(sqe, m_fileHandle, readRequests[reqidx].m_buffer, PageSize, + readRequests[reqidx].m_offset); + io_uring_sqe_set_data(sqe, &(readRequests[reqidx])); +#else + iocbs[i] = &(readRequests[reqidx].myiocb); +#endif + reqidx++; + } + } + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::ReadBlocks: iocp:%d totalToSubmit:%d\n", iocp, totalToSubmit); +#ifdef URING + if (io_uring_submit(&m_uring[iocp]) < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_uring_submit failed!\n"); + return 0; + } + struct io_uring_cqe *cqe; + for (auto i = 0; i < totalToSubmit; i++) + { + int ret = io_uring_wait_cqe(&m_uring[iocp], &cqe); + if (ret < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_uring_wait_cqe failed: %s\n", strerror(-ret)); + break; + } + ret = cqe->res; + io_uring_cqe_seen(&m_uring[iocp], cqe); + if (ret < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_uring_wait_cqe failed with res=%d: %s\n", ret, strerror(-ret)); + continue; + } + // AsyncReadRequest* readRequest = + // reinterpret_cast(io_uring_cqe_get_data(cqe)); + totalDone++; + } +#else + while (totalDone < totalToSubmit) { + // Submit all I/Os + if (totalSubmitted < totalToSubmit) { + int s = syscall(__NR_io_submit, m_iocps[iocp], totalToSubmit - totalSubmitted, iocbs.data() + totalSubmitted); + if (s > 0) { + totalSubmitted += s; + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_submit failed\n"); + return batchTotalDone; + } + } + int wait = totalSubmitted - totalDone; + auto d = syscall(__NR_io_getevents, m_iocps[iocp], wait, wait, events.data() + totalDone, &AIOTimeout); + if (d >= 0) { + totalDone += d; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::ReadBlocks: io_getevents failed\n"); + return batchTotalDone; + } + } +#endif + batchTotalDone += totalDone; + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) { + if (batchTotalDone < requestCount) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "AsyncFileReader::ReadBlocks (batch[%d:%d]) : timeout, continue for next batch...\n", currSubIoStartId, currSubIoEndId); + } + //break; + } + } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::ReadBlocks: finish\n"); + return batchTotalDone; + } + + virtual std::uint32_t BatchWriteFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) + { + if (requestCount <= 0) return 0; + + auto t1 = std::chrono::high_resolution_clock::now(); + if (batchSize < 0 || batchSize > requestCount) batchSize = requestCount; + + std::vector iocbs(batchSize); + std::vector events(batchSize); + if (readRequests[0].m_status < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "The iocp number is not enough! Please increase iothread number.\n"); + return 0; + } + int iocp = readRequests[0].m_status % m_iocps.size(); + //struct timespec timeout_ts {0, 0}; + //while (syscall(__NR_io_getevents, m_iocps[iocp], batchSize, batchSize, events.data(), &timeout_ts) > 0); + + + for (int i = 0; i < requestCount; i++) { + AsyncReadRequest* readRequest = readRequests + i; + struct iocb* myiocb = &(readRequest->myiocb); + memset(myiocb, 0, sizeof(struct iocb)); + myiocb->aio_data = reinterpret_cast(&readRequest); + myiocb->aio_lio_opcode = IOCB_CMD_PWRITE; + myiocb->aio_fildes = m_fileHandle; + myiocb->aio_buf = (std::uint64_t)(readRequest->m_buffer); + //myiocb->aio_nbytes = readRequest->m_readSize; + myiocb->aio_nbytes = PageSize; + myiocb->aio_offset = static_cast(readRequest->m_offset); + } + + uint32_t batchTotalDone = 0; + for (int currSubIoStartId = 0; currSubIoStartId < requestCount; currSubIoStartId += batchSize) { + int currSubIoEndId = (currSubIoStartId + batchSize) > requestCount ? requestCount : currSubIoStartId + batchSize; + int totalToSubmit = currSubIoEndId - currSubIoStartId; + int totalSubmitted = 0, totalDone = 0; + for (int i = 0; i < totalToSubmit; i++) { + iocbs[i] = &(readRequests[currSubIoStartId + i].myiocb); + } + + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::WriteBlocks: iocp:%d totalToSubmit:%d\n", iocp, totalToSubmit); + while (totalDone < totalToSubmit) { + // Submit all I/Os + if (totalSubmitted < totalToSubmit) { + int s = syscall(__NR_io_submit, m_iocps[iocp], totalToSubmit - totalSubmitted, iocbs.data() + totalSubmitted); + if (s > 0) { + totalSubmitted += s; + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::WriteBlocks: io_submit failed\n"); + return batchTotalDone; + } + } + int wait = totalSubmitted - totalDone; + auto d = syscall(__NR_io_getevents, m_iocps[iocp], wait, wait, events.data() + totalDone, &AIOTimeout); + if (d >= 0) { + totalDone += d; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "AsyncFileReader::WriteBlocks: io_getevents failed\n"); + return batchTotalDone; + } + } + batchTotalDone += totalDone; + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) { + if (batchTotalDone < requestCount) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "AsyncFileReader::WriteBlocks (batch[%d:%d]) : timeout, continue for next batch...\n", currSubIoStartId, currSubIoEndId); + } + //break; + } + } + //SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader::WriteBlocks: %d reqs finish\n", requestCount); + return batchTotalDone; + } + virtual bool ReadFileAsync(AsyncReadRequest& readRequest) { - struct iocb myiocb = { 0 }; - myiocb.aio_data = reinterpret_cast(&readRequest); - myiocb.aio_lio_opcode = IOCB_CMD_PREAD; - myiocb.aio_fildes = m_fileHandle; - myiocb.aio_buf = (std::uint64_t)(readRequest.m_buffer); - myiocb.aio_nbytes = readRequest.m_readSize; - myiocb.aio_offset = static_cast(readRequest.m_offset); - - struct iocb* iocbs[1] = { &myiocb }; + struct iocb *myiocb = &(readRequest.myiocb); + myiocb->aio_data = reinterpret_cast(&readRequest); + myiocb->aio_lio_opcode = IOCB_CMD_PREAD; + myiocb->aio_fildes = m_fileHandle; + myiocb->aio_buf = (std::uint64_t)(readRequest.m_buffer); + myiocb->aio_nbytes = readRequest.m_readSize; + myiocb->aio_offset = static_cast(readRequest.m_offset); + + int iocp = (readRequest.m_status & 0xffff) % m_iocps.size(); + struct iocb* iocbs[1] = { myiocb }; int curTry = 0, maxTry = 10; - while (curTry < maxTry && syscall(__NR_io_submit, m_iocps[(readRequest.m_status & 0xffff) % m_iocps.size()], 1, iocbs) < 1) { + while (curTry < maxTry && syscall(__NR_io_submit, m_iocps[iocp], 1, iocbs) < 1) { usleep(AIOTimeout.tv_nsec / 1000); curTry++; } @@ -600,10 +920,16 @@ namespace SPTAG virtual void ShutDown() { + if (m_shutdown) return; + + m_shutdown = true; for (int i = 0; i < m_iocps.size(); i++) syscall(__NR_io_destroy, m_iocps[i]); +#ifdef URING + for (int i = 0; i < m_uring.size(); i++) io_uring_queue_exit(&m_uring[i]); +#endif close(m_fileHandle); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "AsyncFileReader: ShutDown!\n"); #ifndef BATCH_READ - m_shutdown = true; for (auto& th : m_fileIocpThreads) { if (th.joinable()) @@ -614,7 +940,7 @@ namespace SPTAG #endif } - aio_context_t& GetIOCP(int i) { return m_iocps[i]; } + aio_context_t& GetIOCP(int i) { return m_iocps[i % m_iocps.size()]; } int GetFileHandler() { return m_fileHandle; } @@ -631,22 +957,28 @@ namespace SPTAG AsyncReadRequest* req = reinterpret_cast((events[r].data)); if (nullptr != req) { - req->m_callback(req); + req->m_callback(true); } } } } - bool m_shutdown; - std::vector m_fileIocpThreads; #endif + bool m_shutdown = true; + int m_fileHandle; - std::vector m_iocps; + uint64_t m_currSize; + + std::vector m_iocps; + +#ifdef URING + std::vector m_uring; +#endif }; #endif - void BatchReadFileAsync(std::vector>& handlers, AsyncReadRequest* readRequests, int num); + bool BatchReadFileAsync(std::vector>& handlers, AsyncReadRequest* readRequests, int num); } } diff --git a/AnnService/inc/Helper/ConcurrentSet.h b/AnnService/inc/Helper/ConcurrentSet.h index 1dfdf4473..49210a29e 100644 --- a/AnnService/inc/Helper/ConcurrentSet.h +++ b/AnnService/inc/Helper/ConcurrentSet.h @@ -5,15 +5,25 @@ #define _SPTAG_HELPER_CONCURRENTSET_H_ #ifndef _MSC_VER +#ifdef TBB +#include +#include +#include +#include +#else +#include #include #include #include #include +#endif // TBB #else #include #include #include -#endif +#include +#endif // _MSC_VER + namespace SPTAG { namespace Helper @@ -21,9 +31,24 @@ namespace SPTAG namespace Concurrent { #ifndef _MSC_VER +#ifdef TBB + template + using ConcurrentSet = tbb::concurrent_unordered_set; + + template + using ConcurrentMap = tbb::concurrent_unordered_map; + + template + using ConcurrentQueue = tbb::concurrent_queue; + + template + using ConcurrentPriorityQueue = tbb::concurrent_priority_queue; +#else template class ConcurrentSet { + typedef typename std::unordered_set::iterator iterator; + public: ConcurrentSet() { m_lock.reset(new std::shared_timed_mutex); } @@ -41,10 +66,10 @@ namespace SPTAG return m_data.count(key); } - void insert(const T& key) + std::pair insert(const T& key) { std::unique_lock lock(*m_lock); - m_data.insert(key); + return m_data.insert(key); } private: @@ -56,6 +81,7 @@ namespace SPTAG class ConcurrentMap { typedef typename std::unordered_map::iterator iterator; + public: ConcurrentMap(int capacity = 8) { m_lock.reset(new std::shared_timed_mutex); m_data.reserve(capacity); } @@ -79,6 +105,19 @@ namespace SPTAG return m_data[k]; } + size_t unsafe_erase(const K& k) + { + std::unique_lock lock(*m_lock); + return m_data.erase(k); + } + + template + std::pair insert(P&& v) + { + std::unique_lock lock(*m_lock); + return m_data.insert(v); + } + private: std::unique_ptr m_lock; std::unordered_map m_data; @@ -102,8 +141,9 @@ namespace SPTAG bool try_pop(T& j) { std::lock_guard lock(m_lock); - if (m_queue.empty()) return false; - + if (m_queue.empty()) { + return false; + } j = m_queue.front(); m_queue.pop(); return true; @@ -113,6 +153,39 @@ namespace SPTAG std::mutex m_lock; std::queue m_queue; }; + + template + class ConcurrentPriorityQueue + { + public: + ConcurrentPriorityQueue() {} + ~ConcurrentPriorityQueue() {} + + size_type size() const { + std::lock_guard lock(m_lock); + return m_queue.size(); + } + + void push(const T& value) { + std::lock_guard lock(m_lock); + m_queue.push(value); + } + + bool try_pop(T& value) { + std::lock_guard lock(m_lock); + if (m_queue.empty()) { + return false; + } + value = m_queue.top(); + m_queue.pop(); + return true; + } + + private: + std::mutex m_lock; + std::priority_queue m_queue; + }; +#endif // TBB #else template using ConcurrentSet = Concurrency::concurrent_unordered_set; @@ -122,8 +195,11 @@ namespace SPTAG template using ConcurrentQueue = Concurrency::concurrent_queue; + + template + using ConcurrentPriorityQueue = Concurrency::concurrent_priority_queue; #endif } } } -#endif // _SPTAG_HELPER_CONCURRENTSET_H_ \ No newline at end of file +#endif // _SPTAG_HELPER_CONCURRENTSET_H_ diff --git a/AnnService/inc/Helper/DiskIO.h b/AnnService/inc/Helper/DiskIO.h index dd2ba887e..3267afb17 100644 --- a/AnnService/inc/Helper/DiskIO.h +++ b/AnnService/inc/Helper/DiskIO.h @@ -8,6 +8,17 @@ #include #include #include +#include "inc/Core/Common.h" + +#ifdef SPDK +extern "C" { +#include "spdk/env.h" +#include "spdk/event.h" +#include "spdk/log.h" +#include "spdk/thread.h" +#include "spdk/bdev.h" +} +#endif namespace SPTAG { @@ -18,27 +29,96 @@ namespace SPTAG DIS_BulkRead = 0, DIS_UserRead, DIS_HighPriorityUserRead, + DIS_BulkWrite, + DIS_UserWrite, + DIS_HighPriorityUserWrite, DIS_Count }; - struct AsyncReadRequest +#ifdef _MSC_VER + namespace DiskUtils { - std::uint64_t m_offset; - std::uint64_t m_readSize; - char* m_buffer; - std::function m_callback; - int m_status; - // Carry items like counter for callback to process. - void* m_payload; - bool m_success; + struct PrioritizedDiskFileReaderResource; - // Carry exension metadata needed by some DiskIO implementations - void* m_extension; + struct CallbackOverLapped : public OVERLAPPED + { + PrioritizedDiskFileReaderResource* const c_registeredResource; + + void* m_data; + + CallbackOverLapped(PrioritizedDiskFileReaderResource* p_registeredResource) + : c_registeredResource(p_registeredResource), + m_data(nullptr) + { + } + }; + + + struct PrioritizedDiskFileReaderResource + { + CallbackOverLapped m_col; + + PrioritizedDiskFileReaderResource() + : m_col(this) + { + } + }; + } +#endif - AsyncReadRequest() : m_offset(0), m_readSize(0), m_buffer(nullptr), m_status(0), m_payload(nullptr), m_success(false), m_extension(nullptr) {} + template + class PageBuffer + { + public: + PageBuffer() + : m_pageBufferSize(0) + { + } + + void ReservePageBuffer(std::size_t p_size) + { + if (m_pageBufferSize < p_size) + { + m_pageBufferSize = p_size; +#ifdef SPDK + m_pageBuffer.reset(static_cast(spdk_dma_zmalloc(sizeof(T) * m_pageBufferSize, PageSize, NULL)), [=](T* ptr) { spdk_free(ptr); }); +#else + m_pageBuffer.reset(static_cast(BLOCK_ALLOC(sizeof(T) * m_pageBufferSize, PageSize)), [=](T* ptr) { BLOCK_FREE(ptr, PageSize); }); +#endif + } + } + + T* GetBuffer() + { + return m_pageBuffer.get(); + } + + std::size_t GetPageSize() + { + return m_pageBufferSize; + } + + void SetAvailableSize(std::size_t p_size) + { + m_availableSize = p_size; + } + + std::size_t GetAvailableSize() + { + return m_availableSize; + } + + private: + std::shared_ptr m_pageBuffer; + + std::size_t m_pageBufferSize; + + std::size_t m_availableSize = 0; }; + struct AsyncReadRequest; + class DiskIO { public: @@ -51,7 +131,8 @@ namespace SPTAG std::uint64_t maxIOSize = (1 << 20), std::uint32_t maxReadRetries = 2, std::uint32_t maxWriteRetries = 2, - std::uint16_t threadPoolSize = 4) = 0; + std::uint16_t threadPoolSize = 4, + std::uint64_t maxFileSize = (300ULL << 30)) = 0; virtual std::uint64_t ReadBinary(std::uint64_t readSize, char* buffer, std::uint64_t offset = UINT64_MAX) = 0; @@ -62,14 +143,23 @@ namespace SPTAG virtual std::uint64_t WriteString(const char* buffer, std::uint64_t offset = UINT64_MAX) = 0; virtual bool ReadFileAsync(AsyncReadRequest& readRequest) { return false; } + + // interface method for waiting for async read to complete when underlying callback support is not available. + virtual void Wait(AsyncReadRequest& readRequest) { return; } - virtual bool BatchReadFile(AsyncReadRequest* readRequests, std::uint32_t requestCount) { return false; } + virtual std::uint32_t BatchReadFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) { return false; } + + virtual std::uint32_t BatchWriteFile(AsyncReadRequest* readRequests, std::uint32_t requestCount, const std::chrono::microseconds& timeout, int batchSize = -1) { return false; } virtual bool BatchCleanRequests(SPTAG::Helper::AsyncReadRequest* readRequests, std::uint32_t requestCount) { return false; } + virtual bool ExpandFile(uint64_t expandSize) { return false; } + virtual std::uint64_t TellP() = 0; virtual void ShutDown() = 0; + + virtual bool Available() = 0; }; class SimpleFileIO : public DiskIO @@ -79,12 +169,18 @@ namespace SPTAG virtual ~SimpleFileIO() { ShutDown(); } + virtual bool Available() + { + return (m_handle != nullptr) && m_handle->is_open(); + } + virtual bool Initialize(const char* filePath, int openMode, // Max read/write buffer size. std::uint64_t maxIOSize = (1 << 20), std::uint32_t maxReadRetries = 2, std::uint32_t maxWriteRetries = 2, - std::uint16_t threadPoolSize = 4) + std::uint16_t threadPoolSize = 4, + std::uint64_t maxFileSize = (300ULL << 30)) { m_handle.reset(new std::fstream(filePath, (std::ios::openmode)openMode)); return m_handle->is_open(); @@ -167,6 +263,7 @@ namespace SPTAG streambuf(char* buffer, size_t size) { setg(buffer, buffer, buffer + size); + setp(buffer, buffer + size); } std::uint64_t tellp() @@ -183,12 +280,18 @@ namespace SPTAG ShutDown(); } + virtual bool Available() + { + return m_handle != nullptr; + } + virtual bool Initialize(const char* filePath, int openMode, // Max read/write buffer size. std::uint64_t maxIOSize = (1 << 20), std::uint32_t maxReadRetries = 2, std::uint32_t maxWriteRetries = 2, - std::uint16_t threadPoolSize = 4) + std::uint16_t threadPoolSize = 4, + std::uint64_t maxFileSize = (300ULL << 30)) { if (filePath != nullptr) m_handle.reset(new streambuf((char*)filePath, maxIOSize)); diff --git a/AnnService/inc/Helper/KeyValueIO.h b/AnnService/inc/Helper/KeyValueIO.h new file mode 100644 index 000000000..4bda94a91 --- /dev/null +++ b/AnnService/inc/Helper/KeyValueIO.h @@ -0,0 +1,63 @@ +#ifndef _SPTAG_HELPER_KEYVALUEIO_H_ +#define _SPTAG_HELPER_KEYVALUEIO_H_ + +#include "inc/Core/Common.h" +#include "inc/Helper/DiskIO.h" +#include +#include + +namespace SPTAG +{ + namespace Helper + { + class KeyValueIO { + public: + KeyValueIO() {} + + virtual ~KeyValueIO() {} + + virtual void ShutDown() = 0; + + virtual ErrorCode Get(const std::string& key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) { return ErrorCode::Undefined; } + + virtual ErrorCode Get(const SizeType key, std::string* value, const std::chrono::microseconds& timeout, std::vector* reqs) = 0; + + virtual ErrorCode MultiGet(const std::vector& keys, std::vector>& values, const std::chrono::microseconds& timeout, std::vector* reqs) = 0; + + virtual ErrorCode MultiGet(const std::vector& keys, std::vector* values, const std::chrono::microseconds& timeout, std::vector* reqs) = 0; + + virtual ErrorCode MultiGet(const std::vector& keys, std::vector* values, const std::chrono::microseconds &timeout, std::vector* reqs) { return ErrorCode::Undefined; } + + virtual ErrorCode Put(const std::string& key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) { return ErrorCode::Undefined; } + + virtual ErrorCode Put(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) = 0; + + virtual ErrorCode Merge(const SizeType key, const std::string& value, const std::chrono::microseconds& timeout, std::vector* reqs) = 0; + + virtual ErrorCode Delete(SizeType key) = 0; + + virtual ErrorCode DeleteRange(SizeType start, SizeType end) {return ErrorCode::Undefined;} + + virtual void ForceCompaction() {} + + virtual void GetStat() {} + + virtual int64_t GetNumBlocks() { return 0; } + + virtual bool Available() { return false; } + + virtual ErrorCode Check(const SizeType key, int size, std::vector *visited) + { + return ErrorCode::Undefined; + } + + virtual ErrorCode Checkpoint(std::string prefix) {return ErrorCode::Undefined;} + + virtual ErrorCode StartToScan(SizeType& key, std::string* value) {return ErrorCode::Undefined;} + + virtual ErrorCode NextToScan(SizeType& key, std::string* value) {return ErrorCode::Undefined;} + }; + } +} + +#endif \ No newline at end of file diff --git a/AnnService/inc/Helper/LockFree.h b/AnnService/inc/Helper/LockFree.h index 5661b4701..f5a66513d 100644 --- a/AnnService/inc/Helper/LockFree.h +++ b/AnnService/inc/Helper/LockFree.h @@ -32,7 +32,7 @@ namespace SPTAG ~LockFreeVector() { - for (T* ptr : m_blocks) delete ptr; + for (T* ptr : m_blocks) delete[] ptr; m_blocks.clear(); } diff --git a/AnnService/inc/Helper/Logging.h b/AnnService/inc/Helper/Logging.h index 3456d8e0a..221f4dfa5 100644 --- a/AnnService/inc/Helper/Logging.h +++ b/AnnService/inc/Helper/Logging.h @@ -8,6 +8,8 @@ #include #include #include +#include +#include #pragma warning(disable:4996) @@ -33,6 +35,41 @@ namespace SPTAG virtual void Logging(const char* title, LogLevel level, const char* file, int line, const char* func, const char* format, ...) = 0; }; + class LoggerHolder + { +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) || __cplusplus >= 202002L) + private: + std::atomic> m_logger; + public: + LoggerHolder(std::shared_ptr logger) : m_logger(logger) {} + + void SetLogger(std::shared_ptr p_logger) + { + m_logger = p_logger; + } + + std::shared_ptr GetLogger() + { + return m_logger; + } +#else + private: + std::shared_ptr m_logger; + public: + LoggerHolder(std::shared_ptr logger) : m_logger(logger) {} + + void SetLogger(std::shared_ptr p_logger) + { + std::atomic_store(&m_logger, p_logger); + } + + std::shared_ptr GetLogger() + { + return std::atomic_load(&m_logger); + } +#endif + }; + class SimpleLogger : public Logger { public: diff --git a/AnnService/inc/Helper/SimpleIniReader.h b/AnnService/inc/Helper/SimpleIniReader.h index a5fbceb51..a82290c04 100644 --- a/AnnService/inc/Helper/SimpleIniReader.h +++ b/AnnService/inc/Helper/SimpleIniReader.h @@ -5,6 +5,7 @@ #define _SPTAG_HELPER_INIREADER_H_ #include "inc/Core/Common.h" +#include "inc/Helper/DiskIO.h" #include "StringConvert.h" #include diff --git a/AnnService/inc/Helper/StringConvert.h b/AnnService/inc/Helper/StringConvert.h index 1ec2bb021..2e37ee8fa 100644 --- a/AnnService/inc/Helper/StringConvert.h +++ b/AnnService/inc/Helper/StringConvert.h @@ -95,7 +95,7 @@ inline bool ConvertStringToUnsignedInt(const char* p_str, DataType& p_value) return false; } - if (val < (std::numeric_limits::min)() || val >(std::numeric_limits::max)()) + if (val >(std::numeric_limits::max)()) { return false; } @@ -353,6 +353,68 @@ inline bool ConvertStringTo(const char* p_str, QuantizerType& p_v return false; } +template <> +inline bool ConvertStringTo(const char* p_str, NumaStrategy& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineNumaStrategy(Name) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = NumaStrategy::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineNumaStrategy + + return false; +} + +template <> +inline bool ConvertStringTo(const char* p_str, OrderStrategy& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineOrderStrategy(Name) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = OrderStrategy::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineOrderStrategy + + return false; +} + +template <> +inline bool ConvertStringTo(const char* p_str, Storage& p_value) +{ + if (nullptr == p_str) + { + return false; + } + +#define DefineStorage(Name) \ + else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \ + { \ + p_value = Storage::Name; \ + return true; \ + } \ + +#include "inc/Core/DefinitionList.h" +#undef DefineStorage + + return false; +} // Specialization of ConvertToString<>(). template<> @@ -484,6 +546,63 @@ inline std::string ConvertToString(const TruthFileType& p_value) return "Undefined"; } +template <> +inline std::string ConvertToString(const NumaStrategy& p_value) +{ + switch (p_value) + { +#define DefineNumaStrategy(Name) \ + case NumaStrategy::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineNumaStrategy + + default: + break; + } + + return "Undefined"; +} + +template <> +inline std::string ConvertToString(const OrderStrategy& p_value) +{ + switch (p_value) + { +#define DefineOrderStrategy(Name) \ + case OrderStrategy::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineOrderStrategy + + default: + break; + } + + return "Undefined"; +} + +template <> +inline std::string ConvertToString(const Storage& p_value) +{ + switch (p_value) + { +#define DefineStorage(Name) \ + case Storage::Name: \ + return #Name; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineStorage + + default: + break; + } + + return "Undefined"; +} + template <> inline std::string ConvertToString(const ErrorCode& p_value) { diff --git a/AnnService/inc/Helper/ThreadPool.h b/AnnService/inc/Helper/ThreadPool.h index aa9d0c171..96d781833 100644 --- a/AnnService/inc/Helper/ThreadPool.h +++ b/AnnService/inc/Helper/ThreadPool.h @@ -4,6 +4,7 @@ #ifndef _SPTAG_HELPER_THREADPOOL_H_ #define _SPTAG_HELPER_THREADPOOL_H_ +#include #include #include #include @@ -34,6 +35,8 @@ namespace SPTAG public: virtual ~Job() {} virtual void exec(IAbortOperation* p_abort) = 0; + + virtual void exec(void* p_workspace, IAbortOperation* p_abort) = 0; }; ThreadPool() {} @@ -55,16 +58,18 @@ namespace SPTAG Job *j; while (get(j)) { - try + try { j->exec(&m_abort); } - catch (std::exception& e) { - LOG(Helper::LogLevel::LL_Error, "ThreadPool: exception in %s %s\n", typeid(*j).name(), e.what()); + catch (std::exception &e) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ThreadPool: exception in %s %s\n", + typeid(*j).name(), e.what()); } - delete j; - } + currentJobs--; + } }); } } @@ -84,9 +89,11 @@ namespace SPTAG while (m_jobs.empty() && !m_abort.ShouldAbort()) m_cond.wait(lock); if (!m_abort.ShouldAbort()) { j = m_jobs.front(); + currentJobs++; m_jobs.pop(); + return true; } - return !m_abort.ShouldAbort(); + return false; } size_t jobsize() @@ -95,7 +102,12 @@ namespace SPTAG return m_jobs.size(); } + inline uint32_t runningJobs() { return currentJobs; } + + inline bool allClear() { return currentJobs == 0 && jobsize() == 0; } + protected: + std::atomic_uint32_t currentJobs{ 0 }; std::queue m_jobs; Abort m_abort; std::mutex m_lock; diff --git a/AnnService/inc/Quantizer/Training.h b/AnnService/inc/Quantizer/Training.h index 9c65c71e7..7fbb7e0d5 100644 --- a/AnnService/inc/Quantizer/Training.h +++ b/AnnService/inc/Quantizer/Training.h @@ -63,61 +63,89 @@ std::unique_ptr TrainPQQuantizer(std::shared_ptr options, { SizeType numCentroids = 256; if (raw_vectors->Dimension() % options->m_quantizedDim != 0) { - LOG(Helper::LogLevel::LL_Error, "Only n_codebooks that divide dimension are supported.\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Only n_codebooks that divide dimension are supported.\n"); + return nullptr; } DimensionType subdim = raw_vectors->Dimension() / options->m_quantizedDim; auto codebooks = std::make_unique(numCentroids * raw_vectors->Dimension()); - LOG(Helper::LogLevel::LL_Info, "Begin Training Quantizer Codebooks.\n"); -#pragma omp parallel for - for (int codebookIdx = 0; codebookIdx < options->m_quantizedDim; codebookIdx++) { - LOG(Helper::LogLevel::LL_Info, "Training Codebook %d.\n", codebookIdx); - auto kargs = COMMON::KmeansArgs(numCentroids, subdim, raw_vectors->Count(), options->m_threadNum, DistCalcMethod::L2, nullptr); - auto dset = COMMON::Dataset(raw_vectors->Count(), subdim, blockRows, raw_vectors->Count()); - - for (int vectorIdx = 0; vectorIdx < raw_vectors->Count(); vectorIdx++) { - auto raw_addr = reinterpret_cast(raw_vectors->GetVector(vectorIdx)) + (codebookIdx * subdim); - auto dset_addr = dset[vectorIdx]; - for (int k = 0; k < subdim; k++) { - dset_addr[k] = raw_addr[k]; - } - } - - std::vector localindices; - localindices.resize(dset.R()); - for (SizeType il = 0; il < localindices.size(); il++) localindices[il] = il; - - auto nclusters = COMMON::KmeansClustering(dset, localindices, 0, dset.R(), kargs, options->m_trainingSamples, options->m_KmeansLambda, options->m_debug, nullptr); - - std::vector reverselocalindex; - reverselocalindex.resize(dset.R()); - for (SizeType il = 0; il < reverselocalindex.size(); il++) - { - reverselocalindex[localindices[il]] = il; - } - - for (int vectorIdx = 0; vectorIdx < raw_vectors->Count(); vectorIdx++) { - auto localidx = reverselocalindex[vectorIdx]; - auto quan_addr = reinterpret_cast(quantized_vectors->GetVector(vectorIdx)); - quan_addr[codebookIdx] = kargs.label[localidx]; - - } - - for (int j = 0; j < numCentroids; j++) { - std::cout << kargs.counts[j] << '\t'; - } - std::cout << std::endl; - - T* cb = codebooks.get() + (numCentroids * subdim * codebookIdx); - for (int i = 0; i < numCentroids; i++) - { - for (int j = 0; j < subdim; j++) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Training Quantizer Codebooks.\n"); + std::vector mythreads; + mythreads.reserve(options->m_threadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < options->m_threadNum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t codebookIdx = 0; + while (true) { - cb[i * subdim + j] = kargs.centers[i * subdim + j]; + codebookIdx = sent.fetch_add(1); + if (codebookIdx < options->m_quantizedDim) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Training Codebook %d.\n", codebookIdx); + auto kargs = COMMON::KmeansArgs(numCentroids, subdim, raw_vectors->Count(), options->m_threadNum, + DistCalcMethod::L2, nullptr); + auto dset = COMMON::Dataset(raw_vectors->Count(), subdim, blockRows, raw_vectors->Count()); + + for (int vectorIdx = 0; vectorIdx < raw_vectors->Count(); vectorIdx++) + { + auto raw_addr = + reinterpret_cast(raw_vectors->GetVector(vectorIdx)) + (codebookIdx * subdim); + auto dset_addr = dset[vectorIdx]; + for (int k = 0; k < subdim; k++) + { + dset_addr[k] = raw_addr[k]; + } + } + + std::vector localindices; + localindices.resize(dset.R()); + for (SizeType il = 0; il < localindices.size(); il++) + localindices[il] = il; + + // auto nclusters = COMMON::KmeansClustering(dset, localindices, 0, dset.R(), kargs, + // options->m_trainingSamples, options->m_KmeansLambda, options->m_debug, nullptr); + + std::vector reverselocalindex; + reverselocalindex.resize(dset.R()); + for (SizeType il = 0; il < reverselocalindex.size(); il++) + { + reverselocalindex[localindices[il]] = il; + } + + for (int vectorIdx = 0; vectorIdx < raw_vectors->Count(); vectorIdx++) + { + auto localidx = reverselocalindex[vectorIdx]; + auto quan_addr = reinterpret_cast(quantized_vectors->GetVector(vectorIdx)); + quan_addr[codebookIdx] = kargs.label[localidx]; + } + + for (int j = 0; j < numCentroids; j++) + { + std::cout << kargs.counts[j] << '\t'; + } + std::cout << std::endl; + + T *cb = codebooks.get() + (numCentroids * subdim * codebookIdx); + for (int i = 0; i < numCentroids; i++) + { + for (int j = 0; j < subdim; j++) + { + cb[i * subdim + j] = kargs.centers[i * subdim + j]; + } + } + } + else + { + return; + } } - } + }); } - + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); return codebooks; -} \ No newline at end of file +} diff --git a/AnnService/inc/SPFresh/SPFresh.h b/AnnService/inc/SPFresh/SPFresh.h new file mode 100644 index 000000000..7390472ad --- /dev/null +++ b/AnnService/inc/SPFresh/SPFresh.h @@ -0,0 +1,1246 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common.h" +#include "inc/Core/Common/TruthSet.h" +#include "inc/Core/SPANN/Index.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/VectorSetReader.h" +#include + +#include +#include +#include + +using namespace SPTAG; + +namespace SPTAG { + namespace SSDServing { + namespace SPFresh { + + typedef std::chrono::steady_clock SteadClock; + + double getMsInterval(std::chrono::steady_clock::time_point start, std::chrono::steady_clock::time_point end) { + return (std::chrono::duration_cast(end - start).count() * 1.0) / 1000.0; + } + + double getSecInterval(std::chrono::steady_clock::time_point start, std::chrono::steady_clock::time_point end) { + return (std::chrono::duration_cast(end - start).count() * 1.0) / 1000.0; + } + + double getMinInterval(std::chrono::steady_clock::time_point start, std::chrono::steady_clock::time_point end) { + return (std::chrono::duration_cast(end - start).count() * 1.0) / 60.0; + } + + /// Clock class + class StopWSPFresh { + private: + std::chrono::steady_clock::time_point time_begin; + public: + StopWSPFresh() { + time_begin = std::chrono::steady_clock::now(); + } + + double getElapsedMs() { + std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); + return getMsInterval(time_begin, time_end); + } + + double getElapsedSec() { + std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); + return getSecInterval(time_begin, time_end); + } + + double getElapsedMin() { + std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); + return getMinInterval(time_begin, time_end); + } + + void reset() { + time_begin = std::chrono::steady_clock::now(); + } + }; + + template + void OutputResult(const std::string& p_output, std::vector& p_results, int p_resultNum) + { + if (!p_output.empty()) + { + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(p_output.c_str(), std::ios::binary | std::ios::out)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed create file: %s\n", p_output.c_str()); + exit(1); + } + int32_t i32Val = static_cast(p_results.size()); + if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + exit(1); + } + i32Val = p_resultNum; + if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + exit(1); + } + + float fVal = 0; + for (size_t i = 0; i < p_results.size(); ++i) + { + for (int j = 0; j < p_resultNum; ++j) + { + i32Val = p_results[i].GetResult(j)->VID; + if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + exit(1); + } + + fVal = p_results[i].GetResult(j)->Dist; + if (ptr->WriteBinary(sizeof(fVal), reinterpret_cast(&fVal)) != sizeof(fVal)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + exit(1); + } + } + } + } + } + + void ShowMemoryStatus(std::shared_ptr vectorSet, double second) + { + int tSize = 0, resident = 0, share = 0; + std::ifstream buffer("/proc/self/statm"); + buffer >> tSize >> resident >> share; + buffer.close(); +#ifndef _MSC_VER + long page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; // in case x86-64 is configured to use 2MB pages +#else + SYSTEM_INFO sysInfo; + GetSystemInfo(&sysInfo); + long page_size_kb = sysInfo.dwPageSize / 1024; +#endif + long rss = resident * page_size_kb; + long vector_size; + if (vectorSet != nullptr) + vector_size = vectorSet->PerVectorDataSize() * (vectorSet->Count() / 1024); + else + vector_size = 0; + long vector_size_mb = vector_size / 1024; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Current time: %.0lf. RSS : %ld MB, Vector Set Size : %ld MB, True Size: %ld MB\n", second, rss / 1024, vector_size_mb, rss / 1024 - vector_size_mb); + } + + template + void PrintPercentiles(const std::vector& p_values, std::function p_get, const char* p_format, bool reverse=false) + { + double sum = 0; + std::vector collects; + collects.reserve(p_values.size()); + for (const auto& v : p_values) + { + T tmp = p_get(v); + sum += tmp; + collects.push_back(tmp); + } + if (reverse) { + std::sort(collects.begin(), collects.end(), std::greater()); + } + else { + std::sort(collects.begin(), collects.end()); + } + if (reverse) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMin\n"); + } + else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMax\n"); + } + + std::string formatStr("%.3lf"); + for (int i = 1; i < 7; ++i) + { + formatStr += '\t'; + formatStr += p_format; + } + + formatStr += '\n'; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + formatStr.c_str(), + sum / collects.size(), + collects[static_cast(collects.size() * 0.50)], + collects[static_cast(collects.size() * 0.90)], + collects[static_cast(collects.size() * 0.95)], + collects[static_cast(collects.size() * 0.99)], + collects[static_cast(collects.size() * 0.999)], + collects[static_cast(collects.size() - 1)]); + } + + template + static float CalculateRecallSPFresh(VectorIndex* index, std::vector& results, const std::vector>& truth, int K, int truthK, std::shared_ptr querySet, std::shared_ptr vectorSet, SizeType NumQuerys, std::ofstream* log = nullptr, bool debug = false) + { + float meanrecall = 0, minrecall = MaxDist, maxrecall = 0, stdrecall = 0; + std::vector thisrecall(NumQuerys, 0); + std::unique_ptr visited(new bool[K]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start Calculating Recall\n"); + for (SizeType i = 0; i < NumQuerys; i++) + { + memset(visited.get(), 0, K * sizeof(bool)); + for (SizeType id : truth[i]) + { + for (int j = 0; j < K; j++) + { + if (visited[j] || results[i].GetResult(j)->VID < 0) continue; + if (results[i].GetResult(j)->VID == id) + { + thisrecall[i] += 1; + visited[j] = true; + break; + } else if (vectorSet != nullptr) { + float dist = results[i].GetResult(j)->Dist; + float truthDist = COMMON::DistanceUtils::ComputeDistance((const T*)querySet->GetVector(i), (const T*)vectorSet->GetVector(id), vectorSet->Dimension(), index->GetDistCalcMethod()); + if (index->GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine && fabs(dist - truthDist) < Epsilon) { + thisrecall[i] += 1; + visited[j] = true; + break; + } + else if (index->GetDistCalcMethod() == SPTAG::DistCalcMethod::L2 && fabs(dist - truthDist) < Epsilon * (dist + Epsilon)) { + thisrecall[i] += 1; + visited[j] = true; + break; + } + } + } + } + thisrecall[i] /= truthK; + meanrecall += thisrecall[i]; + if (thisrecall[i] < minrecall) minrecall = thisrecall[i]; + if (thisrecall[i] > maxrecall) maxrecall = thisrecall[i]; + + if (debug) { + std::string ll("recall:" + std::to_string(thisrecall[i]) + "\ngroundtruth:"); + std::vector truthvec; + for (SizeType id : truth[i]) { + float truthDist = 0.0; + if (vectorSet != nullptr) { + truthDist = COMMON::DistanceUtils::ComputeDistance((const T*)querySet->GetVector(i), (const T*)vectorSet->GetVector(id), querySet->Dimension(), index->GetDistCalcMethod()); + } + truthvec.emplace_back(id, truthDist); + } + std::sort(truthvec.begin(), truthvec.end()); + for (int j = 0; j < truthvec.size(); j++) + ll += std::to_string(truthvec[j].node) + "@" + std::to_string(truthvec[j].distance) + ","; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); + ll = "ann:"; + for (int j = 0; j < K; j++) + ll += std::to_string(results[i].GetResult(j)->VID) + "@" + std::to_string(results[i].GetResult(j)->Dist) + ","; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%s\n", ll.c_str()); + } + } + meanrecall /= NumQuerys; + for (SizeType i = 0; i < NumQuerys; i++) + { + stdrecall += (thisrecall[i] - meanrecall) * (thisrecall[i] - meanrecall); + } + stdrecall = std::sqrt(stdrecall / NumQuerys); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "stdrecall: %.6lf, maxrecall: %.2lf, minrecall: %.2lf\n", stdrecall, maxrecall, minrecall); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nRecall Distribution:\n"); + PrintPercentiles(thisrecall, + [](const float recall) -> float + { + return recall; + }, + "%.3lf", true); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall%d@%d: %f\n", K, truthK, meanrecall); + + if (log) (*log) << meanrecall << " " << stdrecall << " " << minrecall << " " << maxrecall << std::endl; + return meanrecall; + } + + template + double SearchSequential(SPANN::Index* p_index, + int p_numThreads, + std::vector& p_results, + std::vector& p_stats, + int p_maxQueryCount, int p_internalResultNum) + { + int numQueries = min(static_cast(p_results.size()), p_maxQueryCount); + + std::atomic_size_t queriesSent(0); + + std::vector threads; + + StopWSPFresh sw; + + auto func = [&]() + { + StopWSPFresh threadws; + size_t index = 0; + while (true) + { + index = queriesSent.fetch_add(1); + if (index < numQueries) + { + double startTime = threadws.getElapsedMs(); + p_index->GetMemoryIndex()->SearchIndex(p_results[index]); + double endTime = threadws.getElapsedMs(); + + p_stats[index].m_totalLatency = endTime - startTime; + + p_index->SearchDiskIndex(p_results[index], &(p_stats[index])); + double exEndTime = threadws.getElapsedMs(); + + p_stats[index].m_exLatency = exEndTime - endTime; + p_stats[index].m_totalLatency = p_stats[index].m_totalSearchLatency = exEndTime - startTime; + } + else + { + return; + } + } + }; + for (int i = 0; i < p_numThreads; i++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + + auto sendingCost = sw.getElapsedSec(); + + return numQueries / sendingCost; + } + + template + void PrintStats(std::vector& stats) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nEx Elements Count:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_totalListElementsCount; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nHead Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_totalSearchLatency - ss.m_exLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nSetup Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_exSetUpLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nComp Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_compLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nSPDK Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_diskReadLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nEx Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_exLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Latency Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> double + { + return ss.m_totalSearchLatency; + }, + "%.3lf"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Disk Page Access Distribution(KB):\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> int + { + return ss.m_diskAccessCount; + }, + "%4d"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Disk IO Distribution:\n"); + PrintPercentiles(stats, + [](const SPANN::SearchStats& ss) -> int + { + return ss.m_diskIOCount; + }, + "%4d"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\n"); + } + + void ResetStats(std::vector& totalStats) + { + for (int i = 0; i < totalStats.size(); i++) + { + totalStats[i].m_totalListElementsCount = 0; + totalStats[i].m_exLatency = 0; + totalStats[i].m_totalSearchLatency = 0; + totalStats[i].m_diskAccessCount = 0; + totalStats[i].m_diskIOCount = 0; + totalStats[i].m_compLatency = 0; + totalStats[i].m_diskReadLatency = 0; + totalStats[i].m_exSetUpLatency = 0; + } + } + + void AddStats(std::vector& totalStats, std::vector& addedStats) + { + for (int i = 0; i < totalStats.size(); i++) + { + totalStats[i].m_totalListElementsCount += addedStats[i].m_totalListElementsCount; + totalStats[i].m_exLatency += addedStats[i].m_exLatency; + totalStats[i].m_totalSearchLatency += addedStats[i].m_totalSearchLatency; + totalStats[i].m_diskAccessCount += addedStats[i].m_diskAccessCount; + totalStats[i].m_diskIOCount += addedStats[i].m_diskIOCount; + totalStats[i].m_compLatency += addedStats[i].m_compLatency; + totalStats[i].m_diskReadLatency += addedStats[i].m_diskReadLatency; + totalStats[i].m_exSetUpLatency += addedStats[i].m_exSetUpLatency; + } + } + + void AvgStats(std::vector& totalStats, int avgStatsNum) + { + for (int i = 0; i < totalStats.size(); i++) + { + totalStats[i].m_totalListElementsCount /= avgStatsNum; + totalStats[i].m_exLatency /= avgStatsNum; + totalStats[i].m_totalSearchLatency /= avgStatsNum; + totalStats[i].m_diskAccessCount /= avgStatsNum; + totalStats[i].m_diskIOCount /= avgStatsNum; + totalStats[i].m_compLatency /= avgStatsNum; + totalStats[i].m_diskReadLatency /= avgStatsNum; + totalStats[i].m_exSetUpLatency /= avgStatsNum; + } + } + + std::string convertFloatToString(const float value, const int precision = 0) + { + std::stringstream stream{}; + stream< LoadVectorSet(SPANN::Options& p_opts, int numThreads) + { + std::shared_ptr vectorSet; + if (p_opts.m_loadAllVectors) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading VectorSet...\n"); + if (!p_opts.m_fullVectorPath.empty() && fileexists(p_opts.m_fullVectorPath.c_str())) { + std::shared_ptr vectorOptions(new Helper::ReaderOptions(p_opts.m_valueType, p_opts.m_dim, p_opts.m_vectorType, p_opts.m_vectorDelimiter)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (ErrorCode::Success == vectorReader->LoadFile(p_opts.m_fullVectorPath)) + { + vectorSet = vectorReader->GetVectorSet(); + if (p_opts.m_distCalcMethod == DistCalcMethod::Cosine) vectorSet->Normalize(numThreads); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nLoad VectorSet(%d,%d).\n", vectorSet->Count(), vectorSet->Dimension()); + } + } + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Reduce memory usage\n"); + } + return vectorSet; + } + + std::shared_ptr LoadUpdateVectors(SPANN::Options& p_opts, std::vector& insertSet, SizeType updateSize) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load Update Vectors\n"); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(p_opts.m_fullVectorPath.c_str(), std::ios::binary | std::ios::in)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", p_opts.m_fullVectorPath.c_str()); + throw std::runtime_error("Failed read file"); + } + + SizeType row; + DimensionType col; + if (ptr->ReadBinary(sizeof(SizeType), (char*)&row) != sizeof(SizeType)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + throw std::runtime_error("Failed read file"); + } + if (ptr->ReadBinary(sizeof(DimensionType), (char*)&col) != sizeof(DimensionType)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + throw std::runtime_error("Failed read file"); + } + + std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(p_opts.m_valueType)) * updateSize * col; + ByteArray vectorSet; + if (totalRecordVectorBytes > 0) { + vectorSet = ByteArray::Alloc(totalRecordVectorBytes); + char* vecBuf = reinterpret_cast(vectorSet.Data()); + std::uint64_t readSize = ((std::uint64_t)GetValueTypeSize(p_opts.m_valueType)) * col; + for (int i = 0; i < updateSize; i++) { + std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(p_opts.m_valueType)) * insertSet[i] * col + sizeof(SizeType) + sizeof(DimensionType); + if (ptr->ReadBinary(readSize, vecBuf + i*readSize, offset) != readSize) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + throw std::runtime_error("Failed read file"); + } + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load Vector(%d,%d)\n",updateSize, col); + return std::make_shared(vectorSet, + p_opts.m_valueType, + col, + updateSize); + + } + + std::shared_ptr LoadQuerySet(SPANN::Options& p_opts) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading QuerySet...\n"); + std::shared_ptr queryOptions(new Helper::ReaderOptions(p_opts.m_valueType, p_opts.m_dim, p_opts.m_queryType, p_opts.m_queryDelimiter)); + auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); + if (ErrorCode::Success != queryReader->LoadFile(p_opts.m_queryPath)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + exit(1); + } + return queryReader->GetVectorSet(); + } + + void LoadTruth(SPANN::Options& p_opts, std::vector>& truth, int numQueries, std::string truthfilename, int truthK) + { + auto ptr = f_createIO(); + if (p_opts.m_update) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...: %s\n", truthfilename.c_str()); + + if (ptr == nullptr || !ptr->Initialize(truthfilename.c_str(), std::ios::in | std::ios::binary)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open truth file: %s\n", truthfilename.c_str()); + exit(1); + } + } else { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...: %s\n", p_opts.m_truthPath.c_str()); + + if (ptr == nullptr || !ptr->Initialize(p_opts.m_truthPath.c_str(), std::ios::in | std::ios::binary)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open truth file: %s\n", truthfilename.c_str()); + exit(1); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "K: %d, TruthResultNum: %d\n", p_opts.m_resultNum, p_opts.m_truthResultNum); + COMMON::TruthSet::LoadTruth(ptr, truth, numQueries, p_opts.m_truthResultNum, p_opts.m_resultNum, p_opts.m_truthType); + char tmp[4]; + if (ptr->ReadBinary(4, tmp) == 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number is larger than query number(%d)!\n", numQueries); + } + } + + template + void StableSearch(SPANN::Index* p_index, + int numThreads, + std::shared_ptr querySet, + std::shared_ptr vectorSet, + int avgStatsNum, + int queryCountLimit, + int internalResultNum, + std::string& truthFileName, + SPANN::Options& p_opts, + double second = 0, + bool showStatus = true) + { + if (avgStatsNum == 0) return; + int numQueries = querySet->Count(); + + std::vector results(numQueries, QueryResult(NULL, internalResultNum, false)); + + if (showStatus) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Searching: numThread: %d, numQueries: %d, searchTimes: %d.\n", numThreads, numQueries, avgStatsNum); + std::vector stats(numQueries); + std::vector TotalStats(numQueries); + ResetStats(TotalStats); + double totalQPS = 0; + for (int i = 0; i < avgStatsNum; i++) + { + for (int j = 0; j < numQueries; ++j) + { + results[j].SetTarget(reinterpret_cast(querySet->GetVector(j))); + results[j].Reset(); + } + totalQPS += SearchSequential(p_index, numThreads, results, stats, queryCountLimit, internalResultNum); + //PrintStats(stats); + AddStats(TotalStats, stats); + } + if (showStatus) SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Current time: %.0lf, Searching Times: %d, AvgQPS: %.2lf.\n", second, avgStatsNum, totalQPS/avgStatsNum); + + AvgStats(TotalStats, avgStatsNum); + + if (showStatus) PrintStats(TotalStats); + + if (p_opts.m_calTruth) + { + if (p_opts.m_searchResult.empty()) { + std::vector> truth; + int K = p_opts.m_resultNum; + int truthK = p_opts.m_resultNum; + // float MRR, recall; + LoadTruth(p_opts, truth, numQueries, truthFileName, truthK); + CalculateRecallSPFresh((p_index->GetMemoryIndex()).get(), results, truth, K, truthK, querySet, vectorSet, numQueries); + // recall = COMMON::TruthSet::CalculateRecall((p_index->GetMemoryIndex()).get(), results, truth, K, truthK, querySet, vectorSet, numQueries, nullptr, false, &MRR); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall%d@%d: %f MRR@%d: %f\n", truthK, K, recall, K, MRR); + } else { + OutputResult(p_opts.m_searchResult + std::to_string(second), results, p_opts.m_resultNum); + } + } + } + + void LoadUpdateMapping(std::string fileName, std::vector& reverseIndices) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading %s\n", fileName.c_str()); + + int vectorNum; + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::in | std::ios::binary)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + if (ptr->ReadBinary(4, (char *)&vectorNum) != 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + } + + reverseIndices.clear(); + reverseIndices.resize(vectorNum); + + if (ptr->ReadBinary(vectorNum * 4, (char*)reverseIndices.data()) != vectorNum * 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "update mapping Error!\n"); + exit(1); + } + } + + void LoadUpdateTrace(std::string fileName, SizeType& updateSize, std::vector& insertSet, std::vector& deleteSet) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading %s\n", fileName.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::in | std::ios::binary)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + int tempSize; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading Size\n"); + + if (ptr->ReadBinary(4, (char *)&tempSize) != 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Update Size Error!\n"); + } + + updateSize = tempSize; + + deleteSet.clear(); + deleteSet.resize(updateSize); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading deleteSet\n"); + + if (ptr->ReadBinary(updateSize * 4, (char*)deleteSet.data()) != updateSize * 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Delete Set Error!\n"); + exit(1); + } + + insertSet.clear(); + insertSet.resize(updateSize); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading insertSet\n"); + + if (ptr->ReadBinary(updateSize * 4, (char*)insertSet.data()) != updateSize * 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert Set Error!\n"); + exit(1); + } + } + + void LoadUpdateTraceStressTest(std::string fileName, SizeType& updateSize, std::vector& insertSet) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading %s\n", fileName.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::in | std::ios::binary)) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + int tempSize; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading Size\n"); + + if (ptr->ReadBinary(4, (char *)&tempSize) != 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Update Size Error!\n"); + } + + updateSize = tempSize; + + insertSet.clear(); + insertSet.resize(updateSize); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading insertSet\n"); + + if (ptr->ReadBinary(updateSize * 4, (char*)insertSet.data()) != updateSize * 4) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert Set Error!\n"); + exit(1); + } + } + + template + void InsertVectorsBySet(SPANN::Index* p_index, + int insertThreads, + std::shared_ptr vectorSet, + std::vector& insertSet, + std::vector& mapping, + int updateSize, + SPANN::Options& p_opts) + { + StopWSPFresh sw; + std::vector threads; + std::vector latency_vector(updateSize); + + std::atomic_size_t vectorsSent(0); + + auto func = [&]() + { + size_t index = 0; + while (true) + { + index = vectorsSent.fetch_add(1); + if (index < updateSize) + { + if ((index & ((1 << 8) - 1)) == 0 && p_opts.m_showUpdateProgress) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Insert: Sent %.2lf%%...\n", index * 100.0 / updateSize); + } + // std::vector meta; + // std::vector metaoffset; + // std::string a = std::to_string(insertSet[index]); + // metaoffset.push_back((std::uint64_t)meta.size()); + // for (size_t j = 0; j < a.length(); j++) + // meta.push_back(a[j]); + // metaoffset.push_back((std::uint64_t)meta.size()); + // std::shared_ptr metaset(new SPTAG::MemMetadataSet( + // SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), + // SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), + // 1)); + if (p_opts.m_stressTest) p_index->DeleteIndex(mapping[insertSet[index]]); + auto insertBegin = std::chrono::high_resolution_clock::now(); + if (p_opts.m_loadAllVectors) + p_index->AddIndexSPFresh(vectorSet->GetVector(insertSet[index]), 1, p_opts.m_dim, &mapping[insertSet[index]]); + else + p_index->AddIndexSPFresh(vectorSet->GetVector(index), 1, p_opts.m_dim, &mapping[insertSet[index]]); + auto insertEnd = std::chrono::high_resolution_clock::now(); + latency_vector[index] = std::chrono::duration_cast(insertEnd - insertBegin).count(); + } + else + { + return; + } + } + }; + for (int j = 0; j < insertThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + + double sendingCost = sw.getElapsedSec(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Insert: Finish sending in %.3lf seconds, sending throughput is %.2lf , insertion count %u.\n", + sendingCost, + updateSize / sendingCost, + static_cast(updateSize)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"Insert: During Update\n"); + + while(!p_index->AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + double syncingCost = sw.getElapsedSec(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Insert: Finish syncing in %.3lf seconds, actuall throughput is %.2lf, insertion count %u.\n", + syncingCost, + updateSize / syncingCost, + static_cast(updateSize)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Insert Latency Distribution:\n"); + PrintPercentiles(latency_vector, + [](const double& ss) -> double + { + return ss; + }, + "%.3lf"); + } + + template + void DeleteVectorsBySet(SPANN::Index* p_index, + int deleteThreads, + std::shared_ptr vectorSet, + std::vector& deleteSet, + std::vector& mapping, + int updateSize, + SPANN::Options& p_opts, + int batch) + { + int avgQPS = p_opts.m_deleteQPS / deleteThreads; + if (p_opts.m_deleteQPS == -1) avgQPS = -1; + std::vector latency_vector(updateSize); + std::vector threads; + StopWSPFresh sw; + std::atomic_size_t vectorsSent(0); + auto func = [&]() + { + int deleteCount = 0; + while (true) + { + deleteCount++; + size_t index = 0; + index = vectorsSent.fetch_add(1); + if (index < updateSize) + { + if ((index & ((1 << 8) - 1)) == 0 && p_opts.m_showUpdateProgress) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Delete: Sent %.2lf%%...\n", index * 100.0 / updateSize); + } + auto deleteBegin = std::chrono::high_resolution_clock::now(); + p_index->DeleteIndex(mapping[deleteSet[index]]); + // p_index->DeleteIndex(vectorSet->GetVector(deleteSet[index]), deleteSet[index]); + // std::vector meta; + // std::string a = std::to_string(deleteSet[index]); + // for (size_t j = 0; j < a.length(); j++) + // meta.push_back(a[j]); + // ByteArray metarr = SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false); + + // if (p_index->VectorIndex::DeleteIndex(metarr) == ErrorCode::VectorNotFound) { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"VID meta no found: %d\n", deleteSet[index]); + // exit(1); + // } + + auto deleteEnd = std::chrono::high_resolution_clock::now(); + latency_vector[index] = std::chrono::duration_cast(deleteEnd - deleteBegin).count(); + if (avgQPS != -1 && deleteCount == avgQPS) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + deleteCount = 0; + } + } + else + { + return; + } + } + }; + for (int j = 0; j < deleteThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + double sendingCost = sw.getElapsedSec(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Delete: Finish sending in %.3lf seconds, sending throughput is %.2lf , deletion count %u.\n", + sendingCost, + updateSize / sendingCost, + static_cast(updateSize)); + } + + template + void SteadyStateSPFresh(SPANN::Index* p_index) + { + SPANN::Options& p_opts = *(p_index->GetOptions()); + int days = p_opts.m_days; + if (days == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Need to input update days\n"); + exit(1); + } + StopWSPFresh sw; + + int numThreads = p_opts.m_searchThreadNum; + int internalResultNum = p_opts.m_searchInternalResultNum; + int searchTimes = p_opts.m_searchTimes; + + auto vectorSet = LoadVectorSet(p_opts, numThreads); + + auto querySet = LoadQuerySet(p_opts); + + int curCount = p_index->GetNumSamples(); + + bool calTruthOrigin = p_opts.m_calTruth; + + p_index->ForceCompaction(); + + p_index->GetDBStat(); + + if (!p_opts.m_onlySearchFinalBatch) { + if (p_opts.m_maxInternalResultNum != -1) + { + for (int iterInternalResultNum = p_opts.m_minInternalResultNum; iterInternalResultNum <= p_opts.m_maxInternalResultNum; iterInternalResultNum += p_opts.m_stepInternalResultNum) + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, iterInternalResultNum, p_opts.m_truthPath, p_opts, sw.getElapsedSec()); + } + } + else + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, p_opts.m_truthPath, p_opts, sw.getElapsedSec()); + } + } + // exit(1); + + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + p_index->GetDBStat(); + + int insertThreads = p_opts.m_insertThreadNum; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Updating: numThread: %d, total days: %d.\n", insertThreads, days); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start updating...\n"); + + int updateSize; + std::vector insertSet; + std::vector deleteSet; + std::vector mapping; + if (p_opts.m_endVectorNum == -1) p_opts.m_endVectorNum = curCount; + mapping.resize(p_opts.m_endVectorNum); + for (int i = 0; i < p_opts.m_endVectorNum; i++) { + mapping[i] = i; + } + + for (int i = 0; i < days; i++) + { + + std::string traceFileName = p_opts.m_updateFilePrefix + std::to_string(i); + if (!p_opts.m_stressTest) LoadUpdateTrace(traceFileName, updateSize, insertSet, deleteSet); + else LoadUpdateTraceStressTest(traceFileName, updateSize, insertSet); + if (!p_opts.m_loadAllVectors) { + vectorSet = LoadUpdateVectors(p_opts, insertSet, updateSize); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Updating day: %d: numThread: %d, updateSize: %d,total days: %d.\n", i, insertThreads, updateSize, days); + + int sampleSize = updateSize / p_opts.m_sampling; + + int nextSamplePoint = sampleSize + updateSize * i; + + bool showStatus = false; + + std::future delete_future; + if (!p_opts.m_stressTest) { + delete_future = + std::async(std::launch::async, DeleteVectorsBySet, p_index, + 1, vectorSet, std::ref(deleteSet), std::ref(mapping), updateSize, std::ref(p_opts), i); + } + + std::future_status delete_status; + + std::future insert_future = + std::async(std::launch::async, InsertVectorsBySet, p_index, + insertThreads, vectorSet, std::ref(insertSet), std::ref(mapping), updateSize, std::ref(p_opts)); + + std::future_status insert_status; + + std::string tempFileName; + p_opts.m_calTruth = false; + do { + insert_status = insert_future.wait_for(std::chrono::seconds(2)); + if (!p_opts.m_stressTest) delete_status = delete_future.wait_for(std::chrono::seconds(2)); + else delete_status = std::future_status::ready; + if (insert_status == std::future_status::timeout || delete_status == std::future_status::timeout) { + if (p_index->GetNumDeleted() >= nextSamplePoint) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Samppling Size: %d\n", nextSamplePoint); + showStatus = true; + nextSamplePoint += sampleSize; + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + p_index->GetIndexStat(-1, false, false); + } else { + showStatus = false; + } + p_index->GetDBStat(); + if(p_opts.m_searchDuringUpdate) StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, tempFileName, p_opts, sw.getElapsedSec(), showStatus); + p_index->GetDBStat(); + } + } while (insert_status != std::future_status::ready || delete_status != std::future_status::ready); + + curCount += updateSize; + + p_index->GetIndexStat(updateSize, true, true); + p_index->GetDBStat(); + + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + + std::string truthFileName; + + if (!p_opts.m_stressTest) truthFileName = p_opts.m_truthFilePrefix + std::to_string(i); + else truthFileName = p_opts.m_truthPath; + + p_index->Checkpoint(); + + p_opts.m_calTruth = calTruthOrigin; + if (p_opts.m_onlySearchFinalBatch && days - 1 != i) continue; + p_index->StopMerge(); + if (p_opts.m_maxInternalResultNum != -1) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Latency & Recall Tradeoff\n"); + for (int iterInternalResultNum = p_opts.m_minInternalResultNum; iterInternalResultNum <= p_opts.m_maxInternalResultNum; iterInternalResultNum += p_opts.m_stepInternalResultNum) + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, iterInternalResultNum, truthFileName, p_opts, sw.getElapsedSec()); + } + } + else + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, truthFileName, p_opts, sw.getElapsedSec()); + } + p_index->OpenMerge(); + } + } + + template + void InsertVectors(SPANN::Index* p_index, + int insertThreads, + std::shared_ptr vectorSet, + int curCount, + int step, + SPANN::Options& p_opts) + { + StopWSPFresh sw; + std::vector threads; + + std::atomic_size_t vectorsSent(0); + std::vector latency_vector(step); + + auto func = [&]() + { + size_t index = 0; + while (true) + { + index = vectorsSent.fetch_add(1); + if (index < step) + { + if ((index & ((1 << 14) - 1)) == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Sent %.2lf%%...\n", index * 100.0 / step); + } + auto insertBegin = std::chrono::high_resolution_clock::now(); + p_index->AddIndex(vectorSet->GetVector((SizeType)(index + curCount)), 1, p_opts.m_dim, nullptr); + auto insertEnd = std::chrono::high_resolution_clock::now(); + latency_vector[index] = std::chrono::duration_cast(insertEnd - insertBegin).count(); + } + else + { + return; + } + } + }; + for (int j = 0; j < insertThreads; j++) { threads.emplace_back(func); } + for (auto& thread : threads) { thread.join(); } + + double sendingCost = sw.getElapsedSec(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Finish sending in %.3lf seconds, sending throughput is %.2lf , insertion count %u.\n", + sendingCost, + step/ sendingCost, + static_cast(step)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info,"During Update\n"); + + while(!p_index->AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + double syncingCost = sw.getElapsedSec(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Finish syncing in %.3lf seconds, actuall throughput is %.2lf, insertion count %u.\n", + syncingCost, + step / syncingCost, + static_cast(step)); + PrintPercentiles(latency_vector, + [](const double& ss) -> double + { + return ss; + }, + "%.3lf"); + } + + template + void UpdateSPFresh(SPANN::Index* p_index) + { + SPANN::Options& p_opts = *(p_index->GetOptions()); + int step = p_opts.m_step; + if (step == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Incremental Test Error, Need to set step.\n"); + exit(1); + } + StopWSPFresh sw; + + int numThreads = p_opts.m_searchThreadNum; + int internalResultNum = p_opts.m_searchInternalResultNum; + int searchTimes = p_opts.m_searchTimes; + + auto vectorSet = LoadVectorSet(p_opts, numThreads); + + auto querySet = LoadQuerySet(p_opts); + + int curCount = p_index->GetNumSamples(); + int insertCount = vectorSet->Count() - curCount; + + bool calTruthOrigin = p_opts.m_calTruth; + + if (p_opts.m_endVectorNum != -1) + { + insertCount = p_opts.m_endVectorNum - curCount; + } + + p_index->ForceCompaction(); + + p_index->GetDBStat(); + + if (!p_opts.m_onlySearchFinalBatch) { + if (p_opts.m_maxInternalResultNum != -1) + { + for (int iterInternalResultNum = p_opts.m_minInternalResultNum; iterInternalResultNum <= p_opts.m_maxInternalResultNum; iterInternalResultNum += p_opts.m_stepInternalResultNum) + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, iterInternalResultNum, p_opts.m_truthPath, p_opts, sw.getElapsedSec()); + } + } + else + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, p_opts.m_truthPath, p_opts, sw.getElapsedSec()); + } + } + + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + p_index->GetDBStat(); + + int batch; + if (step == 0) { + batch = 0; + } else { + batch = insertCount / step; + } + + int finishedInsert = 0; + int insertThreads = p_opts.m_insertThreadNum; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Updating: numThread: %d, step: %d, insertCount: %d, totalBatch: %d.\n", insertThreads, step, insertCount, batch); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start updating...\n"); + for (int i = 0; i < 1; i++) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Updating Batch %d: numThread: %d, step: %d.\n", i, insertThreads, step); + + std::future insert_future = + std::async(std::launch::async, InsertVectors, p_index, + insertThreads, vectorSet, curCount, step, std::ref(p_opts)); + + std::future_status insert_status; + + std::string tempFileName; + p_opts.m_calTruth = false; + do { + insert_status = insert_future.wait_for(std::chrono::seconds(3)); + if (insert_status == std::future_status::timeout) { + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + p_index->GetIndexStat(-1, false, false); + p_index->GetDBStat(); + if(p_opts.m_searchDuringUpdate) StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, tempFileName, p_opts, sw.getElapsedSec()); + } + }while (insert_status != std::future_status::ready); + + curCount += step; + finishedInsert += step; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total Vector num %d \n", curCount); + + p_index->GetIndexStat(finishedInsert, true, true); + + ShowMemoryStatus(vectorSet, sw.getElapsedSec()); + + std::string truthFileName = p_opts.m_truthFilePrefix + std::to_string(i); + + p_opts.m_calTruth = calTruthOrigin; + if (p_opts.m_onlySearchFinalBatch && batch - 1 != i) continue; + // p_index->ForceGC(); + // p_index->ForceCompaction(); + p_index->StopMerge(); + if (p_opts.m_maxInternalResultNum != -1) + { + for (int iterInternalResultNum = p_opts.m_minInternalResultNum; iterInternalResultNum <= p_opts.m_maxInternalResultNum; iterInternalResultNum += p_opts.m_stepInternalResultNum) + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, iterInternalResultNum, truthFileName, p_opts, sw.getElapsedSec()); + } + } + else + { + StableSearch(p_index, numThreads, querySet, vectorSet, searchTimes, p_opts.m_queryCountLimit, internalResultNum, truthFileName, p_opts, sw.getElapsedSec()); + } + p_index->OpenMerge(); + } + } + + int UpdateTest(const char* storePath) { + + std::shared_ptr index; + + if (index->LoadIndex(storePath, index) != ErrorCode::Success) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to load index.\n"); + return 1; + } + + SPANN::Options* opts = nullptr; + + #define DefineVectorValueType(Name, Type) \ + if (index->GetVectorValueType() == VectorValueType::Name) { \ + opts = ((SPANN::Index*)index.get())->GetOptions(); \ + } \ + + #include "inc/Core/DefinitionList.h" + #undef DefineVectorValueType + + #define DefineVectorValueType(Name, Type) \ + if (opts->m_valueType == VectorValueType::Name) { \ + if (opts->m_steadyState) SteadyStateSPFresh((SPANN::Index*)(index.get())); \ + else UpdateSPFresh((SPANN::Index*)(index.get())); \ + } \ + + #include "inc/Core/DefinitionList.h" + #undef DefineVectorValueType + + return 0; + } + } + } +} diff --git a/AnnService/inc/SSDServing/SSDIndex.h b/AnnService/inc/SSDServing/SSDIndex.h index 4cd62a891..cf3597739 100644 --- a/AnnService/inc/SSDServing/SSDIndex.h +++ b/AnnService/inc/SSDServing/SSDIndex.h @@ -3,12 +3,10 @@ #pragma once #include - #include "inc/Core/Common.h" #include "inc/Core/Common/DistanceUtils.h" #include "inc/Core/Common/QueryResultSet.h" #include "inc/Core/SPANN/Index.h" -#include "inc/Core/SPANN/ExtraFullGraphSearcher.h" #include "inc/Helper/VectorSetReader.h" #include "inc/Helper/StringConvert.h" #include "inc/SSDServing/Utils.h" @@ -24,17 +22,17 @@ namespace SPTAG { { auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(p_output.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed create file: %s\n", p_output.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed create file: %s\n", p_output.c_str()); return ErrorCode::FailedCreateFile; } int32_t i32Val = static_cast(p_results.size()); if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); return ErrorCode::DiskIOFail; } i32Val = p_resultNum; if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); return ErrorCode::DiskIOFail; } @@ -45,13 +43,13 @@ namespace SPTAG { { i32Val = p_results[i].GetResult(j)->VID; if (ptr->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); return ErrorCode::DiskIOFail; } fVal = p_results[i].GetResult(j)->Dist; if (ptr->WriteBinary(sizeof(fVal), reinterpret_cast(&fVal)) != sizeof(fVal)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); return ErrorCode::DiskIOFail; } } @@ -75,7 +73,7 @@ namespace SPTAG { std::sort(collects.begin(), collects.end()); - LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMax\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMax\n"); std::string formatStr("%.3lf"); for (int i = 1; i < 7; ++i) @@ -86,7 +84,7 @@ namespace SPTAG { formatStr += '\n'; - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, formatStr.c_str(), sum / collects.size(), collects[static_cast(collects.size() * 0.50)], @@ -110,56 +108,59 @@ namespace SPTAG { std::atomic_size_t queriesSent(0); std::vector threads; - - LOG(Helper::LogLevel::LL_Info, "Searching: numThread: %d, numQueries: %d.\n", p_numThreads, numQueries); + threads.reserve(p_numThreads); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Searching: numThread: %d, numQueries: %d.\n", p_numThreads, numQueries); Utils::StopW sw; - auto func = [&]() - { - Utils::StopW threadws; - size_t index = 0; - while (true) + for (int i = 0; i < p_numThreads; i++) { threads.emplace_back([&, i]() { - index = queriesSent.fetch_add(1); - if (index < numQueries) + NumaStrategy ns = (p_index->GetDiskIndex() != nullptr) ? NumaStrategy::SCATTER : NumaStrategy::LOCAL; // Only for SPANN, we need to avoid IO threads overlap with search threads. + Helper::SetThreadAffinity(i, threads[i], ns, OrderStrategy::ASC); + + Utils::StopW threadws; + size_t index = 0; + while (true) { - if ((index & ((1 << 14) - 1)) == 0) + index = queriesSent.fetch_add(1); + if (index < numQueries) { - LOG(Helper::LogLevel::LL_Info, "Sent %.2lf%%...\n", index * 100.0 / numQueries); - } + if ((index & ((1 << 14) - 1)) == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Sent %.2lf%%...\n", index * 100.0 / numQueries); + } - double startTime = threadws.getElapsedMs(); - p_index->GetMemoryIndex()->SearchIndex(p_results[index]); - double endTime = threadws.getElapsedMs(); - p_index->DebugSearchDiskIndex(p_results[index], p_internalResultNum, p_internalResultNum, &(p_stats[index])); - double exEndTime = threadws.getElapsedMs(); - p_results[index].ClearTmp(); + double startTime = threadws.getElapsedMs(); + p_index->GetMemoryIndex()->SearchIndex(p_results[index]); + double endTime = threadws.getElapsedMs(); + p_index->SearchDiskIndex(p_results[index], &(p_stats[index])); + double exEndTime = threadws.getElapsedMs(); - p_stats[index].m_exLatency = exEndTime - endTime; - p_stats[index].m_totalLatency = p_stats[index].m_totalSearchLatency = exEndTime - startTime; - } - else - { - return; + p_stats[index].m_exLatency = exEndTime - endTime; + p_stats[index].m_totalLatency = p_stats[index].m_totalSearchLatency = exEndTime - startTime; + } + else + { + return; + } } - } - }; - - for (int i = 0; i < p_numThreads; i++) { threads.emplace_back(func); } + }); + } for (auto& thread : threads) { thread.join(); } double sendingCost = sw.getElapsedSec(); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish sending in %.3lf seconds, actuallQPS is %.2lf, query count %u.\n", sendingCost, numQueries / sendingCost, static_cast(numQueries)); + + for (int i = 0; i < numQueries; i++) { p_results[i].CleanQuantizedTarget(); } } template - void Search(SPANN::Index* p_index) + ErrorCode Search(SPANN::Index* p_index) { SPANN::Options& p_opts = *(p_index->GetOptions()); std::string outputFile = p_opts.m_searchResult; @@ -171,25 +172,25 @@ namespace SPTAG { p_index->m_pQuantizer->SetEnableADC(p_opts.m_enableADC); } -// if (!p_opts.m_logFile.empty()) -// { -// g_pLogger.reset(new Helper::FileLogger(Helper::LogLevel::LL_Info, p_opts.m_logFile.c_str())); -// } -// int numThreads = p_opts.m_iSSDNumberOfThreads; - int numThreads = 1; + if (!p_opts.m_logFile.empty()) + { + SetLogger(std::make_shared(Helper::LogLevel::LL_Info, p_opts.m_logFile.c_str())); + } + int numThreads = p_opts.m_searchThreadNum; int internalResultNum = p_opts.m_searchInternalResultNum; int K = p_opts.m_resultNum; int truthK = (p_opts.m_truthResultNum <= 0) ? K : p_opts.m_truthResultNum; + ErrorCode ret; if (!warmupFile.empty()) { - LOG(Helper::LogLevel::LL_Info, "Start loading warmup query set...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading warmup query set...\n"); std::shared_ptr queryOptions(new Helper::ReaderOptions(p_opts.m_valueType, p_opts.m_dim, p_opts.m_warmupType, p_opts.m_warmupDelimiter)); auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); - if (ErrorCode::Success != queryReader->LoadFile(p_opts.m_warmupPath)) + if (ErrorCode::Success != (ret = queryReader->LoadFile(p_opts.m_warmupPath))) { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + return ret; } auto warmupQuerySet = queryReader->GetVectorSet(); int warmupNumQueries = warmupQuerySet->Count(); @@ -202,18 +203,18 @@ namespace SPTAG { warmupResults[i].Reset(); } - LOG(Helper::LogLevel::LL_Info, "Start warmup...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start warmup...\n"); SearchSequential(p_index, numThreads, warmupResults, warmpUpStats, p_opts.m_queryCountLimit, internalResultNum); - LOG(Helper::LogLevel::LL_Info, "\nFinish warmup...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nFinish warmup...\n"); } - LOG(Helper::LogLevel::LL_Info, "Start loading QuerySet...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading QuerySet...\n"); std::shared_ptr queryOptions(new Helper::ReaderOptions(p_opts.m_valueType, p_opts.m_dim, p_opts.m_queryType, p_opts.m_queryDelimiter)); auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); - if (ErrorCode::Success != queryReader->LoadFile(p_opts.m_queryPath)) + if (ErrorCode::Success != (ret = queryReader->LoadFile(p_opts.m_queryPath))) { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + return ret; } auto querySet = queryReader->GetVectorSet(); int numQueries = querySet->Count(); @@ -227,11 +228,11 @@ namespace SPTAG { } - LOG(Helper::LogLevel::LL_Info, "Start ANN Search...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start ANN Search...\n"); SearchSequential(p_index, numThreads, results, stats, p_opts.m_queryCountLimit, internalResultNum); - LOG(Helper::LogLevel::LL_Info, "\nFinish ANN Search...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nFinish ANN Search...\n"); std::shared_ptr vectorSet; @@ -242,12 +243,12 @@ namespace SPTAG { { vectorSet = vectorReader->GetVectorSet(); if (p_opts.m_distCalcMethod == DistCalcMethod::Cosine) vectorSet->Normalize(numThreads); - LOG(Helper::LogLevel::LL_Info, "\nLoad VectorSet(%d,%d).\n", vectorSet->Count(), vectorSet->Dimension()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nLoad VectorSet(%d,%d).\n", vectorSet->Count(), vectorSet->Dimension()); } } if (p_opts.m_rerank > 0 && vectorSet != nullptr) { - LOG(Helper::LogLevel::LL_Info, "\n Begin rerank...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\n Begin rerank...\n"); for (int i = 0; i < results.size(); i++) { for (int j = 0; j < K; j++) @@ -266,25 +267,25 @@ namespace SPTAG { std::vector> truth; if (!truthFile.empty()) { - LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...\n"); auto ptr = f_createIO(); if (ptr == nullptr || !ptr->Initialize(truthFile.c_str(), std::ios::in | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "Failed open truth file: %s\n", truthFile.c_str()); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open truth file: %s\n", truthFile.c_str()); + return ErrorCode::FailedOpenFile; } int originalK = truthK; COMMON::TruthSet::LoadTruth(ptr, truth, numQueries, originalK, truthK, p_opts.m_truthType); char tmp[4]; if (ptr->ReadBinary(4, tmp) == 4) { - LOG(Helper::LogLevel::LL_Error, "Truth number is larger than query number(%d)!\n", numQueries); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number is larger than query number(%d)!\n", numQueries); } recall = COMMON::TruthSet::CalculateRecall((p_index->GetMemoryIndex()).get(), results, truth, K, truthK, querySet, vectorSet, numQueries, nullptr, false, &MRR); - LOG(Helper::LogLevel::LL_Info, "Recall%d@%d: %f MRR@%d: %f\n", truthK, K, recall, K, MRR); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall%d@%d: %f MRR@%d: %f\n", truthK, K, recall, K, MRR); } - LOG(Helper::LogLevel::LL_Info, "\nEx Elements Count:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nEx Elements Count:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> double { @@ -292,7 +293,7 @@ namespace SPTAG { }, "%.3lf"); - LOG(Helper::LogLevel::LL_Info, "\nHead Latency Distribution:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nHead Latency Distribution:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> double { @@ -300,7 +301,7 @@ namespace SPTAG { }, "%.3lf"); - LOG(Helper::LogLevel::LL_Info, "\nEx Latency Distribution:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nEx Latency Distribution:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> double { @@ -308,7 +309,7 @@ namespace SPTAG { }, "%.3lf"); - LOG(Helper::LogLevel::LL_Info, "\nTotal Latency Distribution:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Latency Distribution:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> double { @@ -316,7 +317,7 @@ namespace SPTAG { }, "%.3lf"); - LOG(Helper::LogLevel::LL_Info, "\nTotal Disk Page Access Distribution:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Disk Page Access Distribution:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> int { @@ -324,7 +325,7 @@ namespace SPTAG { }, "%4d"); - LOG(Helper::LogLevel::LL_Info, "\nTotal Disk IO Distribution:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nTotal Disk IO Distribution:\n"); PrintPercentiles(stats, [](const SPANN::SearchStats& ss) -> int { @@ -332,21 +333,21 @@ namespace SPTAG { }, "%4d"); - LOG(Helper::LogLevel::LL_Info, "\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\n"); if (!outputFile.empty()) { - LOG(Helper::LogLevel::LL_Info, "Start output to %s\n", outputFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start output to %s\n", outputFile.c_str()); OutputResult(outputFile, results, K); } - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall@%d: %f MRR@%d: %f\n", K, recall, K, MRR); - LOG(Helper::LogLevel::LL_Info, "\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\n"); if (p_opts.m_recall_analysis) { - LOG(Helper::LogLevel::LL_Info, "Start recall analysis...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start recall analysis...\n"); std::shared_ptr headIndex = p_index->GetMemoryIndex(); SizeType sampleSize = numQueries < 100 ? numQueries : 100; @@ -364,118 +365,184 @@ namespace SPTAG { std::vector postingCut(sampleSize, 0); for (int i = 0; i < sampleSize; i++) samples[i] = COMMON::Utils::rand(numQueries); -#pragma omp parallel for schedule(dynamic) - for (int i = 0; i < sampleSize; i++) + std::vector mythreads; + mythreads.reserve(p_opts.m_iSSDNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < p_opts.m_iSSDNumberOfThreads; tid++) { - COMMON::QueryResultSet queryANNHeads((const ValueType*)(querySet->GetVector(samples[i])), max(K, internalResultNum)); - headIndex->SearchIndex(queryANNHeads); - float queryANNHeadsLongestDist = queryANNHeads.GetResult(internalResultNum - 1)->Dist; - - COMMON::QueryResultSet queryBFHeads((const ValueType*)(querySet->GetVector(samples[i])), max(sampleK, internalResultNum)); - for (SizeType y = 0; y < headIndex->GetNumSamples(); y++) - { - float dist = headIndex->ComputeDistance(queryBFHeads.GetQuantizedTarget(), headIndex->GetSample(y)); - queryBFHeads.AddPoint(y, dist); - } - queryBFHeads.SortResult(); - - { - std::vector visited(internalResultNum, false); - for (SizeType y = 0; y < internalResultNum; y++) + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) { - for (SizeType z = 0; z < internalResultNum; z++) + i = sent.fetch_add(1); + if (i < sampleSize) { - if (visited[z]) continue; - - if (fabs(queryANNHeads.GetResult(z)->Dist - queryBFHeads.GetResult(y)->Dist) < sampleE) + COMMON::QueryResultSet queryANNHeads( + (const ValueType *)(querySet->GetVector(samples[i])), + max(K, internalResultNum)); + headIndex->SearchIndex(queryANNHeads); + float queryANNHeadsLongestDist = + queryANNHeads.GetResult(internalResultNum - 1)->Dist; + + COMMON::QueryResultSet queryBFHeads( + (const ValueType *)(querySet->GetVector(samples[i])), + max(sampleK, internalResultNum)); + for (SizeType y = 0; y < headIndex->GetNumSamples(); y++) { - queryHeadRecalls[i] += 1; - visited[z] = true; - break; + float dist = headIndex->ComputeDistance(queryBFHeads.GetQuantizedTarget(), + headIndex->GetSample(y)); + queryBFHeads.AddPoint(y, dist); } - } - } - } - - std::map> tmpFound; // headID->truths - p_index->DebugSearchDiskIndex(queryBFHeads, internalResultNum, sampleK, nullptr, &truth[samples[i]], &tmpFound); - - for (SizeType z = 0; z < K; z++) { - truthRecalls[i] += truth[samples[i]].count(queryBFHeads.GetResult(z)->VID); - } - - for (SizeType z = 0; z < K; z++) { - truth[samples[i]].erase(results[samples[i]].GetResult(z)->VID); - } + queryBFHeads.SortResult(); - for (std::map>::iterator it = tmpFound.begin(); it != tmpFound.end(); it++) { - float q2truthposting = headIndex->ComputeDistance(querySet->GetVector(samples[i]), headIndex->GetSample(it->first)); - for (auto vid : it->second) { - if (!truth[samples[i]].count(vid)) continue; - - if (q2truthposting < queryANNHeadsLongestDist) shouldSelect[i] += 1; - else { - shouldSelectLong[i] += 1; - - std::set nearQuerySelectedHeads; - float v2vhead = headIndex->ComputeDistance(vectorSet->GetVector(vid), headIndex->GetSample(it->first)); - for (SizeType z = 0; z < internalResultNum; z++) { - if (queryANNHeads.GetResult(z)->VID < 0) break; - float v2qhead = headIndex->ComputeDistance(vectorSet->GetVector(vid), headIndex->GetSample(queryANNHeads.GetResult(z)->VID)); - if (v2qhead < v2vhead) { - nearQuerySelectedHeads.insert(queryANNHeads.GetResult(z)->VID); + { + std::vector visited(internalResultNum, false); + for (SizeType y = 0; y < internalResultNum; y++) + { + for (SizeType z = 0; z < internalResultNum; z++) + { + if (visited[z]) + continue; + + if (fabs(queryANNHeads.GetResult(z)->Dist - + queryBFHeads.GetResult(y)->Dist) < sampleE) + { + queryHeadRecalls[i] += 1; + visited[z] = true; + break; + } + } } } - if (nearQuerySelectedHeads.size() == 0) continue; - - nearQueryHeads[i] += 1; - COMMON::QueryResultSet annTruthHead((const ValueType*)(vectorSet->GetVector(vid)), p_opts.m_debugBuildInternalResultNum); - headIndex->SearchIndex(annTruthHead); + std::map> tmpFound; // headID->truths + p_index->DebugSearchDiskIndex(queryBFHeads, internalResultNum, sampleK, nullptr, + &truth[samples[i]], &tmpFound); - bool found = false; - for (SizeType z = 0; z < annTruthHead.GetResultNum(); z++) { - if (nearQuerySelectedHeads.count(annTruthHead.GetResult(z)->VID)) { - found = true; - break; - } + for (SizeType z = 0; z < K; z++) + { + truthRecalls[i] += truth[samples[i]].count(queryBFHeads.GetResult(z)->VID); } - if (!found) { - annNotFound[i] += 1; - continue; + for (SizeType z = 0; z < K; z++) + { + truth[samples[i]].erase(results[samples[i]].GetResult(z)->VID); } - // RNG rule and posting cut - std::set replicas; - for (SizeType z = 0; z < annTruthHead.GetResultNum() && replicas.size() < p_opts.m_replicaCount; z++) { - BasicResult* item = annTruthHead.GetResult(z); - if (item->VID < 0) break; - - bool good = true; - for (auto r : replicas) { - if (p_opts.m_rngFactor * headIndex->ComputeDistance(headIndex->GetSample(r), headIndex->GetSample(item->VID)) < item->Dist) { - good = false; - break; + for (std::map>::iterator it = tmpFound.begin(); + it != tmpFound.end(); it++) + { + float q2truthposting = headIndex->ComputeDistance( + querySet->GetVector(samples[i]), headIndex->GetSample(it->first)); + for (auto vid : it->second) + { + if (!truth[samples[i]].count(vid)) + continue; + + if (q2truthposting < queryANNHeadsLongestDist) + shouldSelect[i] += 1; + else + { + shouldSelectLong[i] += 1; + + std::set nearQuerySelectedHeads; + float v2vhead = headIndex->ComputeDistance( + vectorSet->GetVector(vid), headIndex->GetSample(it->first)); + for (SizeType z = 0; z < internalResultNum; z++) + { + if (queryANNHeads.GetResult(z)->VID < 0) + break; + float v2qhead = headIndex->ComputeDistance( + vectorSet->GetVector(vid), + headIndex->GetSample(queryANNHeads.GetResult(z)->VID)); + if (v2qhead < v2vhead) + { + nearQuerySelectedHeads.insert(queryANNHeads.GetResult(z)->VID); + } + } + if (nearQuerySelectedHeads.size() == 0) + continue; + + nearQueryHeads[i] += 1; + + COMMON::QueryResultSet annTruthHead( + (const ValueType *)(vectorSet->GetVector(vid)), + p_opts.m_debugBuildInternalResultNum); + headIndex->SearchIndex(annTruthHead); + + bool found = false; + for (SizeType z = 0; z < annTruthHead.GetResultNum(); z++) + { + if (nearQuerySelectedHeads.count(annTruthHead.GetResult(z)->VID)) + { + found = true; + break; + } + } + + if (!found) + { + annNotFound[i] += 1; + continue; + } + + // RNG rule and posting cut + std::set replicas; + for (SizeType z = 0; z < annTruthHead.GetResultNum() && + replicas.size() < p_opts.m_replicaCount; + z++) + { + BasicResult *item = annTruthHead.GetResult(z); + if (item->VID < 0) + break; + + bool good = true; + for (auto r : replicas) + { + if (p_opts.m_rngFactor * headIndex->ComputeDistance( + headIndex->GetSample(r), + headIndex->GetSample(item->VID)) < + item->Dist) + { + good = false; + break; + } + } + if (good) + replicas.insert(item->VID); + } + + found = false; + for (auto r : nearQuerySelectedHeads) + { + if (replicas.count(r)) + { + found = true; + break; + } + } + + if (found) + postingCut[i] += 1; + else + rngRule[i] += 1; } } - if (good) replicas.insert(item->VID); - } - - found = false; - for (auto r : nearQuerySelectedHeads) { - if (replicas.count(r)) { - found = true; - break; - } } - - if (found) postingCut[i] += 1; - else rngRule[i] += 1; + } + else + { + return; } } - } + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); + float headacc = 0, truthacc = 0, shorter = 0, longer = 0, lost = 0, buildNearQueryHeads = 0, buildAnnNotFound = 0, buildRNGRule = 0, buildPostingCut = 0; for (int i = 0; i < sampleSize; i++) { headacc += queryHeadRecalls[i]; @@ -491,31 +558,32 @@ namespace SPTAG { buildPostingCut += postingCut[i]; } - LOG(Helper::LogLevel::LL_Info, "Query head recall @%d:%f.\n", internalResultNum, headacc / sampleSize / internalResultNum); - LOG(Helper::LogLevel::LL_Info, "BF top %d postings truth recall @%d:%f.\n", sampleK, truthK, truthacc / sampleSize / truthK); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Query head recall @%d:%f.\n", internalResultNum, headacc / sampleSize / internalResultNum); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "BF top %d postings truth recall @%d:%f.\n", sampleK, truthK, truthacc / sampleSize / truthK); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Percent of truths in postings have shorter distance than query selected heads: %f percent\n", shorter / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Percent of truths in postings have longer distance than query selected heads: %f percent\n", longer / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\tPercent of truths no shorter distance in query selected heads: %f percent\n", (longer - buildNearQueryHeads) / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\tPercent of truths exists shorter distance in query selected heads: %f percent\n", buildNearQueryHeads / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\t\tRNG rule ANN search loss: %f percent\n", buildAnnNotFound / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\t\tPosting cut loss: %f percent\n", buildPostingCut / lost * 100); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\t\tRNG rule loss: %f percent\n", buildRNGRule / lost * 100); } + return ErrorCode::Success; } } } diff --git a/AnnService/inc/SSDServing/SelectHead.h b/AnnService/inc/SSDServing/SelectHead.h index f44dc4280..48d14d627 100644 --- a/AnnService/inc/SSDServing/SelectHead.h +++ b/AnnService/inc/SSDServing/SelectHead.h @@ -139,7 +139,7 @@ namespace SPTAG { return; } - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "CheckNode: %8d, Height: %3d, MinDepth: %3d, MaxDepth: %3d, Children: %3d, Single: %3d\n", p_nodeID, p_height, @@ -150,7 +150,7 @@ namespace SPTAG { for (int nodeId = node.childStart; nodeId < node.childEnd; ++nodeId) { - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, " ChildNode: %8d, MinDepth: %3d, MaxDepth: %3d, ChildrenCount: %3d, LeafCount: %3d\n", nodeId, p_nodeInfos[nodeId].minDepth, @@ -279,7 +279,7 @@ namespace SPTAG { { if (c.depth > nextDepth) { - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Iteration of depth %d, selected %u heads, about %.2lf%%\n", nextDepth, static_cast(overallSelected.size() - lastSelected), @@ -316,7 +316,7 @@ namespace SPTAG { } } - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Iteration of Revisit, selected %u heads, about %.2lf%%\n", static_cast(overallSelected.size() - lastSelected), overallSelected.size() * 100.0 / p_vectorCount); @@ -369,7 +369,7 @@ namespace SPTAG { overallSelected.erase(p_vectorCount); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Iteration of depth %d, selected %u heads, about %.2lf%%\n", nextDepth, static_cast(overallSelected.size() - lastSelected), @@ -379,22 +379,22 @@ namespace SPTAG { p_selected.reserve(overallSelected.size()); p_selected.assign(overallSelected.begin(), overallSelected.end()); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finally selected %u heads, about %.2lf%%\n", static_cast(p_selected.size()), p_selected.size() * 100.0 / p_vectorCount); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total leaf count %d, big cluster count %d\n", overallLeafCount, bigClusterCount); if (p_opts.m_printSizeCount) { - LOG(Helper::LogLevel::LL_Info, "Leaf size count:\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Leaf size count:\n"); for (const auto& ele : sizeCount) { - LOG(Helper::LogLevel::LL_Info, "%3d %10d\n", ele.first, ele.second); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%3d %10d\n", ele.first, ele.second); } } @@ -407,10 +407,10 @@ namespace SPTAG { } } - LOG(Helper::LogLevel::LL_Info, "Leaf sum count > set leaf size: %d\n", sumGreaterThanLeafSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Leaf sum count > set leaf size: %d\n", sumGreaterThanLeafSize); int covered = DfsCovered(rootIdOfFirstTree, p_tree, overallSelected, false); - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total covered leaf count %d, about %.2lf\n", covered, covered * 100.0 / p_vectorCount); @@ -510,7 +510,7 @@ namespace SPTAG { double diff = static_cast(p_selected.size()) / p_vectorCount - p_opts.m_ratio; - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Select Threshold: %d, Split Threshold: %d, diff: %.2lf%%.\n", opts.m_selectThreshold, opts.m_splitThreshold, @@ -538,7 +538,7 @@ namespace SPTAG { opts.m_selectThreshold = selectThreshold; opts.m_splitThreshold = splitThreshold; - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Final Select Threshold: %d, Split Threshold: %d.\n", opts.m_selectThreshold, opts.m_splitThreshold); @@ -587,7 +587,7 @@ namespace SPTAG { while (1 <= trySelect && trySelect <= p_vectorCount) { - LOG(Helper::LogLevel::LL_Info, "Start try Select Threshold: %d\n", trySelect); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start try Select Threshold: %d\n", trySelect); opts.m_selectThreshold = trySelect; @@ -607,7 +607,7 @@ namespace SPTAG { double diff = static_cast(p_selected.size()) / p_vectorCount - p_opts.m_ratio; - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, " Split Threshold: %d, diff: %.2lf%%.\n", opts.m_splitThreshold, diff * 100.0); @@ -639,7 +639,7 @@ namespace SPTAG { double ssratio = static_cast(opts.m_splitThreshold) / trySelect; double diffOffset = minDiff / p_opts.m_ratio; - LOG(Helper::LogLevel::LL_Info, "Best abs(diff): %.3lf%%, target offset %.3lf, SSRatio: %.3lf\n\n", minDiff * 100.0, diffOffset, ssratio); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Best abs(diff): %.3lf%%, target offset %.3lf, SSRatio: %.3lf\n\n", minDiff * 100.0, diffOffset, ssratio); if (randomTry == 0) { @@ -706,7 +706,7 @@ namespace SPTAG { opts.m_selectThreshold = selectThreshold; opts.m_splitThreshold = splitThreshold; - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Final Select Threshold: %d, Split Threshold: %d.\n", opts.m_selectThreshold, opts.m_splitThreshold); @@ -726,22 +726,22 @@ namespace SPTAG { } else { - LOG(Helper::LogLevel::LL_Info, "Start selecting nodes...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start selecting nodes...\n"); if (!opts.m_selectDynamically) { - LOG(Helper::LogLevel::LL_Info, "Select Head Statically...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Select Head Statically...\n"); SelectHeadStatically(bkt, vectorSet->Count(), opts, selected); } else { - LOG(Helper::LogLevel::LL_Info, "Select Head Dynamically...\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Select Head Dynamically...\n"); SelectHeadDynamically_Old(bkt, vectorSet->Count(), opts, selected); } } if (selected.empty()) { - LOG(Helper::LogLevel::LL_Error, "Can't select any vector as head with current settings\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Can't select any vector as head with current settings\n"); return ErrorCode::Fail; } @@ -751,7 +751,7 @@ namespace SPTAG { { if (counter.count(item) <= 0) { - LOG(Helper::LogLevel::LL_Error, "selected node: %d can't be used to calculate std.", item); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "selected node: %d can't be used to calculate std.", item); return ErrorCode::Fail; } leafSizes.push_back(counter[item]); @@ -769,16 +769,16 @@ namespace SPTAG { int sum = 0; for (auto& kv : leafDict) { - LOG(Helper::LogLevel::LL_Info, "leaf size: %d, nodes: %d. \n", kv.first, kv.second); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "leaf size: %d, nodes: %d. \n", kv.first, kv.second); sum += kv.second; } if (sum != selected.size()) { - LOG(Helper::LogLevel::LL_Error, "please recheck std computation. \n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "please recheck std computation. \n"); return ErrorCode::Fail; } float std = Std(leafSizes.data(), leafSizes.size()); - LOG(Helper::LogLevel::LL_Info, "standard deviation is %.3f.\n", std); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "standard deviation is %.3f.\n", std); } return ErrorCode::Success; @@ -792,8 +792,8 @@ namespace SPTAG { bkt->m_iSamples = opts.m_iSamples; bkt->m_iTreeNumber = opts.m_iTreeNumber; bkt->m_fBalanceFactor = opts.m_fBalanceFactor; - LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n"); - LOG(Helper::LogLevel::LL_Info, "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n", + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n", bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber, opts.m_iNumberOfThreads); VectorSearch::TimeUtils::StopW sw; int dataRowsInBlock = opts.m_datasetRowsInBlock; @@ -802,8 +802,8 @@ namespace SPTAG { bkt->BuildTrees(data, opts.m_distCalcMethod, opts.m_iNumberOfThreads, nullptr, nullptr, true); double elapsedMinutes = sw.getElapsedMin(); - LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n"); - LOG(Helper::LogLevel::LL_Info, "Invoking BuildTrees used time: %.2lf minutes (about %.2lf hours).\n", elapsedMinutes, elapsedMinutes / 60.0); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Invoking BuildTrees used time: %.2lf minutes (about %.2lf hours).\n", elapsedMinutes, elapsedMinutes / 60.0); std::stringstream bktFileNameBuilder; bktFileNameBuilder << opts.m_vectorPath << ".bkt." @@ -833,31 +833,31 @@ namespace SPTAG { headCnt = static_cast(std::round(p_opts.m_ratio * p_vectorCount)); } - LOG(Helper::LogLevel::LL_Info, "Setting requires to select none vectors as head, adjusted it to %d vectors\n", headCnt); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting requires to select none vectors as head, adjusted it to %d vectors\n", headCnt); } if (p_opts.m_iBKTKmeansK > headCnt) { p_opts.m_iBKTKmeansK = headCnt; - LOG(Helper::LogLevel::LL_Info, "Setting of cluster number is less than head count, adjust it to %d\n", headCnt); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting of cluster number is less than head count, adjust it to %d\n", headCnt); } if (p_opts.m_selectThreshold == 0) { p_opts.m_selectThreshold = min(p_vectorCount - 1, static_cast(1 / p_opts.m_ratio)); - LOG(Helper::LogLevel::LL_Info, "Set SelectThreshold to %d\n", p_opts.m_selectThreshold); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SelectThreshold to %d\n", p_opts.m_selectThreshold); } if (p_opts.m_splitThreshold == 0) { p_opts.m_splitThreshold = min(p_vectorCount - 1, static_cast(p_opts.m_selectThreshold * 2)); - LOG(Helper::LogLevel::LL_Info, "Set SplitThreshold to %d\n", p_opts.m_splitThreshold); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SplitThreshold to %d\n", p_opts.m_splitThreshold); } if (p_opts.m_splitFactor == 0) { p_opts.m_splitFactor = min(p_vectorCount - 1, static_cast(std::round(1 / p_opts.m_ratio) + 0.5)); - LOG(Helper::LogLevel::LL_Info, "Set SplitFactor to %d\n", p_opts.m_splitFactor); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SplitFactor to %d\n", p_opts.m_splitFactor); } } @@ -865,25 +865,26 @@ namespace SPTAG { Utils::StopW sw; - LOG(Helper::LogLevel::LL_Info, "Start loading vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading vector file.\n"); auto valueType = COMMON::DistanceUtils::Quantizer ? SPTAG::VectorValueType::UInt8 : opts.m_valueType; std::shared_ptr options(new Helper::ReaderOptions(valueType, opts.m_dim, opts.m_vectorType, opts.m_vectorDelimiter)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); - if (ErrorCode::Success != vectorReader->LoadFile(opts.m_vectorPath)) + ErrorCode ret; + if (ErrorCode::Success != (ret = vectorReader->LoadFile(opts.m_vectorPath))) { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + return ret; } auto vectorSet = vectorReader->GetVectorSet(); if (opts.m_distCalcMethod == DistCalcMethod::Cosine) vectorSet->Normalize(opts.m_iSelectHeadNumberOfThreads); - LOG(Helper::LogLevel::LL_Info, "Finish loading vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish loading vector file.\n"); AdjustOptions(opts, vectorSet->Count()); std::vector selected; if (Helper::StrUtils::StrEqualIgnoreCase(opts.m_selectType.c_str(), "Random")) { - LOG(Helper::LogLevel::LL_Info, "Start generating Random head.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating Random head.\n"); selected.resize(vectorSet->Count()); for (int i = 0; i < vectorSet->Count(); i++) selected[i] = i; std::shuffle(selected.begin(), selected.end(), rg); @@ -893,7 +894,7 @@ namespace SPTAG { } else if (Helper::StrUtils::StrEqualIgnoreCase(opts.m_selectType.c_str(), "Clustering")) { - LOG(Helper::LogLevel::LL_Info, "Start generating Clustering head.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating Clustering head.\n"); int headCnt = static_cast(std::round(opts.m_ratio * vectorSet->Count())); switch (valueType) @@ -911,7 +912,7 @@ namespace SPTAG { } else if (Helper::StrUtils::StrEqualIgnoreCase(opts.m_selectType.c_str(), "BKT")) { - LOG(Helper::LogLevel::LL_Info, "Start generating BKT.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating BKT.\n"); std::shared_ptr bkt; switch (valueType) { @@ -925,7 +926,7 @@ namespace SPTAG { default: break; } - LOG(Helper::LogLevel::LL_Info, "Finish generating BKT.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish generating BKT.\n"); std::unordered_map counter; @@ -936,25 +937,25 @@ namespace SPTAG { if (opts.m_analyzeOnly) { - LOG(Helper::LogLevel::LL_Info, "Analyze Only.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Analyze Only.\n"); std::vector bktNodeInfos(bkt->size()); // Always use the first tree DfsAnalyze(0, bkt, vectorSet, opts, 0, bktNodeInfos); - LOG(Helper::LogLevel::LL_Info, "Analyze Finish.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Analyze Finish.\n"); } else { if (SelectHead(vectorSet, bkt, opts, counter, selected) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to select head.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to select head.\n"); return ErrorCode::Fail; } } } - LOG(Helper::LogLevel::LL_Info, + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Seleted Nodes: %u, about %.2lf%% of total.\n", static_cast(selected.size()), selected.size() * 100.0 / vectorSet->Count()); @@ -967,36 +968,36 @@ namespace SPTAG { if (output == nullptr || outputIDs == nullptr || !output->Initialize(opts.m_headVectorFile.c_str(), std::ios::binary | std::ios::out) || !outputIDs->Initialize(opts.m_headIDFile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed to create output file:%s %s\n", opts.m_headVectorFile.c_str(), opts.m_headIDFile.c_str()); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to create output file:%s %s\n", opts.m_headVectorFile.c_str(), opts.m_headIDFile.c_str()); + return ErrorCode::FailedCreateFile; } SizeType val = static_cast(selected.size()); if (output->WriteBinary(sizeof(val), reinterpret_cast(&val)) != sizeof(val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return ErrorCode::DiskIOFail; } DimensionType dt = vectorSet->Dimension(); if (output->WriteBinary(sizeof(dt), reinterpret_cast(&dt)) != sizeof(dt)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return ErrorCode::DiskIOFail; } for (auto& ele : selected) { uint64_t vid = static_cast(ele); if (outputIDs->WriteBinary(sizeof(vid), reinterpret_cast(&vid)) != sizeof(vid)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return ErrorCode::DiskIOFail; } if (output->WriteBinary(vectorSet->PerVectorDataSize(), (char*)(vectorSet->GetVector((SizeType)vid))) != vectorSet->PerVectorDataSize()) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - exit(1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return ErrorCode::DiskIOFail; } } } double elapsedMinutes = sw.getElapsedMin(); - LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedMinutes, elapsedMinutes / 60.0); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedMinutes, elapsedMinutes / 60.0); return ErrorCode::Success; } diff --git a/AnnService/packages.config b/AnnService/packages.config index 4ad04937b..9dd28f715 100644 --- a/AnnService/packages.config +++ b/AnnService/packages.config @@ -1,12 +1,11 @@  - - - - - - - - - + + + + + + + + \ No newline at end of file diff --git a/AnnService/src/Aggregator/AggregatorContext.cpp b/AnnService/src/Aggregator/AggregatorContext.cpp index 6cf839286..dae7efb43 100644 --- a/AnnService/src/Aggregator/AggregatorContext.cpp +++ b/AnnService/src/Aggregator/AggregatorContext.cpp @@ -10,14 +10,11 @@ using namespace SPTAG; using namespace SPTAG::Aggregator; RemoteMachine::RemoteMachine() - : m_connectionID(Socket::c_invalidConnectionID), - m_status(RemoteMachineStatus::Disconnected) + : m_connectionID(Socket::c_invalidConnectionID), m_status(RemoteMachineStatus::Disconnected) { } - -AggregatorContext::AggregatorContext(const std::string& p_filePath) - : m_initialized(false) +AggregatorContext::AggregatorContext(const std::string &p_filePath) : m_initialized(false) { Helper::IniReader iniReader; if (ErrorCode::Success != iniReader.LoadIniFile(p_filePath)) @@ -30,7 +27,8 @@ AggregatorContext::AggregatorContext(const std::string& p_filePath) m_settings->m_listenAddr = iniReader.GetParameter("Service", "ListenAddr", std::string("0.0.0.0")); m_settings->m_listenPort = iniReader.GetParameter("Service", "ListenPort", std::string("8100")); m_settings->m_threadNum = iniReader.GetParameter("Service", "ThreadNumber", static_cast(8)); - m_settings->m_socketThreadNum = iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); + m_settings->m_socketThreadNum = + iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); m_settings->m_centers = iniReader.GetParameter("Service", "Centers", std::string("centers")); m_settings->m_valueType = iniReader.GetParameter("Service", "ValueType", VectorValueType::Float); m_settings->m_topK = iniReader.GetParameter("Service", "TopK", static_cast(-1)); @@ -61,21 +59,24 @@ AggregatorContext::AggregatorContext(const std::string& p_filePath) m_remoteServers.push_back(std::move(remoteMachine)); } - if (m_settings->m_topK > 0) { + if (m_settings->m_topK > 0) + { std::ifstream inputStream(m_settings->m_centers, std::ifstream::binary); - if (!inputStream.is_open()) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_settings->m_centers.c_str()); - exit(1); + if (!inputStream.is_open()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_settings->m_centers.c_str()); + return; } SizeType row; DimensionType col; - inputStream.read((char*)&row, sizeof(SizeType)); - inputStream.read((char*)&col, sizeof(DimensionType)); - if (row > serverNum) row = serverNum; + inputStream.read((char *)&row, sizeof(SizeType)); + inputStream.read((char *)&col, sizeof(DimensionType)); + if (row > serverNum) + row = serverNum; std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(m_settings->m_valueType)) * row * col; ByteArray vectorSet = ByteArray::Alloc(totalRecordVectorBytes); - char* vecBuf = reinterpret_cast(vectorSet.Data()); + char *vecBuf = reinterpret_cast(vectorSet.Data()); inputStream.read(vecBuf, totalRecordVectorBytes); inputStream.close(); @@ -84,34 +85,26 @@ AggregatorContext::AggregatorContext(const std::string& p_filePath) m_initialized = true; } - AggregatorContext::~AggregatorContext() { } - -bool -AggregatorContext::IsInitialized() const +bool AggregatorContext::IsInitialized() const { return m_initialized; } - -const std::vector>& -AggregatorContext::GetRemoteServers() const +const std::vector> &AggregatorContext::GetRemoteServers() const { return m_remoteServers; } - -const std::shared_ptr& -AggregatorContext::GetSettings() const +const std::shared_ptr &AggregatorContext::GetSettings() const { return m_settings; } -const std::shared_ptr& -AggregatorContext::GetCenters() const +const std::shared_ptr &AggregatorContext::GetCenters() const { return m_centers; } \ No newline at end of file diff --git a/AnnService/src/Aggregator/AggregatorExecutionContext.cpp b/AnnService/src/Aggregator/AggregatorExecutionContext.cpp index 8f7a28375..2c503b1d8 100644 --- a/AnnService/src/Aggregator/AggregatorExecutionContext.cpp +++ b/AnnService/src/Aggregator/AggregatorExecutionContext.cpp @@ -16,35 +16,26 @@ AggregatorExecutionContext::AggregatorExecutionContext(std::size_t p_totalServer m_unfinishedCount = static_cast(p_totalServerNumber); } - AggregatorExecutionContext::~AggregatorExecutionContext() { } - -std::size_t -AggregatorExecutionContext::GetServerNumber() const +std::size_t AggregatorExecutionContext::GetServerNumber() const { return m_results.size(); } - -AggregatorResult& -AggregatorExecutionContext::GetResult(std::size_t p_num) +AggregatorResult &AggregatorExecutionContext::GetResult(std::size_t p_num) { return m_results[p_num]; } - -const Socket::PacketHeader& -AggregatorExecutionContext::GetRequestHeader() const +const Socket::PacketHeader &AggregatorExecutionContext::GetRequestHeader() const { return m_requestHeader; } - -bool -AggregatorExecutionContext::IsCompletedAfterFinsh(std::uint32_t p_finishedCount) +bool AggregatorExecutionContext::IsCompletedAfterFinsh(std::uint32_t p_finishedCount) { auto lastCount = m_unfinishedCount.fetch_sub(p_finishedCount); return lastCount <= p_finishedCount; diff --git a/AnnService/src/Aggregator/AggregatorService.cpp b/AnnService/src/Aggregator/AggregatorService.cpp index 96a3ce726..dbe291924 100644 --- a/AnnService/src/Aggregator/AggregatorService.cpp +++ b/AnnService/src/Aggregator/AggregatorService.cpp @@ -2,27 +2,22 @@ // Licensed under the MIT License. #include "inc/Aggregator/AggregatorService.h" -#include "inc/Server/QueryParser.h" #include "inc/Core/Common/DistanceUtils.h" #include "inc/Helper/Base64Encode.h" +#include "inc/Server/QueryParser.h" using namespace SPTAG; using namespace SPTAG::Aggregator; -AggregatorService::AggregatorService() - : m_shutdownSignals(m_ioContext), - m_pendingConnectServersTimer(m_ioContext) +AggregatorService::AggregatorService() : m_shutdownSignals(m_ioContext), m_pendingConnectServersTimer(m_ioContext) { } - AggregatorService::~AggregatorService() { } - -bool -AggregatorService::Initialize() +bool AggregatorService::Initialize() { std::string configFilePath = "Aggregator.ini"; m_aggregatorContext.reset(new AggregatorContext(configFilePath)); @@ -32,9 +27,7 @@ AggregatorService::Initialize() return m_initalized; } - -void -AggregatorService::Run() +void AggregatorService::Run() { auto threadNum = max((SPTAG::SizeType)1, GetContext()->GetSettings()->m_threadNum); m_threadPool.reset(new boost::asio::thread_pool(threadNum)); @@ -44,39 +37,29 @@ AggregatorService::Run() WaitForShutdown(); } - -void -AggregatorService::StartClient() +void AggregatorService::StartClient() { auto context = GetContext(); Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); - handlerMap->emplace(Socket::PacketType::SearchResponse, - [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) - { - boost::asio::post(*m_threadPool, - std::bind(&AggregatorService::SearchResponseHanlder, - this, - p_srcID, - std::move(p_packet))); - }); - - - m_socketClient.reset(new Socket::Client(handlerMap, - context->GetSettings()->m_socketThreadNum, - 30)); - - m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) - { - auto context = this->GetContext(); - for (const auto& server : context->GetRemoteServers()) - { - if (nullptr != server && p_cid == server->m_connectionID) - { - server->m_status = RemoteMachineStatus::Disconnected; - this->AddToPendingServers(server); - } - } - }); + handlerMap->emplace( + Socket::PacketType::SearchResponse, [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, + std::bind(&AggregatorService::SearchResponseHanlder, this, p_srcID, std::move(p_packet))); + }); + + m_socketClient.reset(new Socket::Client(handlerMap, context->GetSettings()->m_socketThreadNum, 30)); + + m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) { + auto context = this->GetContext(); + for (const auto &server : context->GetRemoteServers()) + { + if (nullptr != server && p_cid == server->m_connectionID) + { + server->m_status = RemoteMachineStatus::Disconnected; + this->AddToPendingServers(server); + } + } + }); { std::lock_guard guard(m_pendingConnectServersMutex); @@ -86,36 +69,24 @@ AggregatorService::StartClient() ConnectToPendingServers(); } - -void -AggregatorService::StartListen() +void AggregatorService::StartListen() { auto context = GetContext(); Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); - handlerMap->emplace(Socket::PacketType::SearchRequest, - [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) - { - boost::asio::post(*m_threadPool, - std::bind(&AggregatorService::SearchRequestHanlder, - this, - p_srcID, - std::move(p_packet))); - }); - - m_socketServer.reset(new Socket::Server(context->GetSettings()->m_listenAddr, - context->GetSettings()->m_listenPort, - handlerMap, - context->GetSettings()->m_socketThreadNum)); - - LOG(Helper::LogLevel::LL_Info, - "Start to listen %s:%s ...\n", - context->GetSettings()->m_listenAddr.c_str(), - context->GetSettings()->m_listenPort.c_str()); -} + handlerMap->emplace( + Socket::PacketType::SearchRequest, [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, + std::bind(&AggregatorService::SearchRequestHanlder, this, p_srcID, std::move(p_packet))); + }); + m_socketServer.reset(new Socket::Server(context->GetSettings()->m_listenAddr, context->GetSettings()->m_listenPort, + handlerMap, context->GetSettings()->m_socketThreadNum)); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to listen %s:%s ...\n", context->GetSettings()->m_listenAddr.c_str(), + context->GetSettings()->m_listenPort.c_str()); +} -void -AggregatorService::WaitForShutdown() +void AggregatorService::WaitForShutdown() { m_shutdownSignals.add(SIGINT); m_shutdownSignals.add(SIGTERM); @@ -123,23 +94,20 @@ AggregatorService::WaitForShutdown() m_shutdownSignals.add(SIGQUIT); #endif - m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) - { - LOG(Helper::LogLevel::LL_Info, "Received shutdown signals.\n"); + m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Received shutdown signals.\n"); m_pendingConnectServersTimer.cancel(); }); m_ioContext.run(); - LOG(Helper::LogLevel::LL_Info, "Start shutdown procedure.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start shutdown procedure.\n"); m_socketServer.reset(); m_threadPool->stop(); m_threadPool->join(); } - -void -AggregatorService::ConnectToPendingServers() +void AggregatorService::ConnectToPendingServers() { auto context = GetContext(); std::vector> pendingList; @@ -150,7 +118,7 @@ AggregatorService::ConnectToPendingServers() pendingList.swap(m_pendingConnectServers); } - for (auto& pendingServer : pendingList) + for (auto &pendingServer : pendingList) { if (pendingServer->m_status != RemoteMachineStatus::Disconnected) { @@ -159,113 +127,115 @@ AggregatorService::ConnectToPendingServers() pendingServer->m_status = RemoteMachineStatus::Connecting; std::shared_ptr server = pendingServer; - auto runner = [server, this]() - { - ErrorCode errCode; - auto cid = m_socketClient->ConnectToServer(server->m_address, server->m_port, errCode); - if (Socket::c_invalidConnectionID == cid) - { - if (ErrorCode::Socket_FailedResolveEndPoint == errCode) - { - LOG(Helper::LogLevel::LL_Error, - "[Error] Failed to resolve %s %s.\n", - server->m_address.c_str(), - server->m_port.c_str()); - } - else - { - this->AddToPendingServers(std::move(server)); - } - } - else - { - server->m_connectionID = cid; - server->m_status = RemoteMachineStatus::Connected; - } - }; + auto runner = [server, this]() { + ErrorCode errCode; + auto cid = m_socketClient->ConnectToServer(server->m_address, server->m_port, errCode); + if (Socket::c_invalidConnectionID == cid) + { + if (ErrorCode::Socket_FailedResolveEndPoint == errCode) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "[Error] Failed to resolve %s %s.\n", + server->m_address.c_str(), server->m_port.c_str()); + } + else + { + this->AddToPendingServers(std::move(server)); + } + } + else + { + server->m_connectionID = cid; + server->m_status = RemoteMachineStatus::Connected; + } + }; boost::asio::post(*m_threadPool, std::move(runner)); } m_pendingConnectServersTimer.expires_from_now(boost::posix_time::seconds(30)); - m_pendingConnectServersTimer.async_wait([this](const boost::system::error_code& p_ec) - { - if (boost::asio::error::operation_aborted != p_ec) - { - ConnectToPendingServers(); - } - }); + m_pendingConnectServersTimer.async_wait([this](const boost::system::error_code &p_ec) { + if (boost::asio::error::operation_aborted != p_ec) + { + ConnectToPendingServers(); + } + }); } - -void -AggregatorService::AddToPendingServers(std::shared_ptr p_remoteServer) +void AggregatorService::AddToPendingServers(std::shared_ptr p_remoteServer) { std::lock_guard guard(m_pendingConnectServersMutex); m_pendingConnectServers.emplace_back(std::move(p_remoteServer)); } - -void -AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +void AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) { auto context = GetContext(); std::vector remoteServers; remoteServers.reserve(context->GetRemoteServers().size()); - if (context->GetSettings()->m_topK > 0 && context->GetRemoteServers().size() == context->GetCenters()->Count()) { - Socket::RemoteQuery remoteQuery; - remoteQuery.Read(p_packet.Body()); - - Service::QueryParser queryParser; - queryParser.Parse(remoteQuery.m_queryString, "|"); - ByteArray vector; - size_t vectorSize; - SizeType vectorDimension = 0; - std::vector servers; - switch (context->GetSettings()->m_valueType) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - if (!queryParser.GetVectorElements().empty()) { \ - Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); \ - } else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) { \ - vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); \ - Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), vectorSize); \ - vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); \ - } \ - for (int i = 0; i < context->GetCenters()->Count(); i++) { \ - servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance((Type*)vector.Data(), \ - (Type*)context->GetCenters()->GetVector(i), vectorDimension, context->GetSettings()->m_distMethod))); \ - } \ - break; \ + if (context->GetSettings()->m_topK > 0 && context->GetRemoteServers().size() == context->GetCenters()->Count()) + { + Socket::RemoteQuery remoteQuery; + remoteQuery.Read(p_packet.Body()); + + Service::QueryParser queryParser; + queryParser.Parse(remoteQuery.m_queryString, "|"); + ByteArray vector; + size_t vectorSize; + SizeType vectorDimension = 0; + std::vector servers; + switch (context->GetSettings()->m_valueType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + if (!queryParser.GetVectorElements().empty()) \ + { \ + Service::ConvertVectorFromString(queryParser.GetVectorElements(), vector, vectorDimension); \ + } \ + else if (queryParser.GetVectorBase64() != nullptr && queryParser.GetVectorBase64Length() != 0) \ + { \ + vector = ByteArray::Alloc(Helper::Base64::CapacityForDecode(queryParser.GetVectorBase64Length())); \ + Helper::Base64::Decode(queryParser.GetVectorBase64(), queryParser.GetVectorBase64Length(), vector.Data(), \ + vectorSize); \ + vectorDimension = (SizeType)(vectorSize / GetValueTypeSize(context->GetSettings()->m_valueType)); \ + } \ + for (int i = 0; i < context->GetCenters()->Count(); i++) \ + { \ + servers.push_back(BasicResult(i, COMMON::DistanceUtils::ComputeDistance( \ + (Type *)vector.Data(), (Type *)context->GetCenters()->GetVector(i), \ + vectorDimension, context->GetSettings()->m_distMethod))); \ + } \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: - break; - } - std::sort(servers.begin(), servers.end(), [](const BasicResult& a, const BasicResult& b) { return a.Dist < b.Dist; }); - for (int i = 0; i < context->GetSettings()->m_topK; i++) { - auto& server = context->GetRemoteServers().at(servers[i].VID); - if (RemoteMachineStatus::Connected != server->m_status) - { - continue; - } - remoteServers.push_back(server->m_connectionID); - } - } - else { - for (const auto& server : context->GetRemoteServers()) - { - if (RemoteMachineStatus::Connected != server->m_status) - { - continue; - } - - remoteServers.push_back(server->m_connectionID); - } - } + default: + break; + } + std::sort(servers.begin(), servers.end(), + [](const BasicResult &a, const BasicResult &b) { return a.Dist < b.Dist; }); + for (int i = 0; i < context->GetSettings()->m_topK; i++) + { + auto &server = context->GetRemoteServers().at(servers[i].VID); + if (RemoteMachineStatus::Connected != server->m_status) + { + continue; + } + remoteServers.push_back(server->m_connectionID); + } + } + else + { + for (const auto &server : context->GetRemoteServers()) + { + if (RemoteMachineStatus::Connected != server->m_status) + { + continue; + } + + remoteServers.push_back(server->m_connectionID); + } + } Socket::PacketHeader requestHeader = p_packet.Header(); if (Socket::c_invalidConnectionID == requestHeader.m_connectionID) { @@ -277,8 +247,7 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID for (std::uint32_t i = 0; i < remoteServers.size(); ++i) { - AggregatorCallback callback = [this, executionContext, i](Socket::RemoteSearchResult p_result) - { + AggregatorCallback callback = [this, executionContext, i](Socket::RemoteSearchResult p_result) { executionContext->GetResult(i).reset(new Socket::RemoteSearchResult(std::move(p_result))); if (executionContext->IsCompletedAfterFinsh(1)) { @@ -286,8 +255,7 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID } }; - auto timeoutCallback = [](std::shared_ptr p_callback) - { + auto timeoutCallback = [](std::shared_ptr p_callback) { if (nullptr != p_callback) { Socket::RemoteSearchResult result; @@ -297,8 +265,7 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID } }; - auto connectCallback = [callback](bool p_connectSucc) - { + auto connectCallback = [callback](bool p_connectSucc) { if (!p_connectSucc) { Socket::RemoteSearchResult result; @@ -313,9 +280,9 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID packet.Header().m_processStatus = Socket::PacketProcessStatus::Ok; packet.Header().m_bodyLength = p_packet.Header().m_bodyLength; packet.Header().m_connectionID = Socket::c_invalidConnectionID; - packet.Header().m_resourceID = m_aggregatorCallbackManager.Add(std::make_shared(std::move(callback)), - context->GetSettings()->m_searchTimeout, - std::move(timeoutCallback)); + packet.Header().m_resourceID = + m_aggregatorCallbackManager.Add(std::make_shared(std::move(callback)), + context->GetSettings()->m_searchTimeout, std::move(timeoutCallback)); packet.AllocateBuffer(packet.Header().m_bodyLength); packet.Header().WriteBuffer(packet.HeaderBuffer()); @@ -325,9 +292,7 @@ AggregatorService::SearchRequestHanlder(Socket::ConnectionID p_localConnectionID } } - -void -AggregatorService::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +void AggregatorService::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) { auto callback = m_aggregatorCallbackManager.GetAndRemove(p_packet.Header().m_resourceID); if (nullptr == callback) @@ -350,17 +315,13 @@ AggregatorService::SearchResponseHanlder(Socket::ConnectionID p_localConnectionI } } - -std::shared_ptr -AggregatorService::GetContext() +std::shared_ptr AggregatorService::GetContext() { // Add mutex if necessary. return m_aggregatorContext; } - -void -AggregatorService::AggregateResults(std::shared_ptr p_exectionContext) +void AggregatorService::AggregateResults(std::shared_ptr p_exectionContext) { if (nullptr == p_exectionContext) { @@ -378,7 +339,7 @@ AggregatorService::AggregateResults(std::shared_ptr std::size_t resultNum = 0; for (std::size_t i = 0; i < p_exectionContext->GetServerNumber(); ++i) { - const auto& result = p_exectionContext->GetResult(i); + const auto &result = p_exectionContext->GetResult(i); if (nullptr == result) { continue; @@ -390,13 +351,13 @@ AggregatorService::AggregateResults(std::shared_ptr remoteResult.m_allIndexResults.reserve(resultNum); for (std::size_t i = 0; i < p_exectionContext->GetServerNumber(); ++i) { - const auto& result = p_exectionContext->GetResult(i); + const auto &result = p_exectionContext->GetResult(i); if (nullptr == result) { continue; } - for (auto& indexRes : result->m_allIndexResults) + for (auto &indexRes : result->m_allIndexResults) { remoteResult.m_allIndexResults.emplace_back(std::move(indexRes)); } @@ -407,7 +368,5 @@ AggregatorService::AggregateResults(std::shared_ptr packet.Header().m_bodyLength = static_cast(remoteResult.Write(packet.Body()) - packet.Body()); packet.Header().WriteBuffer(packet.HeaderBuffer()); - m_socketServer->SendPacket(p_exectionContext->GetRequestHeader().m_connectionID, - std::move(packet), - nullptr); + m_socketServer->SendPacket(p_exectionContext->GetRequestHeader().m_connectionID, std::move(packet), nullptr); } diff --git a/AnnService/src/Aggregator/AggregatorSettings.cpp b/AnnService/src/Aggregator/AggregatorSettings.cpp index a3e2bc680..9bb93117b 100644 --- a/AnnService/src/Aggregator/AggregatorSettings.cpp +++ b/AnnService/src/Aggregator/AggregatorSettings.cpp @@ -6,9 +6,6 @@ using namespace SPTAG; using namespace SPTAG::Aggregator; -AggregatorSettings::AggregatorSettings() - : m_searchTimeout(100), - m_threadNum(8), - m_socketThreadNum(8) +AggregatorSettings::AggregatorSettings() : m_searchTimeout(100), m_threadNum(8), m_socketThreadNum(8) { } diff --git a/AnnService/src/Aggregator/main.cpp b/AnnService/src/Aggregator/main.cpp index 2a06025d5..588d7ceb3 100644 --- a/AnnService/src/Aggregator/main.cpp +++ b/AnnService/src/Aggregator/main.cpp @@ -5,7 +5,7 @@ SPTAG::Aggregator::AggregatorService g_service; -int main(int argc, char* argv[]) +int main(int argc, char *argv[]) { if (!g_service.Initialize()) { @@ -16,4 +16,3 @@ int main(int argc, char* argv[]) return 0; } - diff --git a/AnnService/src/BalancedDataPartition/main.cpp b/AnnService/src/BalancedDataPartition/main.cpp index 4834db5d8..886b10790 100644 --- a/AnnService/src/BalancedDataPartition/main.cpp +++ b/AnnService/src/BalancedDataPartition/main.cpp @@ -1,33 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include "inc/Core/Common/DistanceUtils.h" -#include "inc/Core/Common/Dataset.h" #include "inc/Core/Common/BKTree.h" -#include "inc/Helper/VectorSetReader.h" +#include "inc/Core/Common/Dataset.h" +#include "inc/Core/Common/DistanceUtils.h" #include "inc/Helper/CommonHelper.h" +#include "inc/Helper/VectorSetReader.h" +#include +#include +#include +#include using namespace SPTAG; -#define CHECKIO(ptr, func, bytes, ...) if (ptr->func(bytes, __VA_ARGS__) != bytes) { \ - LOG(Helper::LogLevel::LL_Error, "DiskError: Cannot read or write %d bytes.\n", (int)(bytes)); \ - exit(1); \ -} +#define CHECKIO(ptr, func, bytes, ...) \ + if (ptr->func(bytes, __VA_ARGS__) != bytes) \ + { \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "DiskError: Cannot read or write %d bytes.\n", (int)(bytes)); \ + exit(1); \ + } typedef short LabelType; class PartitionOptions : public Helper::ReaderOptions { -public: - PartitionOptions():Helper::ReaderOptions(VectorValueType::Float, 0, VectorFileType::TXT, "|", 32) + public: + PartitionOptions() : Helper::ReaderOptions(VectorValueType::Float, 0, VectorFileType::TXT, "|", 32) { AddRequiredOption(m_inputFiles, "-i", "--input", "Input raw data."); AddRequiredOption(m_clusterNum, "-c", "--numclusters", "Number of clusters."); - AddOptionalOption(m_stopDifference, "-d", "--diff", "Clustering stop center difference."); + AddOptionalOption(m_stopDifference, "-df", "--diff", "Clustering stop center difference."); AddOptionalOption(m_maxIter, "-r", "--iters", "Max clustering iterations."); AddOptionalOption(m_localSamples, "-s", "--samples", "Number of samples for fast clustering."); AddOptionalOption(m_lambda, "-l", "--lambda", "lambda for balanced size level."); @@ -51,7 +53,9 @@ class PartitionOptions : public Helper::ReaderOptions AddOptionalOption(m_hardcut, "-hc", "--hard", "soft: 0, hard: 1"); } - ~PartitionOptions() {} + ~PartitionOptions() + { + } std::string m_inputFiles; int m_clusterNum; @@ -87,11 +91,15 @@ class PartitionOptions : public Helper::ReaderOptions EdgeCompare g_edgeComparer; template -bool LoadCenters(T* centers, SizeType row, DimensionType col, const std::string& centerpath, float* lambda = nullptr, float* diff = nullptr, float* mindist = nullptr, int* noimprovement = nullptr) { - if (fileexists(centerpath.c_str())) { +bool LoadCenters(T *centers, SizeType row, DimensionType col, const std::string ¢erpath, float *lambda = nullptr, + float *diff = nullptr, float *mindist = nullptr, int *noimprovement = nullptr) +{ + if (fileexists(centerpath.c_str())) + { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(centerpath.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read center file %s.\n", centerpath.c_str()); + if (ptr == nullptr || !ptr->Initialize(centerpath.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read center file %s.\n", centerpath.c_str()); return false; } @@ -99,109 +107,143 @@ bool LoadCenters(T* centers, SizeType row, DimensionType col, const std::string& DimensionType c; float f; int i; - if (ptr->ReadBinary(sizeof(SizeType), (char*)&r) != sizeof(SizeType)) return false; - if (ptr->ReadBinary(sizeof(DimensionType), (char*)&c) != sizeof(DimensionType)) return false; + if (ptr->ReadBinary(sizeof(SizeType), (char *)&r) != sizeof(SizeType)) + return false; + if (ptr->ReadBinary(sizeof(DimensionType), (char *)&c) != sizeof(DimensionType)) + return false; - if (r != row || c != col) { - LOG(Helper::LogLevel::LL_Error, "Row(%d,%d) or Col(%d,%d) cannot match.\n", r, row, c, col); + if (r != row || c != col) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Row(%d,%d) or Col(%d,%d) cannot match.\n", r, row, c, col); return false; } - if (ptr->ReadBinary(sizeof(T) * row * col, (char*)centers) != sizeof(T) * row * col) return false; + if (ptr->ReadBinary(sizeof(T) * row * col, (char *)centers) != sizeof(T) * row * col) + return false; - if (lambda) { - if (ptr->ReadBinary(sizeof(float), (char*)&f) == sizeof(float)) *lambda = f; + if (lambda) + { + if (ptr->ReadBinary(sizeof(float), (char *)&f) == sizeof(float)) + *lambda = f; } - if (diff) { - if (ptr->ReadBinary(sizeof(float), (char*)&f) == sizeof(float)) *diff = f; + if (diff) + { + if (ptr->ReadBinary(sizeof(float), (char *)&f) == sizeof(float)) + *diff = f; } - if (mindist) { - if (ptr->ReadBinary(sizeof(float), (char*)&f) == sizeof(float)) *mindist = f; + if (mindist) + { + if (ptr->ReadBinary(sizeof(float), (char *)&f) == sizeof(float)) + *mindist = f; } - if (noimprovement) { - if (ptr->ReadBinary(sizeof(int), (char*)&i) == sizeof(int)) *noimprovement = i; + if (noimprovement) + { + if (ptr->ReadBinary(sizeof(int), (char *)&i) == sizeof(int)) + *noimprovement = i; } - LOG(Helper::LogLevel::LL_Info, "Load centers(%d,%d) from file %s.\n", row, col, centerpath.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load centers(%d,%d) from file %s.\n", row, col, centerpath.c_str()); return true; } return false; } template -void SaveCenters(T* centers, SizeType row, DimensionType col, const std::string& centerpath, float lambda = 0.0, float diff = 0.0, float mindist = 0.0, int noimprovement = 0) { +void SaveCenters(T *centers, SizeType row, DimensionType col, const std::string ¢erpath, float lambda = 0.0, + float diff = 0.0, float mindist = 0.0, int noimprovement = 0) +{ auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(centerpath.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed to open center file %s to write.\n", centerpath.c_str()); + if (ptr == nullptr || !ptr->Initialize(centerpath.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open center file %s to write.\n", centerpath.c_str()); exit(1); } - CHECKIO(ptr, WriteBinary, sizeof(SizeType), (char*)&row); - CHECKIO(ptr, WriteBinary, sizeof(DimensionType), (char*)&col); - CHECKIO(ptr, WriteBinary, sizeof(T) * row * col, (char*)centers); - CHECKIO(ptr, WriteBinary, sizeof(float), (char*)&lambda); - CHECKIO(ptr, WriteBinary, sizeof(float), (char*)&diff); - CHECKIO(ptr, WriteBinary, sizeof(float), (char*)&mindist); - CHECKIO(ptr, WriteBinary, sizeof(int), (char*)&noimprovement); - LOG(Helper::LogLevel::LL_Info, "Save centers(%d,%d) to file %s.\n", row, col, centerpath.c_str()); + CHECKIO(ptr, WriteBinary, sizeof(SizeType), (char *)&row); + CHECKIO(ptr, WriteBinary, sizeof(DimensionType), (char *)&col); + CHECKIO(ptr, WriteBinary, sizeof(T) * row * col, (char *)centers); + CHECKIO(ptr, WriteBinary, sizeof(float), (char *)&lambda); + CHECKIO(ptr, WriteBinary, sizeof(float), (char *)&diff); + CHECKIO(ptr, WriteBinary, sizeof(float), (char *)&mindist); + CHECKIO(ptr, WriteBinary, sizeof(int), (char *)&noimprovement); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save centers(%d,%d) to file %s.\n", row, col, centerpath.c_str()); } template -inline float MultipleClustersAssign(const COMMON::Dataset& data, - std::vector& indices, - const SizeType first, const SizeType last, COMMON::KmeansArgs& args, COMMON::Dataset& label, bool updateCenters, float lambda, std::vector& weights, float wlambda) { +inline float MultipleClustersAssign(const COMMON::Dataset &data, std::vector &indices, + const SizeType first, const SizeType last, COMMON::KmeansArgs &args, + COMMON::Dataset &label, bool updateCenters, float lambda, + std::vector &weights, float wlambda) +{ float currDist = 0; - SizeType subsize = (last - first - 1) / args._T + 1; + SizeType subsize = (last - first - 1) / args._TH + 1; std::uint64_t avgCount = 0; - for (int k = 0; k < args._K; k++) avgCount += args.counts[k]; + for (int k = 0; k < args._K; k++) + avgCount += args.counts[k]; avgCount /= args._K; - auto func = [&](int tid) - { + std::vector dist_total(args._K * args._TH, 0); + + auto func = [&](int tid) { SizeType istart = first + tid * subsize; SizeType iend = min(first + (tid + 1) * subsize, last); - SizeType* inewCounts = args.newCounts + tid * args._K; - float* inewWeightedCounts = args.newWeightedCounts + tid * args._K; - float* inewCenters = args.newCenters + tid * args._K * args._D; - SizeType* iclusterIdx = args.clusterIdx + tid * args._K; - float* iclusterDist = args.clusterDist + tid * args._K; + SizeType *inewCounts = args.newCounts + tid * args._K; + float *inewWeightedCounts = args.newWeightedCounts + tid * args._K; + float *inewCenters = args.newCenters + tid * args._K * args._D; + SizeType *iclusterIdx = args.clusterIdx + tid * args._K; + float *iclusterDist = args.clusterDist + tid * args._K; + float *idist_total = dist_total.data() + tid * args._K; float idist = 0; std::vector centerDist(args._K, SPTAG::NodeDistPair()); - for (SizeType i = istart; i < iend; i++) { - for (int k = 0; k < args._K; k++) { - float penalty = lambda * (((options.m_newp == 1) && (args.counts[k] < avgCount)) ? avgCount : args.counts[k]) + wlambda * args.weightedCounts[k]; + for (SizeType i = istart; i < iend; i++) + { + for (int k = 0; k < args._K; k++) + { + float penalty = + lambda * (((options.m_newp == 1) && (args.counts[k] < avgCount)) ? avgCount : args.counts[k]) + + wlambda * args.weightedCounts[k]; float dist = args.fComputeDistance(data[indices[i]], args.centers + k * args._D, args._D) + penalty; centerDist[k].node = k; centerDist[k].distance = dist; } - std::sort(centerDist.begin(), centerDist.end(), [](const SPTAG::NodeDistPair& a, const SPTAG::NodeDistPair& b) { - return (a.distance < b.distance) || (a.distance == b.distance && a.node < b.node); - }); - - for (int k = 0; k < label.C(); k++) { - if (centerDist[k].distance <= centerDist[0].distance * options.m_closurefactor) { + std::sort(centerDist.begin(), centerDist.end(), + [](const SPTAG::NodeDistPair &a, const SPTAG::NodeDistPair &b) { + return (a.distance < b.distance) || (a.distance == b.distance && a.node < b.node); + }); + + for (int k = 0; k < label.C(); k++) + { + if (centerDist[k].distance <= centerDist[0].distance * options.m_closurefactor) + { label[i][k] = (LabelType)(centerDist[k].node); inewCounts[centerDist[k].node]++; inewWeightedCounts[centerDist[k].node] += weights[indices[i]]; idist += centerDist[k].distance; - - if (updateCenters) { - const T* v = (const T*)data[indices[i]]; - float* center = inewCenters + centerDist[k].node * args._D; - for (DimensionType j = 0; j < args._D; j++) center[j] += v[j]; - if (centerDist[k].distance > iclusterDist[centerDist[k].node]) { + idist_total[centerDist[k].node] += centerDist[k].distance; + + if (updateCenters) + { + const T *v = (const T *)data[indices[i]]; + float *center = inewCenters + centerDist[k].node * args._D; + for (DimensionType j = 0; j < args._D; j++) + center[j] += v[j]; + if (centerDist[k].distance > iclusterDist[centerDist[k].node]) + { iclusterDist[centerDist[k].node] = centerDist[k].distance; iclusterIdx[centerDist[k].node] = indices[i]; } } - else { - if (centerDist[k].distance <= iclusterDist[centerDist[k].node]) { + else + { + if (centerDist[k].distance <= iclusterDist[centerDist[k].node]) + { iclusterDist[centerDist[k].node] = centerDist[k].distance; iclusterIdx[centerDist[k].node] = indices[i]; } } } - else { + else + { label[i][k] = (std::numeric_limits::max)(); } } @@ -210,35 +252,59 @@ inline float MultipleClustersAssign(const COMMON::Dataset& data, }; std::vector threads; - for (int i = 0; i < args._T; i++) { threads.emplace_back(func, i); } - for (auto& thread : threads) { thread.join(); } + for (int i = 0; i < args._TH; i++) + { + threads.emplace_back(func, i); + } + for (auto &thread : threads) + { + thread.join(); + } - for (int i = 1; i < args._T; i++) { - for (int k = 0; k < args._K; k++) { - args.newCounts[k] += args.newCounts[i*args._K + k]; - args.newWeightedCounts[k] += args.newWeightedCounts[i*args._K + k]; + for (int i = 1; i < args._TH; i++) + { + for (int k = 0; k < args._K; k++) + { + args.newCounts[k] += args.newCounts[i * args._K + k]; + args.newWeightedCounts[k] += args.newWeightedCounts[i * args._K + k]; + dist_total[k] += dist_total[i * args._K + k]; } } - if (updateCenters) { - for (int i = 1; i < args._T; i++) { - float* currCenter = args.newCenters + i*args._K*args._D; - for (size_t j = 0; j < ((size_t)args._K) * args._D; j++) args.newCenters[j] += currCenter[j]; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "start printing dist_total\n"); + for (int k = 0; k < args._K; k++) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d: dist_total:%f \n", k, dist_total[k]); + } - for (int k = 0; k < args._K; k++) { - if (args.clusterIdx[i*args._K + k] != -1 && args.clusterDist[i*args._K + k] > args.clusterDist[k]) { - args.clusterDist[k] = args.clusterDist[i*args._K + k]; - args.clusterIdx[k] = args.clusterIdx[i*args._K + k]; + if (updateCenters) + { + for (int i = 1; i < args._TH; i++) + { + float *currCenter = args.newCenters + i * args._K * args._D; + for (size_t j = 0; j < ((size_t)args._K) * args._D; j++) + args.newCenters[j] += currCenter[j]; + + for (int k = 0; k < args._K; k++) + { + if (args.clusterIdx[i * args._K + k] != -1 && args.clusterDist[i * args._K + k] > args.clusterDist[k]) + { + args.clusterDist[k] = args.clusterDist[i * args._K + k]; + args.clusterIdx[k] = args.clusterIdx[i * args._K + k]; } } } } - else { - for (int i = 1; i < args._T; i++) { - for (int k = 0; k < args._K; k++) { - if (args.clusterIdx[i*args._K + k] != -1 && args.clusterDist[i*args._K + k] <= args.clusterDist[k]) { - args.clusterDist[k] = args.clusterDist[i*args._K + k]; - args.clusterIdx[k] = args.clusterIdx[i*args._K + k]; + else + { + for (int i = 1; i < args._TH; i++) + { + for (int k = 0; k < args._K; k++) + { + if (args.clusterIdx[i * args._K + k] != -1 && args.clusterDist[i * args._K + k] <= args.clusterDist[k]) + { + args.clusterDist[k] = args.clusterDist[i * args._K + k]; + args.clusterIdx[k] = args.clusterIdx[i * args._K + k]; } } } @@ -247,92 +313,113 @@ inline float MultipleClustersAssign(const COMMON::Dataset& data, } template -inline float HardMultipleClustersAssign(const COMMON::Dataset& data, - std::vector& indices, - const SizeType first, const SizeType last, COMMON::KmeansArgs& args, COMMON::Dataset& label, SizeType* mylimit, std::vector& weights, - const int clusternum, const bool fill) { +inline float HardMultipleClustersAssign(const COMMON::Dataset &data, std::vector &indices, + const SizeType first, const SizeType last, COMMON::KmeansArgs &args, + COMMON::Dataset &label, SizeType *mylimit, + std::vector &weights, const int clusternum, const bool fill) +{ float currDist = 0; - SizeType subsize = (last - first - 1) / args._T + 1; + SizeType subsize = (last - first - 1) / args._TH + 1; - SPTAG::Edge* items = new SPTAG::Edge[last - first]; + SPTAG::Edge *items = new SPTAG::Edge[last - first]; - auto func1 = [&](int tid) - { + auto func1 = [&](int tid) { SizeType istart = first + tid * subsize; SizeType iend = min(first + (tid + 1) * subsize, last); - float* iclusterDist = args.clusterDist + tid * args._K; + float *iclusterDist = args.clusterDist + tid * args._K; std::vector centerDist(args._K, SPTAG::NodeDistPair()); - for (SizeType i = istart; i < iend; i++) { - for (int k = 0; k < args._K; k++) { + for (SizeType i = istart; i < iend; i++) + { + for (int k = 0; k < args._K; k++) + { float dist = args.fComputeDistance(data[indices[i]], args.centers + k * args._D, args._D); centerDist[k].node = k; centerDist[k].distance = dist; } - std::sort(centerDist.begin(), centerDist.end(), [](const SPTAG::NodeDistPair& a, const SPTAG::NodeDistPair& b) { - return (a.distance < b.distance) || (a.distance == b.distance && a.node < b.node); - }); + std::sort(centerDist.begin(), centerDist.end(), + [](const SPTAG::NodeDistPair &a, const SPTAG::NodeDistPair &b) { + return (a.distance < b.distance) || (a.distance == b.distance && a.node < b.node); + }); - if (centerDist[clusternum].distance <= centerDist[0].distance * options.m_closurefactor) { + if (centerDist[clusternum].distance <= centerDist[0].distance * options.m_closurefactor) + { items[i - first].node = centerDist[clusternum].node; items[i - first].distance = centerDist[clusternum].distance; items[i - first].tonode = i; iclusterDist[centerDist[clusternum].node] += centerDist[clusternum].distance; } - else { + else + { items[i - first].node = MaxSize; items[i - first].distance = MaxDist; - items[i - first].tonode = -i-1; + items[i - first].tonode = -i - 1; } } }; { std::vector threads; - for (int i = 0; i < args._T; i++) { threads.emplace_back(func1, i); } - for (auto& thread : threads) { thread.join(); } + for (int i = 0; i < args._TH; i++) + { + threads.emplace_back(func1, i); + } + for (auto &thread : threads) + { + thread.join(); + } } std::sort(items, items + last - first, g_edgeComparer); - for (int i = 0; i < args._T; i++) { - for (int k = 0; k < args._K; k++) { + for (int i = 0; i < args._TH; i++) + { + for (int k = 0; k < args._K; k++) + { mylimit[k] -= args.newCounts[i * args._K + k]; - if (i > 0) args.clusterDist[k] += args.clusterDist[i * args._K + k]; + if (i > 0) + args.clusterDist[k] += args.clusterDist[i * args._K + k]; } } std::size_t startIdx = 0; for (int i = 0; i < args._K; ++i) { std::size_t endIdx = std::lower_bound(items, items + last - first, i + 1, g_edgeComparer) - items; - LOG(Helper::LogLevel::LL_Info, "cluster %d: avgdist:%f limit:%d, drop:%zu - %zu\n", items[startIdx].node, args.clusterDist[i] / (endIdx - startIdx), mylimit[i], startIdx + mylimit[i], endIdx); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d: avgdist:%f limit:%d, drop:%zu - %zu\n", + items[startIdx].node, args.clusterDist[i] / (endIdx - startIdx), mylimit[i], startIdx + mylimit[i], + endIdx); for (size_t dropID = startIdx + mylimit[i]; dropID < endIdx; ++dropID) { - if (items[dropID].tonode >= 0) items[dropID].tonode = -items[dropID].tonode - 1; + if (items[dropID].tonode >= 0) + items[dropID].tonode = -items[dropID].tonode - 1; } startIdx = endIdx; } - auto func2 = [&, subsize](int tid) - { + auto func2 = [&, subsize](int tid) { SizeType istart = tid * subsize; SizeType iend = min((tid + 1) * subsize, last - first); - SizeType* inewCounts = args.newCounts + tid * args._K; - float* inewWeightedCounts = args.newWeightedCounts + tid * args._K; + SizeType *inewCounts = args.newCounts + tid * args._K; + float *inewWeightedCounts = args.newWeightedCounts + tid * args._K; float idist = 0; - for (SizeType i = istart; i < iend; i++) { - if (items[i].tonode >= 0) { + for (SizeType i = istart; i < iend; i++) + { + if (items[i].tonode >= 0) + { label[items[i].tonode][clusternum] = (LabelType)(items[i].node); inewCounts[items[i].node]++; inewWeightedCounts[items[i].node] += weights[indices[items[i].tonode]]; idist += items[i].distance; } - else { + else + { items[i].tonode = -items[i].tonode - 1; label[items[i].tonode][clusternum] = (std::numeric_limits::max)(); } - if (fill) { - for (int k = clusternum + 1; k < label.C(); k++) { + if (fill) + { + for (int k = clusternum + 1; k < label.C(); k++) + { label[items[i].tonode][k] = (std::numeric_limits::max)(); } } @@ -342,24 +429,32 @@ inline float HardMultipleClustersAssign(const COMMON::Dataset& data, { std::vector threads2; - for (int i = 0; i < args._T; i++) { threads2.emplace_back(func2, i); } - for (auto& thread : threads2) { thread.join(); } + for (int i = 0; i < args._TH; i++) + { + threads2.emplace_back(func2, i); + } + for (auto &thread : threads2) + { + thread.join(); + } } delete[] items; std::memset(args.counts, 0, sizeof(SizeType) * args._K); std::memset(args.weightedCounts, 0, sizeof(float) * args._K); - for (int i = 0; i < args._T; i++) { - for (int k = 0; k < args._K; k++) { - args.counts[k] += args.newCounts[i*args._K + k]; - args.weightedCounts[k] += args.newWeightedCounts[i*args._K + k]; + for (int i = 0; i < args._TH; i++) + { + for (int k = 0; k < args._K; k++) + { + args.counts[k] += args.newCounts[i * args._K + k]; + args.weightedCounts[k] += args.newWeightedCounts[i * args._K + k]; } } return currDist; } -template -void Process(MPI_Datatype type) { +template void Process(MPI_Datatype type) +{ int rank, size; MPI_Init(NULL, NULL); MPI_Comm_rank(MPI_COMM_WORLD, &rank); @@ -369,241 +464,312 @@ void Process(MPI_Datatype type) { options.m_inputFiles = Helper::StrUtils::ReplaceAll(options.m_inputFiles, "*", std::to_string(rank)); if (ErrorCode::Success != vectorReader->LoadFile(options.m_inputFiles)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); exit(1); } std::shared_ptr vectors = vectorReader->GetVectorSet(); std::shared_ptr metas = vectorReader->GetMetadataSet(); - if (options.m_distMethod == DistCalcMethod::Cosine) vectors->Normalize(options.m_threadNum); + if (options.m_distMethod == DistCalcMethod::Cosine) + vectors->Normalize(options.m_threadNum); std::vector weights(vectors->Count(), 0.0f); - if (options.m_weightfile.compare("-") != 0) { + if (options.m_weightfile.compare("-") != 0) + { options.m_weightfile = Helper::StrUtils::ReplaceAll(options.m_weightfile, "*", std::to_string(rank)); std::ifstream win(options.m_weightfile, std::ifstream::binary); - if (!win.is_open()) { - LOG(Helper::LogLevel::LL_Error, "Rank %d failed to read weight file %s.\n", rank, options.m_weightfile.c_str()); + if (!win.is_open()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Rank %d failed to read weight file %s.\n", rank, + options.m_weightfile.c_str()); exit(1); } SizeType rows; - win.read((char*)&rows, sizeof(SizeType)); - if (rows != vectors->Count()) { + win.read((char *)&rows, sizeof(SizeType)); + if (rows != vectors->Count()) + { win.close(); - LOG(Helper::LogLevel::LL_Error, "Number of weights (%d) is not equal to number of vectors (%d).\n", rows, vectors->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Number of weights (%d) is not equal to number of vectors (%d).\n", + rows, vectors->Count()); exit(1); } - win.read((char*)weights.data(), sizeof(float)*rows); + win.read((char *)weights.data(), sizeof(float) * rows); win.close(); } - COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024*1024, vectors->Count() + 1, (T*)vectors->GetData()); - COMMON::KmeansArgs args(options.m_clusterNum, vectors->Dimension(), vectors->Count(), options.m_threadNum, options.m_distMethod); + COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024 * 1024, vectors->Count() + 1, + (T *)vectors->GetData()); + COMMON::KmeansArgs args(options.m_clusterNum, vectors->Dimension(), vectors->Count(), options.m_threadNum, + options.m_distMethod); COMMON::Dataset label(vectors->Count(), options.m_clusterassign, vectors->Count(), vectors->Count()); std::vector localindices(data.R(), 0); - for (SizeType i = 0; i < data.R(); i++) localindices[i] = i; + for (SizeType i = 0; i < data.R(); i++) + localindices[i] = i; unsigned long long localCount = data.R(), totalCount; MPI_Allreduce(&localCount, &totalCount, 1, MPI_UNSIGNED_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); totalCount = static_cast(totalCount * 1.0 / args._K * options.m_vectorfactor); - if (rank == 0 && options.m_maxIter > 0 && options.m_lambda < -1e-6f) { + if (rank == 0 && options.m_maxIter > 0 && options.m_lambda < -1e-6f) + { float fBalanceFactor = COMMON::DynamicFactorSelect(data, localindices, 0, data.R(), args, data.R()); options.m_lambda = COMMON::Utils::GetBase() * COMMON::Utils::GetBase() / fBalanceFactor / data.R(); } MPI_Bcast(&(options.m_lambda), 1, MPI_FLOAT, 0, MPI_COMM_WORLD); - LOG(Helper::LogLevel::LL_Info, "rank %d data:(%d,%d) machines:%d clusters:%d type:%d threads:%d lambda:%f samples:%d maxcountperpartition:%d\n", - rank, data.R(), data.C(), size, options.m_clusterNum, ((int)options.m_inputValueType), options.m_threadNum, options.m_lambda, options.m_localSamples, totalCount); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "rank %d data:(%d,%d) machines:%d clusters:%d type:%d threads:%d lambda:%f samples:%d " + "maxcountperpartition:%d\n", + rank, data.R(), data.C(), size, options.m_clusterNum, ((int)options.m_inputValueType), + options.m_threadNum, options.m_lambda, options.m_localSamples, totalCount); - if (rank == 0) { - LOG(Helper::LogLevel::LL_Info, "rank 0 init centers\n"); - if (!LoadCenters(args.newTCenters, args._K, args._D, options.m_centers, &(options.m_lambda))) { - if (options.m_seed >= 0) std::srand(options.m_seed); - COMMON::InitCenters(data, localindices, 0, data.R(), args, options.m_localSamples, options.m_initIter); + if (rank == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rank 0 init centers\n"); + if (!LoadCenters(args.newTCenters, args._K, args._D, options.m_centers, &(options.m_lambda))) + { + if (options.m_seed >= 0) + std::srand(options.m_seed); + COMMON::InitCenters(data, localindices, 0, data.R(), args, options.m_localSamples, + options.m_initIter); } } float currDiff = 1.0, d, currDist, minClusterDist = MaxDist; int iteration = 0; int noImprovement = 0; - while (currDiff > options.m_stopDifference && iteration < options.m_maxIter) { - if (rank == 0) { - std::memcpy(args.centers, args.newTCenters, sizeof(T)*args._K*args._D); + while (currDiff > options.m_stopDifference && iteration < options.m_maxIter) + { + if (rank == 0) + { + std::memcpy(args.centers, args.newTCenters, sizeof(T) * args._K * args._D); } - MPI_Bcast(args.centers, args._K*args._D, type, 0, MPI_COMM_WORLD); + MPI_Bcast(args.centers, args._K * args._D, type, 0, MPI_COMM_WORLD); args.ClearCenters(); args.ClearCounts(); args.ClearDists(-MaxDist); - d = MultipleClustersAssign(data, localindices, 0, data.R(), args, label, true, (iteration == 0) ? 0.0f : options.m_lambda, weights, (iteration == 0) ? 0.0f : options.m_wlambda); + d = MultipleClustersAssign(data, localindices, 0, data.R(), args, label, true, + (iteration == 0) ? 0.0f : options.m_lambda, weights, + (iteration == 0) ? 0.0f : options.m_wlambda); MPI_Allreduce(args.newCounts, args.counts, args._K, MPI_INT, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(args.newWeightedCounts, args.weightedCounts, args._K, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(&d, &currDist, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); - if (currDist < minClusterDist) { + if (currDist < minClusterDist) + { noImprovement = 0; minClusterDist = currDist; } - else { + else + { noImprovement++; } - if (noImprovement >= 10) break; + if (noImprovement >= 10) + break; - if (rank == 0) { + if (rank == 0) + { MPI_Reduce(MPI_IN_PLACE, args.newCenters, args._K * args._D, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD); currDiff = COMMON::RefineCenters(data, args); - LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iteration, currDist, currDiff); - } else + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iteration, currDist, currDiff); + } + else MPI_Reduce(args.newCenters, args.newCenters, args._K * args._D, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD); iteration++; MPI_Bcast(&currDiff, 1, MPI_FLOAT, 0, MPI_COMM_WORLD); } - if (options.m_maxIter == 0) { - if (rank == 0) { - std::memcpy(args.centers, args.newTCenters, sizeof(T)*args._K*args._D); + if (options.m_maxIter == 0) + { + if (rank == 0) + { + std::memcpy(args.centers, args.newTCenters, sizeof(T) * args._K * args._D); } - MPI_Bcast(args.centers, args._K*args._D, type, 0, MPI_COMM_WORLD); + MPI_Bcast(args.centers, args._K * args._D, type, 0, MPI_COMM_WORLD); } - else { - if (rank == 0) { + else + { + if (rank == 0) + { for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, + args.counts[i], args.weightedCounts[i]); } } d = 0; - for (SizeType i = 0; i < data.R(); i++) localindices[i] = i; - std::vector myLimit(args._K, (options.m_hardcut == 0) ? data.R() : (SizeType)(options.m_hardcut * totalCount / size)); - std::memset(args.counts, 0, sizeof(SizeType)*args._K); + for (SizeType i = 0; i < data.R(); i++) + localindices[i] = i; + std::vector myLimit( + args._K, (options.m_hardcut == 0) ? data.R() : (SizeType)(options.m_hardcut * totalCount / size)); + std::memset(args.counts, 0, sizeof(SizeType) * args._K); args.ClearCounts(); args.ClearDists(0); - for (int i = 0; i < options.m_clusterassign - 1; i++) { - d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, i, false); - std::memcpy(myLimit.data(), args.counts, sizeof(SizeType)*args._K); + for (int i = 0; i < options.m_clusterassign - 1; i++) + { + d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, i, + false); + std::memcpy(myLimit.data(), args.counts, sizeof(SizeType) * args._K); MPI_Allreduce(MPI_IN_PLACE, args.counts, args._K, MPI_INT, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(MPI_IN_PLACE, args.weightedCounts, args._K, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); - if (rank == 0) { - LOG(Helper::LogLevel::LL_Info, "assign %d....................d:%f\n", i, d); + if (rank == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "assign %d....................d:%f\n", i, d); for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, + args.counts[i], args.weightedCounts[i]); } for (int k = 0; k < args._K; k++) if (totalCount > args.counts[k]) myLimit[k] += (SizeType)((totalCount - args.counts[k]) / size); } - d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, options.m_clusterassign - 1, true); - std::memcpy(args.newCounts, args.counts, sizeof(SizeType)*args._K); + d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, + options.m_clusterassign - 1, true); + std::memcpy(args.newCounts, args.counts, sizeof(SizeType) * args._K); MPI_Allreduce(args.newCounts, args.counts, args._K, MPI_INT, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(MPI_IN_PLACE, args.weightedCounts, args._K, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); MPI_Allreduce(&d, &currDist, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD); - if (label.Save(options.m_labels + "." + std::to_string(rank)) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to save labels.\n"); + if (label.Save(options.m_labels + "." + std::to_string(rank)) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save labels.\n"); exit(1); } - if (rank == 0) { + if (rank == 0) + { SaveCenters(args.centers, args._K, args._D, options.m_centers, options.m_lambda); - LOG(Helper::LogLevel::LL_Info, "final dist:%f\n", currDist); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "final dist:%f\n", currDist); for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Status, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Status, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], + args.weightedCounts[i]); } MPI_Barrier(MPI_COMM_WORLD); - if (options.m_outdir.compare("-") != 0) { - for (int i = 0; i < args._K; i++) { - if (i % size == rank) { - LOG(Helper::LogLevel::LL_Info, "Cluster %d start ......\n", i); + if (options.m_outdir.compare("-") != 0) + { + for (int i = 0; i < args._K; i++) + { + if (i % size == rank) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Cluster %d start ......\n", i); } noImprovement = 0; std::string vecfile = options.m_outdir + "/" + options.m_outfile + "." + std::to_string(i + 1); - if (fileexists(vecfile.c_str())) noImprovement = 1; + if (fileexists(vecfile.c_str())) + noImprovement = 1; MPI_Allreduce(MPI_IN_PLACE, &noImprovement, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); - if (noImprovement) continue; + if (noImprovement) + continue; - if (i % size == rank) { + if (i % size == rank) + { std::string vecfile = options.m_outdir + "/" + options.m_outfile + "." + std::to_string(i); std::string metafile = options.m_outdir + "/" + options.m_outmetafile + "." + std::to_string(i); - std::string metaindexfile = options.m_outdir + "/" + options.m_outmetaindexfile + "." + std::to_string(i); + std::string metaindexfile = + options.m_outdir + "/" + options.m_outmetaindexfile + "." + std::to_string(i); std::shared_ptr out = f_createIO(), metaout = f_createIO(), metaindexout = f_createIO(); - if (out == nullptr || !out->Initialize(vecfile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", vecfile.c_str()); + if (out == nullptr || !out->Initialize(vecfile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", vecfile.c_str()); exit(1); } - if (metaout == nullptr || !metaout->Initialize(metafile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metafile.c_str()); + if (metaout == nullptr || !metaout->Initialize(metafile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metafile.c_str()); exit(1); } - if (metaindexout == nullptr || !metaindexout->Initialize(metaindexfile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metaindexfile.c_str()); + if (metaindexout == nullptr || + !metaindexout->Initialize(metaindexfile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metaindexfile.c_str()); exit(1); } - CHECKIO(out, WriteBinary, sizeof(int), (char*)(&args.counts[i])); - CHECKIO(out, WriteBinary, sizeof(int), (char*)(&args._D)); - if (metas != nullptr) CHECKIO(metaindexout, WriteBinary, sizeof(int), (char*)(&args.counts[i])); + CHECKIO(out, WriteBinary, sizeof(int), (char *)(&args.counts[i])); + CHECKIO(out, WriteBinary, sizeof(int), (char *)(&args._D)); + if (metas != nullptr) + CHECKIO(metaindexout, WriteBinary, sizeof(int), (char *)(&args.counts[i])); std::uint64_t offset = 0; - T* recvbuf = args.newTCenters; + T *recvbuf = args.newTCenters; int recvmetabuflen = 200; - char* recvmetabuf = new char [recvmetabuflen]; - for (int j = 0; j < size; j++) { + char *recvmetabuf = new char[recvmetabuflen]; + for (int j = 0; j < size; j++) + { uint64_t offset_before = offset; - if (j != rank) { + if (j != rank) + { int recv = 0; MPI_Recv(&recv, 1, MPI_INT, j, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - for (int k = 0; k < recv; k++) { + for (int k = 0; k < recv; k++) + { MPI_Recv(recvbuf, args._D, type, j, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - CHECKIO(out, WriteBinary, sizeof(T)* args._D, (char*)recvbuf); + CHECKIO(out, WriteBinary, sizeof(T) * args._D, (char *)recvbuf); - if (metas != nullptr) { + if (metas != nullptr) + { int len; MPI_Recv(&len, 1, MPI_INT, j, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE); - if (len > recvmetabuflen) { - LOG(Helper::LogLevel::LL_Info, "enlarge recv meta buf to %d\n", len); + if (len > recvmetabuflen) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "enlarge recv meta buf to %d\n", len); delete[] recvmetabuf; recvmetabuflen = len; recvmetabuf = new char[recvmetabuflen]; } MPI_Recv(recvmetabuf, len, MPI_CHAR, j, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); CHECKIO(metaout, WriteBinary, len, recvmetabuf); - CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char*)(&offset)); + CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char *)(&offset)); offset += len; } } - LOG(Helper::LogLevel::LL_Info, "rank %d <- rank %d: %d vectors, %llu bytes meta\n", rank, j, recv, (offset - offset_before)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rank %d <- rank %d: %d vectors, %llu bytes meta\n", + rank, j, recv, (offset - offset_before)); } - else { + else + { size_t total_rec = 0; - for (int k = 0; k < data.R(); k++) { - for (int kk = 0; kk < label.C(); kk++) { - if (label[k][kk] == (LabelType)i) { - CHECKIO(out, WriteBinary, sizeof(T) * args._D, (char*)(data[localindices[k]])); - if (metas != nullptr) { + for (int k = 0; k < data.R(); k++) + { + for (int kk = 0; kk < label.C(); kk++) + { + if (label[k][kk] == (LabelType)i) + { + CHECKIO(out, WriteBinary, sizeof(T) * args._D, (char *)(data[localindices[k]])); + if (metas != nullptr) + { ByteArray meta = metas->GetMetadata(localindices[k]); - CHECKIO(metaout, WriteBinary, meta.Length(), (const char*)meta.Data()); - CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char*)(&offset)); + CHECKIO(metaout, WriteBinary, meta.Length(), (const char *)meta.Data()); + CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char *)(&offset)); offset += meta.Length(); } total_rec++; } } } - LOG(Helper::LogLevel::LL_Info, "rank %d <- rank %d: %d(%d) vectors, %llu bytes meta\n", rank, j, args.newCounts[i], total_rec, (offset - offset_before)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rank %d <- rank %d: %d(%d) vectors, %llu bytes meta\n", + rank, j, args.newCounts[i], total_rec, (offset - offset_before)); } } delete[] recvmetabuf; - if (metas != nullptr) CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char*)(&offset)); + if (metas != nullptr) + CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char *)(&offset)); out->ShutDown(); metaout->ShutDown(); metaindexout->ShutDown(); } - else { + else + { int dest = i % size; MPI_Send(&args.newCounts[i], 1, MPI_INT, dest, 0, MPI_COMM_WORLD); size_t total_len = 0; size_t total_rec = 0; - for (int j = 0; j < data.R(); j++) { - for (int kk = 0; kk < label.C(); kk++) { - if (label[j][kk] == (LabelType)i) { + for (int j = 0; j < data.R(); j++) + { + for (int kk = 0; kk < label.C(); kk++) + { + if (label[j][kk] == (LabelType)i) + { MPI_Send(data[localindices[j]], args._D, type, dest, 1, MPI_COMM_WORLD); - if (metas != nullptr) { + if (metas != nullptr) + { ByteArray meta = metas->GetMetadata(localindices[j]); int len = (int)meta.Length(); MPI_Send(&len, 1, MPI_INT, dest, 2, MPI_COMM_WORLD); @@ -614,7 +780,8 @@ void Process(MPI_Datatype type) { } } } - LOG(Helper::LogLevel::LL_Info, "rank %d -> rank %d: %d(%d) vectors, %llu bytes meta\n", rank, dest, args.newCounts[i], total_rec, total_len); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rank %d -> rank %d: %d(%d) vectors, %llu bytes meta\n", rank, + dest, args.newCounts[i], total_rec, total_len); } MPI_Barrier(MPI_COMM_WORLD); } @@ -623,66 +790,109 @@ void Process(MPI_Datatype type) { } template -ErrorCode SyncSaveCenter(COMMON::KmeansArgs &args, int rank, int iteration, unsigned long long localCount, float localDist, float lambda, float diff, float mindist, int noimprovement, int savecenters, bool assign = false) +ErrorCode SyncSaveCenter(COMMON::KmeansArgs &args, int rank, int iteration, unsigned long long localCount, + float localDist, float lambda, float diff, float mindist, int noimprovement, int savecenters, + bool assign = false) { - if (!direxists(options.m_status.c_str())) mkdir(options.m_status.c_str()); + if (!direxists(options.m_status.c_str())) + mkdir(options.m_status.c_str()); std::string folder = options.m_status + FolderSep + std::to_string(iteration); - if (!direxists(folder.c_str())) mkdir(folder.c_str()); + if (!direxists(folder.c_str())) + mkdir(folder.c_str()); - if (!direxists(folder.c_str())) { - LOG(Helper::LogLevel::LL_Error, "Cannot create the folder %s.\n", folder.c_str()); + if (!direxists(folder.c_str())) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot create the folder %s.\n", folder.c_str()); exit(1); } - if (rank == 0 && savecenters > 0) { - SaveCenters(args.newTCenters, args._K, args._D, folder + FolderSep + "centers.bin", lambda, diff, mindist, noimprovement); + if (rank == 0 && savecenters > 0) + { + SaveCenters(args.newTCenters, args._K, args._D, folder + FolderSep + "centers.bin", lambda, diff, mindist, + noimprovement); } std::string savePath = folder + FolderSep + "status." + std::to_string(iteration) + "." + std::to_string(rank); auto out = f_createIO(); - if (out == nullptr || !out->Initialize(savePath.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write status.\n", savePath.c_str()); + if (out == nullptr || !out->Initialize(savePath.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write status.\n", savePath.c_str()); exit(1); } - CHECKIO(out, WriteBinary, sizeof(unsigned long long), (const char*)&localCount); - CHECKIO(out, WriteBinary, sizeof(float), (const char*)&localDist); - CHECKIO(out, WriteBinary, sizeof(float) * args._K * args._D, (const char*)args.newCenters); - if (assign) { - CHECKIO(out, WriteBinary, sizeof(int) * args._K, (const char*)args.counts); - CHECKIO(out, WriteBinary, sizeof(float) * args._K, (const char*)args.weightedCounts); + CHECKIO(out, WriteBinary, sizeof(unsigned long long), (const char *)&localCount); + CHECKIO(out, WriteBinary, sizeof(float), (const char *)&localDist); + CHECKIO(out, WriteBinary, sizeof(float) * args._K * args._D, (const char *)args.newCenters); + if (assign) + { + CHECKIO(out, WriteBinary, sizeof(int) * args._K, (const char *)args.counts); + CHECKIO(out, WriteBinary, sizeof(float) * args._K, (const char *)args.weightedCounts); } - else { - CHECKIO(out, WriteBinary, sizeof(int) * args._K, (const char*)args.newCounts); - CHECKIO(out, WriteBinary, sizeof(float) * args._K, (const char*)args.newWeightedCounts); + else + { + CHECKIO(out, WriteBinary, sizeof(int) * args._K, (const char *)args.newCounts); + CHECKIO(out, WriteBinary, sizeof(float) * args._K, (const char *)args.newWeightedCounts); } out->ShutDown(); - if (!options.m_syncscript.empty()) { - system((options.m_syncscript + " upload " + folder + " " + std::to_string(options.m_totalparts) + " " + std::to_string(savecenters)).c_str()); + if (!options.m_syncscript.empty()) + { + try + { + int return_value = system((options.m_syncscript + " upload " + folder + " " + + std::to_string(options.m_totalparts) + " " + std::to_string(savecenters)) + .c_str()); + if (return_value != 0) + throw std::system_error(errno, std::generic_category(), "error executing command"); + } + catch (const std::system_error &e) + { + std::cerr << "error executing command: " << options.m_syncscript << e.what() << '\n'; + return ErrorCode::Fail; + } } - else { - LOG(Helper::LogLevel::LL_Error, "Error: Sync script is empty.\n"); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Sync script is empty.\n"); } return ErrorCode::Success; } template -ErrorCode SyncLoadCenter(COMMON::KmeansArgs& args, int rank, int iteration, unsigned long long &totalCount, float &currDist, float &lambda, float &diff, float &mindist, int &noimprovement, bool loadcenters) +ErrorCode SyncLoadCenter(COMMON::KmeansArgs &args, int rank, int iteration, unsigned long long &totalCount, + float &currDist, float &lambda, float &diff, float &mindist, int &noimprovement, + bool loadcenters) { std::string folder = options.m_status + FolderSep + std::to_string(iteration); - //TODO download - if (!options.m_syncscript.empty()) { - system((options.m_syncscript + " download " + folder + " " + std::to_string(options.m_totalparts) + " " + std::to_string(loadcenters)).c_str()); + // TODO download + if (!options.m_syncscript.empty()) + { + try + { + int return_value = system((options.m_syncscript + " download " + folder + " " + + std::to_string(options.m_totalparts) + " " + std::to_string(loadcenters)) + .c_str()); + if (return_value != 0) + throw std::system_error(errno, std::generic_category(), "error executing command"); + } + catch (const std::system_error &e) + { + std::cerr << "error executing command: " << options.m_syncscript << e.what() << '\n'; + return ErrorCode::Fail; + } } - else { - LOG(Helper::LogLevel::LL_Error, "Error: Sync script is empty.\n"); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Sync script is empty.\n"); } - if (loadcenters) { - if (!LoadCenters(args.newTCenters, args._K, args._D, folder + FolderSep + "centers.bin", &lambda, &diff, &mindist, &noimprovement)) { - LOG(Helper::LogLevel::LL_Error, "Cannot load centers.\n"); + if (loadcenters) + { + if (!LoadCenters(args.newTCenters, args._K, args._D, folder + FolderSep + "centers.bin", &lambda, &diff, + &mindist, &noimprovement)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot load centers.\n"); exit(1); } } @@ -696,79 +906,105 @@ ErrorCode SyncLoadCenter(COMMON::KmeansArgs& args, int rank, int iteration, u totalCount = 0; currDist = 0; - for (int part = 0; part < options.m_totalparts; part++) { + for (int part = 0; part < options.m_totalparts; part++) + { std::string loadPath = folder + FolderSep + "status." + std::to_string(iteration) + "." + std::to_string(part); auto input = f_createIO(); - if (input == nullptr || !input->Initialize(loadPath.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to read status.", loadPath.c_str()); + if (input == nullptr || !input->Initialize(loadPath.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to read status.", loadPath.c_str()); exit(1); } - CHECKIO(input, ReadBinary, sizeof(unsigned long long), (char*)&localCount); + CHECKIO(input, ReadBinary, sizeof(unsigned long long), (char *)&localCount); totalCount += localCount; - CHECKIO(input, ReadBinary, sizeof(float), (char*)&localDist); + CHECKIO(input, ReadBinary, sizeof(float), (char *)&localDist); currDist += localDist; CHECKIO(input, ReadBinary, sizeof(float) * args._K * args._D, buf.get()); - for (int i = 0; i < args._K * args._D; i++) args.newCenters[i] += *((float*)(buf.get()) + i); + for (int i = 0; i < args._K * args._D; i++) + args.newCenters[i] += *((float *)(buf.get()) + i); CHECKIO(input, ReadBinary, sizeof(int) * args._K, buf.get()); - for (int i = 0; i < args._K; i++) { - int partsize = *((int*)(buf.get()) + i); - if (partsize >= 0 && args.counts[i] <= MaxSize - partsize) args.counts[i] += partsize; - else { - LOG(Helper::LogLevel::LL_Error, "Cluster %d counts overflow:%d + %d(%d)! Set it to MaxSize.\n", i, args.counts[i], partsize, part); + for (int i = 0; i < args._K; i++) + { + int partsize = *((int *)(buf.get()) + i); + if (partsize >= 0 && args.counts[i] <= MaxSize - partsize) + args.counts[i] += partsize; + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cluster %d counts overflow:%d + %d(%d)! Set it to MaxSize.\n", + i, args.counts[i], partsize, part); args.counts[i] = MaxSize; } } CHECKIO(input, ReadBinary, sizeof(float) * args._K, buf.get()); - for (int i = 0; i < args._K; i++) args.weightedCounts[i] += *((float*)(buf.get()) + i); + for (int i = 0; i < args._K; i++) + args.weightedCounts[i] += *((float *)(buf.get()) + i); } return ErrorCode::Success; } -template -void ProcessWithoutMPI() { +template void ProcessWithoutMPI() +{ std::string rankstr = options.m_labels.substr(options.m_labels.rfind(".") + 1); int rank = std::stoi(rankstr); - LOG(Helper::LogLevel::LL_Info, "DEBUG:rank--%d labels--%s\n", rank, options.m_labels.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "DEBUG:rank--%d labels--%s\n", rank, options.m_labels.c_str()); auto vectorReader = Helper::VectorSetReader::CreateInstance(std::make_shared(options)); options.m_inputFiles = Helper::StrUtils::ReplaceAll(options.m_inputFiles, "*", std::to_string(rank)); if (ErrorCode::Success != vectorReader->LoadFile(options.m_inputFiles)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); exit(1); } std::shared_ptr vectors = vectorReader->GetVectorSet(); std::shared_ptr metas = vectorReader->GetMetadataSet(); - if (options.m_distMethod == DistCalcMethod::Cosine) vectors->Normalize(options.m_threadNum); + if (vectors->Dimension() != options.m_dimension) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "vector dimension %d is not equal to the dimension %d of the option.\n", vectors->Dimension(), + options.m_dimension); + exit(1); + } + if (options.m_distMethod == DistCalcMethod::Cosine) + vectors->Normalize(options.m_threadNum); std::vector weights(vectors->Count(), 0.0f); - if (options.m_weightfile.compare("-") != 0) { + if (options.m_weightfile.compare("-") != 0) + { options.m_weightfile = Helper::StrUtils::ReplaceAll(options.m_weightfile, "*", std::to_string(rank)); std::ifstream win(options.m_weightfile, std::ifstream::binary); - if (!win.is_open()) { - LOG(Helper::LogLevel::LL_Error, "Rank %d failed to read weight file %s.\n", rank, options.m_weightfile.c_str()); + if (!win.is_open()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Rank %d failed to read weight file %s.\n", rank, + options.m_weightfile.c_str()); exit(1); } SizeType rows; - win.read((char*)&rows, sizeof(SizeType)); - if (rows != vectors->Count()) { + win.read((char *)&rows, sizeof(SizeType)); + if (rows != vectors->Count()) + { win.close(); - LOG(Helper::LogLevel::LL_Error, "Number of weights (%d) is not equal to number of vectors (%d).\n", rows, vectors->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Number of weights (%d) is not equal to number of vectors (%d).\n", + rows, vectors->Count()); exit(1); } - win.read((char*)weights.data(), sizeof(float) * rows); + win.read((char *)weights.data(), sizeof(float) * rows); win.close(); } - COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024*1024, vectors->Count() + 1, (T*)vectors->GetData()); - COMMON::KmeansArgs args(options.m_clusterNum, vectors->Dimension(), vectors->Count(), options.m_threadNum, options.m_distMethod); + COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024 * 1024, vectors->Count() + 1, + (T *)vectors->GetData()); + COMMON::KmeansArgs args(options.m_clusterNum, vectors->Dimension(), vectors->Count(), options.m_threadNum, + options.m_distMethod); COMMON::Dataset label(vectors->Count(), options.m_clusterassign, vectors->Count(), vectors->Count()); std::vector localindices(data.R(), 0); - for (SizeType i = 0; i < data.R(); i++) localindices[i] = i; + for (SizeType i = 0; i < data.R(); i++) + { + localindices[i] = i; + } args.ClearCounts(); unsigned long long totalCount; @@ -776,173 +1012,234 @@ void ProcessWithoutMPI() { int iteration = options.m_recoveriter; int noImprovement = 0; - if (rank == 0 && iteration < 0) { - LOG(Helper::LogLevel::LL_Info, "rank 0 init centers\n"); - if (!LoadCenters(args.newTCenters, args._K, args._D, options.m_centers, &(options.m_lambda))) { - if (options.m_seed >= 0) std::srand(options.m_seed); - if (options.m_maxIter > 0 && options.m_lambda < -1e-6f) { + if (rank == 0 && iteration < 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "rank 0 init centers\n"); + if (!LoadCenters(args.newTCenters, args._K, args._D, options.m_centers, &(options.m_lambda))) + { + if (options.m_seed >= 0) + std::srand(options.m_seed); + if (options.m_maxIter > 0 && options.m_lambda < -1e-6f) + { float fBalanceFactor = COMMON::DynamicFactorSelect(data, localindices, 0, data.R(), args, data.R()); - options.m_lambda = COMMON::Utils::GetBase() * COMMON::Utils::GetBase() / fBalanceFactor / data.R(); + options.m_lambda = + COMMON::Utils::GetBase() * COMMON::Utils::GetBase() / fBalanceFactor / data.R(); } - COMMON::InitCenters(data, localindices, 0, data.R(), args, options.m_localSamples, options.m_initIter); + COMMON::InitCenters(data, localindices, 0, data.R(), args, options.m_localSamples, + options.m_initIter); } } - if (iteration < 0) { + if (iteration < 0) + { iteration = 0; - SyncSaveCenter(args, rank, iteration, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, 2); + SyncSaveCenter(args, rank, iteration, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, + 2); } - else { - LOG(Helper::LogLevel::LL_Info, "recover from iteration:%d\n", iteration); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "recover from iteration:%d\n", iteration); } - SyncLoadCenter(args, rank, iteration, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, noImprovement, true); + SyncLoadCenter(args, rank, iteration, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, + noImprovement, true); - LOG(Helper::LogLevel::LL_Info, "rank %d data:(%d,%d) machines:%d clusters:%d type:%d threads:%d lambda:%f samples:%d maxcountperpartition:%d\n", - rank, data.R(), data.C(), options.m_totalparts, options.m_clusterNum, ((int)options.m_inputValueType), options.m_threadNum, options.m_lambda, options.m_localSamples, static_cast(totalCount * 1.0 / args._K * options.m_vectorfactor)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "rank %d data:(%d,%d) machines:%d clusters:%d type:%d threads:%d lambda:%f samples:%d " + "maxcountperpartition:%d\n", + rank, data.R(), data.C(), options.m_totalparts, options.m_clusterNum, ((int)options.m_inputValueType), + options.m_threadNum, options.m_lambda, options.m_localSamples, + static_cast(totalCount * 1.0 / args._K * options.m_vectorfactor)); - while (noImprovement < 10 && currDiff > options.m_stopDifference && iteration < options.m_maxIter) { + while (noImprovement < 10 && currDiff > options.m_stopDifference && iteration < options.m_maxIter) + { std::memcpy(args.centers, args.newTCenters, sizeof(T) * args._K * args._D); args.ClearCenters(); args.ClearCounts(); args.ClearDists(-MaxDist); - d = MultipleClustersAssign(data, localindices, 0, data.R(), args, label, true, (iteration == 0) ? 0.0f : options.m_lambda, weights, (iteration == 0) ? 0.0f : options.m_wlambda); + d = MultipleClustersAssign(data, localindices, 0, data.R(), args, label, true, + (iteration == 0) ? 0.0f : options.m_lambda, weights, + (iteration == 0) ? 0.0f : options.m_wlambda); - SyncSaveCenter(args, rank, iteration + 1, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, 0); - if (rank == 0) { - SyncLoadCenter(args, rank, iteration + 1, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, noImprovement, false); + SyncSaveCenter(args, rank, iteration + 1, data.R(), d, options.m_lambda, currDiff, minClusterDist, + noImprovement, 0); + if (rank == 0) + { + SyncLoadCenter(args, rank, iteration + 1, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, + noImprovement, false); currDiff = COMMON::RefineCenters(data, args); - if (currDist < minClusterDist) { + if (currDist < minClusterDist) + { noImprovement = 0; minClusterDist = currDist; } - else { + else + { noImprovement++; } - SyncSaveCenter(args, rank, iteration + 1, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, 1); + SyncSaveCenter(args, rank, iteration + 1, data.R(), d, options.m_lambda, currDiff, minClusterDist, + noImprovement, 1); } - else { - SyncLoadCenter(args, rank, iteration + 1, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, noImprovement, true); + else + { + SyncLoadCenter(args, rank, iteration + 1, totalCount, currDist, options.m_lambda, currDiff, minClusterDist, + noImprovement, true); } iteration++; - LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iteration, currDist, currDiff); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "iter %d dist:%f diff:%f\n", iteration, currDist, currDiff); } - if (options.m_maxIter == 0) { + if (options.m_maxIter == 0) + { std::memcpy(args.centers, args.newTCenters, sizeof(T) * args._K * args._D); } - else { - if (rank == 0) { + else + { + if (rank == 0) + { for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, + args.counts[i], args.weightedCounts[i]); } } d = 0; totalCount = static_cast(totalCount * 1.0 / args._K * options.m_vectorfactor); unsigned long long tmpTotalCount; - for (SizeType i = 0; i < data.R(); i++) localindices[i] = i; - std::vector myLimit(args._K, (options.m_hardcut == 0)? data.R() : (SizeType)(options.m_hardcut * totalCount / options.m_totalparts)); + for (SizeType i = 0; i < data.R(); i++) + localindices[i] = i; + std::vector myLimit(args._K, (options.m_hardcut == 0) + ? data.R() + : (SizeType)(options.m_hardcut * totalCount / options.m_totalparts)); std::memset(args.counts, 0, sizeof(SizeType) * args._K); args.ClearCounts(); args.ClearDists(0); - for (int i = 0; i < options.m_clusterassign - 1; i++) { - d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, i, false); + for (int i = 0; i < options.m_clusterassign - 1; i++) + { + d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, i, + false); std::memcpy(myLimit.data(), args.counts, sizeof(SizeType) * args._K); - SyncSaveCenter(args, rank, 10000 + iteration + 1 + i, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, 0, true); - SyncLoadCenter(args, rank, 10000 + iteration + 1 + i, tmpTotalCount, currDist, options.m_lambda, currDiff, minClusterDist, noImprovement, false); - if (rank == 0) { - LOG(Helper::LogLevel::LL_Info, "assign %d....................d:%f\n", i, d); + SyncSaveCenter(args, rank, 10000 + iteration + 1 + i, data.R(), d, options.m_lambda, currDiff, minClusterDist, + noImprovement, 0, true); + SyncLoadCenter(args, rank, 10000 + iteration + 1 + i, tmpTotalCount, currDist, options.m_lambda, currDiff, + minClusterDist, noImprovement, false); + if (rank == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "assign %d....................d:%f\n", i, d); for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "cluster %d contains vectors:%d weights:%f\n", i, + args.counts[i], args.weightedCounts[i]); } for (int k = 0; k < args._K; k++) if (totalCount > args.counts[k]) myLimit[k] += (SizeType)((totalCount - args.counts[k]) / options.m_totalparts); } - d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, options.m_clusterassign - 1, true); + d += HardMultipleClustersAssign(data, localindices, 0, data.R(), args, label, myLimit.data(), weights, + options.m_clusterassign - 1, true); std::memcpy(args.newCounts, args.counts, sizeof(SizeType) * args._K); - SyncSaveCenter(args, rank, 10000 + iteration + options.m_clusterassign, data.R(), d, options.m_lambda, currDiff, minClusterDist, noImprovement, 0, true); - SyncLoadCenter(args, rank, 10000 + iteration + options.m_clusterassign, tmpTotalCount, currDist, options.m_lambda, currDiff, minClusterDist, noImprovement, false); + SyncSaveCenter(args, rank, 10000 + iteration + options.m_clusterassign, data.R(), d, options.m_lambda, currDiff, + minClusterDist, noImprovement, 0, true); + SyncLoadCenter(args, rank, 10000 + iteration + options.m_clusterassign, tmpTotalCount, currDist, options.m_lambda, + currDiff, minClusterDist, noImprovement, false); - if (label.Save(options.m_labels) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to save labels.\n"); + if (label.Save(options.m_labels) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save labels.\n"); exit(1); } - if (rank == 0) { + if (rank == 0) + { SaveCenters(args.centers, args._K, args._D, options.m_centers, options.m_lambda); - LOG(Helper::LogLevel::LL_Info, "final dist:%f\n", currDist); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "final dist:%f\n", currDist); for (int i = 0; i < args._K; i++) - LOG(Helper::LogLevel::LL_Status, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], args.weightedCounts[i]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Status, "cluster %d contains vectors:%d weights:%f\n", i, args.counts[i], + args.weightedCounts[i]); } } -template -void Partition() { - if (options.m_outdir.compare("-") == 0) return; +template void Partition() +{ + if (options.m_outdir.compare("-") == 0) + return; auto vectorReader = Helper::VectorSetReader::CreateInstance(std::make_shared(options)); if (ErrorCode::Success != vectorReader->LoadFile(options.m_inputFiles)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); exit(1); } std::shared_ptr vectors = vectorReader->GetVectorSet(); std::shared_ptr metas = vectorReader->GetMetadataSet(); - if (options.m_distMethod == DistCalcMethod::Cosine) vectors->Normalize(options.m_threadNum); + if (options.m_distMethod == DistCalcMethod::Cosine) + vectors->Normalize(options.m_threadNum); - COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024*1024, vectors->Count() + 1, (T*)vectors->GetData()); + COMMON::Dataset data(vectors->Count(), vectors->Dimension(), 1024 * 1024, vectors->Count() + 1, + (T *)vectors->GetData()); COMMON::Dataset label; - if (label.Load(options.m_labels, vectors->Count(), vectors->Count()) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to read labels.\n"); + if (label.Load(options.m_labels, vectors->Count(), vectors->Count()) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read labels.\n"); exit(1); } std::string taskId = options.m_labels.substr(options.m_labels.rfind(".") + 1); - for (int i = 0; i < options.m_clusterNum; i++) { + for (int i = 0; i < options.m_clusterNum; i++) + { std::string vecfile = options.m_outdir + "/" + options.m_outfile + "." + taskId + "." + std::to_string(i); std::string metafile = options.m_outdir + "/" + options.m_outmetafile + "." + taskId + "." + std::to_string(i); - std::string metaindexfile = options.m_outdir + "/" + options.m_outmetaindexfile + "." + taskId + "." + std::to_string(i); + std::string metaindexfile = + options.m_outdir + "/" + options.m_outmetaindexfile + "." + taskId + "." + std::to_string(i); std::shared_ptr out = f_createIO(), metaout = f_createIO(), metaindexout = f_createIO(); - if (out == nullptr || !out->Initialize(vecfile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", vecfile.c_str()); + if (out == nullptr || !out->Initialize(vecfile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", vecfile.c_str()); exit(1); } - if (metaout == nullptr || !metaout->Initialize(metafile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metafile.c_str()); + if (metaout == nullptr || !metaout->Initialize(metafile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metafile.c_str()); exit(1); } - if (metaindexout == nullptr || !metaindexout->Initialize(metaindexfile.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metaindexfile.c_str()); + if (metaindexout == nullptr || + !metaindexout->Initialize(metaindexfile.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open %s to write.\n", metaindexfile.c_str()); exit(1); } int rows = data.R(), cols = data.C(); - CHECKIO(out, WriteBinary, sizeof(int), (char*)(&rows)); - CHECKIO(out, WriteBinary, sizeof(int), (char*)(&cols)); - if (metas != nullptr) CHECKIO(metaindexout, WriteBinary, sizeof(int), (char*)(&rows)); + CHECKIO(out, WriteBinary, sizeof(int), (char *)(&rows)); + CHECKIO(out, WriteBinary, sizeof(int), (char *)(&cols)); + if (metas != nullptr) + CHECKIO(metaindexout, WriteBinary, sizeof(int), (char *)(&rows)); std::uint64_t offset = 0; int records = 0; - for (int k = 0; k < data.R(); k++) { - for (int kk = 0; kk < label.C(); kk++) { - if (label[k][kk] == (LabelType)i) { - CHECKIO(out, WriteBinary, sizeof(T) * cols, (char*)(data[k])); - if (metas != nullptr) { + for (int k = 0; k < data.R(); k++) + { + for (int kk = 0; kk < label.C(); kk++) + { + if (label[k][kk] == (LabelType)i) + { + CHECKIO(out, WriteBinary, sizeof(T) * cols, (char *)(data[k])); + if (metas != nullptr) + { ByteArray meta = metas->GetMetadata(k); - CHECKIO(metaout, WriteBinary, meta.Length(), (const char*)meta.Data()); - CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char*)(&offset)); + CHECKIO(metaout, WriteBinary, meta.Length(), (const char *)meta.Data()); + CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char *)(&offset)); offset += meta.Length(); } records++; } } } - LOG(Helper::LogLevel::LL_Info, "part %s cluster %d: %d vectors, %llu bytes meta.\n", taskId.c_str(), i, records, offset); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "part %s cluster %d: %d vectors, %llu bytes meta.\n", taskId.c_str(), i, + records, offset); - if (metas != nullptr) CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char*)(&offset)); - CHECKIO(out, WriteBinary, sizeof(int), (char*)(&records), 0); - CHECKIO(metaindexout, WriteBinary, sizeof(int), (char*)(&records), 0); + if (metas != nullptr) + CHECKIO(metaindexout, WriteBinary, sizeof(std::uint64_t), (char *)(&offset)); + CHECKIO(out, WriteBinary, sizeof(int), (char *)(&records), 0); + CHECKIO(metaindexout, WriteBinary, sizeof(int), (char *)(&records), 0); out->ShutDown(); metaout->ShutDown(); @@ -950,14 +1247,17 @@ void Partition() { } } -int main(int argc, char* argv[]) { +int main(int argc, char *argv[]) +{ if (!options.Parse(argc - 1, argv + 1)) { exit(1); } - if (options.m_stage.compare("Clustering") == 0) { - switch (options.m_inputValueType) { + if (options.m_stage.compare("Clustering") == 0) + { + switch (options.m_inputValueType) + { case SPTAG::VectorValueType::Float: Process(MPI_FLOAT); break; @@ -971,11 +1271,13 @@ int main(int argc, char* argv[]) { Process(MPI_CHAR); break; default: - LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); } } - else if (options.m_stage.compare("ClusteringWithoutMPI") == 0) { - switch (options.m_inputValueType) { + else if (options.m_stage.compare("ClusteringWithoutMPI") == 0) + { + switch (options.m_inputValueType) + { case SPTAG::VectorValueType::Float: ProcessWithoutMPI(); break; @@ -989,11 +1291,13 @@ int main(int argc, char* argv[]) { ProcessWithoutMPI(); break; default: - LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); } } - else if (options.m_stage.compare("LocalPartition") == 0) { - switch (options.m_inputValueType) { + else if (options.m_stage.compare("LocalPartition") == 0) + { + switch (options.m_inputValueType) + { case SPTAG::VectorValueType::Float: Partition(); break; @@ -1007,7 +1311,7 @@ int main(int argc, char* argv[]) { Partition(); break; default: - LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); } } return 0; diff --git a/AnnService/src/Client/ClientWrapper.cpp b/AnnService/src/Client/ClientWrapper.cpp index c868b9efd..6fe27c40a 100644 --- a/AnnService/src/Client/ClientWrapper.cpp +++ b/AnnService/src/Client/ClientWrapper.cpp @@ -7,15 +7,11 @@ using namespace SPTAG; using namespace SPTAG::Socket; using namespace SPTAG::Client; -ClientWrapper::ClientWrapper(const ClientOptions& p_options) - : m_options(p_options), - m_unfinishedJobCount(0), - m_isWaitingFinish(false) +ClientWrapper::ClientWrapper(const ClientOptions &p_options) + : m_options(p_options), m_unfinishedJobCount(0), m_isWaitingFinish(false) { m_client.reset(new SPTAG::Socket::Client(GetHandlerMap(), p_options.m_socketThreadNum, 30)); - m_client->SetEventOnConnectionClose(std::bind(&ClientWrapper::HandleDeadConnection, - this, - std::placeholders::_1)); + m_client->SetEventOnConnectionClose(std::bind(&ClientWrapper::HandleDeadConnection, this, std::placeholders::_1)); m_connections.reserve(m_options.m_threadNum); for (std::uint32_t i = 0; i < m_options.m_threadNum; ++i) @@ -25,7 +21,7 @@ ClientWrapper::ClientWrapper(const ClientOptions& p_options) conn.first = m_client->ConnectToServer(p_options.m_serverAddr, p_options.m_serverPort, errCode); if (SPTAG::ErrorCode::Socket_FailedResolveEndPoint == errCode) { - LOG(Helper::LogLevel::LL_Error, "Unable to resolve remote address.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to resolve remote address.\n"); return; } @@ -36,16 +32,12 @@ ClientWrapper::ClientWrapper(const ClientOptions& p_options) } } - ClientWrapper::~ClientWrapper() { } - -void -ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, - Callback p_callback, - const ClientOptions& p_options) +void ClientWrapper::SendQueryAsync(const Socket::RemoteQuery &p_query, Callback p_callback, + const ClientOptions &p_options) { if (!bool(p_callback)) { @@ -54,8 +46,7 @@ ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, auto conn = GetConnection(); - auto timeoutCallback = [this](std::shared_ptr p_callback) - { + auto timeoutCallback = [this](std::shared_ptr p_callback) { DecreaseUnfnishedJobCount(); if (nullptr != p_callback) { @@ -66,9 +57,7 @@ ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, } }; - - auto connectCallback = [p_callback, this](bool p_connectSucc) - { + auto connectCallback = [p_callback, this](bool p_connectSucc) { if (!p_connectSucc) { Socket::RemoteSearchResult result; @@ -84,8 +73,7 @@ ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, packet.Header().m_packetType = PacketType::SearchRequest; packet.Header().m_processStatus = PacketProcessStatus::Ok; packet.Header().m_resourceID = m_callbackManager.Add(std::make_shared(std::move(p_callback)), - p_options.m_searchTimeout, - std::move(timeoutCallback)); + p_options.m_searchTimeout, std::move(timeoutCallback)); packet.Header().m_bodyLength = static_cast(p_query.EstimateBufferSize()); packet.AllocateBuffer(packet.Header().m_bodyLength); @@ -96,9 +84,7 @@ ClientWrapper::SendQueryAsync(const Socket::RemoteQuery& p_query, m_client->SendPacket(conn.first, std::move(packet), connectCallback); } - -void -ClientWrapper::WaitAllFinished() +void ClientWrapper::WaitAllFinished() { if (m_unfinishedJobCount > 0) { @@ -111,36 +97,28 @@ ClientWrapper::WaitAllFinished() } } - -PacketHandlerMapPtr -ClientWrapper::GetHandlerMap() +PacketHandlerMapPtr ClientWrapper::GetHandlerMap() { PacketHandlerMapPtr handlerMap(new PacketHandlerMap); handlerMap->emplace(PacketType::RegisterResponse, - [this](ConnectionID p_localConnectionID, Packet p_packet) -> void - { - for (auto& conn : m_connections) - { - if (conn.first == p_localConnectionID) - { - conn.second = p_packet.Header().m_connectionID; - return; - } - } - }); - - handlerMap->emplace(PacketType::SearchResponse, - std::bind(&ClientWrapper::SearchResponseHanlder, - this, - std::placeholders::_1, - std::placeholders::_2)); + [this](ConnectionID p_localConnectionID, Packet p_packet) -> void { + for (auto &conn : m_connections) + { + if (conn.first == p_localConnectionID) + { + conn.second = p_packet.Header().m_connectionID; + return; + } + } + }); + + handlerMap->emplace(PacketType::SearchResponse, std::bind(&ClientWrapper::SearchResponseHanlder, this, + std::placeholders::_1, std::placeholders::_2)); return handlerMap; } - -void -ClientWrapper::DecreaseUnfnishedJobCount() +void ClientWrapper::DecreaseUnfnishedJobCount() { --m_unfinishedJobCount; if (0 == m_unfinishedJobCount) @@ -154,9 +132,7 @@ ClientWrapper::DecreaseUnfnishedJobCount() } } - -const ClientWrapper::ConnectionPair& -ClientWrapper::GetConnection() +const ClientWrapper::ConnectionPair &ClientWrapper::GetConnection() { if (m_connections.size() == 1) { @@ -174,9 +150,7 @@ ClientWrapper::GetConnection() return m_connections[pos]; } - -void -ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +void ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) { std::shared_ptr callback = m_callbackManager.GetAndRemove(p_packet.Header().m_resourceID); if (nullptr == callback) @@ -201,11 +175,9 @@ ClientWrapper::SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, S DecreaseUnfnishedJobCount(); } - -void -ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid) +void ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid) { - for (auto& conn : m_connections) + for (auto &conn : m_connections) { if (conn.first == p_cid) { @@ -227,9 +199,7 @@ ClientWrapper::HandleDeadConnection(Socket::ConnectionID p_cid) } } - -bool -ClientWrapper::IsAvailable() const +bool ClientWrapper::IsAvailable() const { return !m_connections.empty(); } diff --git a/AnnService/src/Client/Options.cpp b/AnnService/src/Client/Options.cpp index bb067d3d5..c1697b21d 100644 --- a/AnnService/src/Client/Options.cpp +++ b/AnnService/src/Client/Options.cpp @@ -9,10 +9,7 @@ using namespace SPTAG; using namespace SPTAG::Client; -ClientOptions::ClientOptions() - : m_searchTimeout(9000), - m_threadNum(1), - m_socketThreadNum(2) +ClientOptions::ClientOptions() : m_searchTimeout(9000), m_threadNum(1), m_socketThreadNum(2) { AddRequiredOption(m_serverAddr, "-s", "--server", "Server address."); AddRequiredOption(m_serverPort, "-p", "--port", "Server port."); @@ -21,7 +18,6 @@ ClientOptions::ClientOptions() AddOptionalOption(m_socketThreadNum, "-sth", "", "Socket Thread Number."); } - ClientOptions::~ClientOptions() { } diff --git a/AnnService/src/Client/main.cpp b/AnnService/src/Client/main.cpp index 088716d4b..4eb8477a4 100644 --- a/AnnService/src/Client/main.cpp +++ b/AnnService/src/Client/main.cpp @@ -1,18 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Client/Options.h" #include "inc/Client/ClientWrapper.h" +#include "inc/Client/Options.h" -#include #include +#include #include using namespace SPTAG; std::unique_ptr g_client; -int main(int argc, char** argv) +int main(int argc, char **argv) { SPTAG::Client::ClientOptions options; if (!options.Parse(argc - 1, argv + 1)) @@ -27,7 +27,7 @@ int main(int argc, char** argv) } g_client->WaitAllFinished(); - LOG(Helper::LogLevel::LL_Info, "connection done\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "connection done\n"); std::string line; std::cout << "Query: " << std::flush; @@ -43,29 +43,26 @@ int main(int argc, char** argv) query.m_queryString = std::move(line); SPTAG::Socket::RemoteSearchResult result; - auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) - { - result = std::move(p_result); - }; + auto callback = [&result](SPTAG::Socket::RemoteSearchResult p_result) { result = std::move(p_result); }; g_client->SendQueryAsync(query, callback, options); g_client->WaitAllFinished(); std::cout << "Status: " << static_cast(result.m_status) << std::endl; - for (const auto& indexRes : result.m_allIndexResults) + for (const auto &indexRes : result.m_allIndexResults) { std::cout << "Index: " << indexRes.m_indexName << std::endl; int idx = 0; - for (const auto& res : indexRes.m_results) + for (const auto &res : indexRes.m_results) { std::cout << "------------------" << std::endl; std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; if (indexRes.m_results.WithMeta()) { - const auto& metadata = indexRes.m_results.GetMetadata(idx); - std::cout << " MetaData: " << std::string((char*)metadata.Data(), metadata.Length()); + const auto &metadata = indexRes.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length()); } std::cout << std::endl; ++idx; @@ -77,4 +74,3 @@ int main(int argc, char** argv) return 0; } - diff --git a/AnnService/src/Core/BKT/BKTIndex.cpp b/AnnService/src/Core/BKT/BKTIndex.cpp index f9ef95f6f..603272a30 100644 --- a/AnnService/src/Core/BKT/BKTIndex.cpp +++ b/AnnService/src/Core/BKT/BKTIndex.cpp @@ -2,133 +2,172 @@ // Licensed under the MIT License. #include "inc/Core/BKT/Index.h" +#include "inc/Core/ResultIterator.h" #include -#pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4244) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4127) // conditional expression is constant +#pragma warning(disable : 4242) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4244) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4127) // conditional expression is constant namespace SPTAG { - namespace BKT - { - template - ErrorCode Index::LoadConfig(Helper::IniReader& p_reader) - { -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - SetParameter(RepresentStr, \ - p_reader.GetParameter("Index", \ - RepresentStr, \ - std::string(#DefaultValue)).c_str()); \ +template thread_local std::unique_ptr COMMON::ThreadLocalWorkSpaceFactory::m_workspace; + +namespace BKT +{ + +template ErrorCode Index::LoadConfig(Helper::IniReader &p_reader) +{ +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, p_reader.GetParameter("Index", RepresentStr, std::string(#DefaultValue)).c_str()); #include "inc/Core/BKT/ParameterDefinitionList.h" #undef DefineBKTParameter - return ErrorCode::Success; - } + return ErrorCode::Success; +} - template <> - void Index::SetQuantizer(std::shared_ptr quantizer) - { - m_pQuantizer = quantizer; - m_pTrees.m_pQuantizer = quantizer; - if (m_pQuantizer) - { - m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; - } - else - { - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; - } +template <> void Index::SetQuantizer(std::shared_ptr quantizer) +{ + m_pQuantizer = quantizer; + m_pTrees.m_pQuantizer = quantizer; + if (m_pQuantizer) + { + m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod); + m_iBaseSquare = + (m_iDistCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; + } + else + { + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) + ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() + : 1; + } +} - } +template void Index::SetQuantizer(std::shared_ptr quantizer) +{ + m_pQuantizer = quantizer; + m_pTrees.m_pQuantizer = quantizer; + if (quantizer) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Set non-null quantizer for index with data type other than BYTE"); + } +} - template - void Index::SetQuantizer(std::shared_ptr quantizer) - { - m_pQuantizer = quantizer; - m_pTrees.m_pQuantizer = quantizer; - if (quantizer) - { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Set non-null quantizer for index with data type other than BYTE"); - } - } +template ErrorCode Index::LoadIndexDataFromMemory(const std::vector &p_indexBlobs) +{ + if (p_indexBlobs.size() < 3) + return ErrorCode::LackOfInputs; + + if (m_pSamples.Load((char *)p_indexBlobs[0].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (m_pTrees.LoadTrees((char *)p_indexBlobs[1].Data()) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (m_pGraph.LoadGraph((char *)p_indexBlobs[2].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (p_indexBlobs.size() <= 3) + m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + else if (m_deletedID.Load((char *)p_indexBlobs[3].Data(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + + if (m_pSamples.R() != m_pGraph.R() || m_pSamples.R() != m_deletedID.R()) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Index data is corrupted, please rebuild the index. Samples: %i, Graph: %i, DeletedID: %i.", + m_pSamples.R(), m_pGraph.R(), m_deletedID.R()); + return ErrorCode::FailedParseValue; + } - template - ErrorCode Index::LoadIndexDataFromMemory(const std::vector& p_indexBlobs) - { - if (p_indexBlobs.size() < 3) return ErrorCode::LackOfInputs; - - if (m_pSamples.Load((char*)p_indexBlobs[0].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (m_pTrees.LoadTrees((char*)p_indexBlobs[1].Data()) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (m_pGraph.LoadGraph((char*)p_indexBlobs[2].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (p_indexBlobs.size() <= 3) m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity); - else if (m_deletedID.Load((char*)p_indexBlobs[3].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); - return ErrorCode::Success; - } + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + m_threadPool.init(); + return ErrorCode::Success; +} - template - ErrorCode Index::LoadIndexData(const std::vector>& p_indexStreams) - { - if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; - - ErrorCode ret = ErrorCode::Success; - if (p_indexStreams[0] == nullptr || (ret = m_pSamples.Load(p_indexStreams[0], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - if (p_indexStreams[1] == nullptr || (ret = m_pTrees.LoadTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; - if (p_indexStreams[2] == nullptr || (ret = m_pGraph.LoadGraph(p_indexStreams[2], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - if (p_indexStreams[3] == nullptr) m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity); - else if ((ret = m_deletedID.Load(p_indexStreams[3], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); - return ret; - } +template +ErrorCode Index::LoadIndexData(const std::vector> &p_indexStreams) +{ + if (p_indexStreams.size() < 4) + return ErrorCode::LackOfInputs; + + ErrorCode ret = ErrorCode::Success; + if (p_indexStreams[0] == nullptr || + (ret = m_pSamples.Load(p_indexStreams[0], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) + return ret; + if (p_indexStreams[1] == nullptr || (ret = m_pTrees.LoadTrees(p_indexStreams[1])) != ErrorCode::Success) + return ret; + if (p_indexStreams[2] == nullptr || + (ret = m_pGraph.LoadGraph(p_indexStreams[2], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) + return ret; + if (p_indexStreams[3] == nullptr) + m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + else if ((ret = m_deletedID.Load(p_indexStreams[3], m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains)) != ErrorCode::Success) + return ret; + + if (m_pSamples.R() != m_pGraph.R() || m_pSamples.R() != m_deletedID.R()) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Index data is corrupted, please rebuild the index. Samples: %i, Graph: %i, DeletedID: %i.", + m_pSamples.R(), m_pGraph.R(), m_deletedID.R()); + return ErrorCode::FailedParseValue; + } - template - ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) - { - auto workSpace = m_workSpacePool->Rent(); - m_iHashTableExp = workSpace->HashTableExponent(); - m_workSpacePool->Return(workSpace); + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + m_threadPool.init(); + return ret; +} -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + GetParameter(RepresentStr) + std::string("\n")).c_str()); +template ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) +{ + auto workspace = m_workSpaceFactory->GetWorkSpace(); + if (workspace) + { + m_iHashTableExp = workspace->HashTableExponent(); + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workspace)); + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + GetParameter(RepresentStr) + std::string("\n")).c_str()); #include "inc/Core/BKT/ParameterDefinitionList.h" #undef DefineBKTParameter - IOSTRING(p_configOut, WriteString, "\n"); - return ErrorCode::Success; - } + IOSTRING(p_configOut, WriteString, "\n"); + return ErrorCode::Success; +} - template - ErrorCode Index::SaveIndexData(const std::vector>& p_indexStreams) - { - if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; - - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); - - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Save(p_indexStreams[0])) != ErrorCode::Success) return ret; - if ((ret = m_pTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; - if ((ret = m_pGraph.SaveGraph(p_indexStreams[2])) != ErrorCode::Success) return ret; - if ((ret = m_deletedID.Save(p_indexStreams[3])) != ErrorCode::Success) return ret; - return ret; - } +template +ErrorCode Index::SaveIndexData(const std::vector> &p_indexStreams) +{ + if (p_indexStreams.size() < 4) + return ErrorCode::LackOfInputs; + + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); + + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Save(p_indexStreams[0])) != ErrorCode::Success) + return ret; + if ((ret = m_pTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) + return ret; + if ((ret = m_pGraph.SaveGraph(p_indexStreams[2])) != ErrorCode::Success) + return ret; + if ((ret = m_deletedID.Save(p_indexStreams[3])) != ErrorCode::Success) + return ret; + return ret; +} -#pragma region K-NN search +#pragma region K - NN search /* -#define Search(CheckDeleted, CheckDuplicated) \ - std::shared_lock lock(*(m_pTrees.m_lock)); \ + +#define Search(CheckDeleted, CheckDuplicated, CheckFilter) \ + std::shared_lock lock(*(m_pTrees.m_lock)); \ m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); \ m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; \ @@ -138,6 +177,7 @@ namespace SPTAG const SizeType *node = m_pGraph[tmpNode]; \ _mm_prefetch((const char *)node, _MM_HINT_T0); \ for (DimensionType i = 0; i <= checkPos; i++) { \ + if (node[i] < 0 || node[i] >= m_pSamples.R()) break; \ _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ } \ if (gnode.distance <= p_query.worstDist()) { \ @@ -148,44 +188,58 @@ namespace SPTAG do { \ CheckDeleted \ { \ - p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ - CheckDuplicated \ - break; \ + CheckFilter \ + { \ + CheckDuplicated \ + break; \ + } \ } \ tmpNode = m_pTrees[i].centerid; \ } while (i++ < tnode.childEnd); \ - } else { \ - CheckDeleted \ - { \ - p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ - p_query.AddPoint(tmpNode, gnode.distance); \ - } \ - } \ + } else { \ + CheckDeleted \ + { \ + CheckFilter \ + { \ + p_query.AddPoint(tmpNode, gnode.distance); \ + } \ + } \ + } \ } else { \ - p_space.m_iNumOfContinuousNoBetterPropagation++; \ - if (p_space.m_iNumOfContinuousNoBetterPropagation > p_space.m_iContinuousLimit || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ - p_query.SortResult(); return; \ + CheckDeleted \ + { \ + if (gnode.distance > p_space.m_Results.worst() || p_space.m_iNumberOfCheckedLeaves > +p_space.m_iMaxCheck) { \ + p_query.SortResult(); return; \ + } \ } \ } \ for (DimensionType i = 0; i <= checkPos; i++) { \ SizeType nn_index = node[i]; \ if (nn_index < 0) break; \ + if (nn_index >= m_pSamples.R()) continue; \ if (p_space.CheckAndSet(nn_index)) continue; \ - float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ + float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], +GetFeatureDim()); \ p_space.m_iNumberOfCheckedLeaves++; \ - p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); \ + if (p_space.m_Results.insert(distance2leaf)) { \ + p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); \ + } \ } \ if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) { \ - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + +p_space.m_iNumberOfCheckedLeaves); \ } \ } \ p_query.SortResult(); \ */ -#define Search(CheckDeleted, CheckDuplicated) \ - std::shared_lock lock(*(m_pTrees.m_lock)); \ - m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); \ - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ +/* +#define SearchIterative(CheckDeleted, p_isFirst, batch) \ + if (p_isFirst) { \ + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); \ + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ + } \ const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; \ while (!p_space.m_NGQueue.empty()) { \ NodeDistPair gnode = p_space.m_NGQueue.pop(); \ @@ -195,438 +249,1050 @@ namespace SPTAG for (DimensionType i = 0; i <= checkPos; i++) { \ _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ } \ - if (gnode.distance <= p_query.worstDist()) { \ - SizeType checkNode = node[checkPos]; \ - if (checkNode < -1) { \ - const COMMON::BKTNode& tnode = m_pTrees[-2 - checkNode]; \ - SizeType i = -tnode.childStart; \ - do { \ - CheckDeleted \ - { \ - CheckDuplicated \ - break; \ - } \ - tmpNode = m_pTrees[i].centerid; \ - } while (i++ < tnode.childEnd); \ - } else { \ - CheckDeleted \ - { \ - p_query.AddPoint(tmpNode, gnode.distance); \ - } \ - } \ - } else { \ - CheckDeleted \ + CheckDeleted \ { \ - if (gnode.distance > p_space.m_Results.worst() || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ - p_query.SortResult(); return; \ + p_query.AddPoint(tmpNode, gnode.distance); \ + count++; \ + if (gnode.distance > p_space.m_Results.worst() || p_space.m_iNumberOfCheckedLeaves > +p_space.m_iMaxCheck) { \ + p_space.m_relaxedMono = true; \ } \ } \ + SizeType checkNode = node[checkPos]; \ + if (checkNode < -1) { \ + const COMMON::BKTNode& tnode = m_pTrees[-2 - checkNode]; \ + SizeType i = -tnode.childStart; \ + while (i < tnode.childEnd) { \ + tmpNode = m_pTrees[i].centerid; \ + CheckDeleted \ + { \ + float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[tmpNode], +GetFeatureDim()); \ + if (!p_space.CheckAndSet(tmpNode)) { \ + p_space.m_NGQueue.insert(NodeDistPair(tmpNode, distance2leaf)); \ + } \ + } \ + i++; \ + }\ } \ - for (DimensionType i = 0; i <= checkPos; i++) { \ - SizeType nn_index = node[i]; \ - if (nn_index < 0) break; \ - if (p_space.CheckAndSet(nn_index)) continue; \ - float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ - p_space.m_iNumberOfCheckedLeaves++; \ - if (p_space.m_Results.insert(distance2leaf)) { \ + for (DimensionType i = 0; i <= checkPos; i++) { \ + SizeType nn_index = node[i]; \ + if (nn_index < 0) break; \ + if (p_space.CheckAndSet(nn_index)) continue; \ + float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], +GetFeatureDim()); \ + p_space.m_iNumberOfCheckedLeaves++; \ p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); \ + p_space.m_Results.insert(distance2leaf); \ } \ - } \ if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) { \ - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + +p_space.m_iNumberOfCheckedLeaves); \ + } \ + if (count >= batch) {\ + break; \ } \ } \ p_query.SortResult(); \ +*/ +template +template &, SizeType, float), + bool (*checkFilter)(const std::shared_ptr &, SizeType, std::function)> +void Index::Search(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, + std::function filterFunc) const +{ + std::shared_lock lock(*(m_pTrees.m_lock)); + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); + const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; - template - void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted, bool p_searchDuplicated) const + while (!p_space.m_NGQueue.empty()) + { + NodeDistPair gnode = p_space.m_NGQueue.pop(); + SizeType tmpNode = gnode.node; + const SizeType *node = m_pGraph[tmpNode]; + _mm_prefetch((const char *)node, _MM_HINT_T0); + for (DimensionType i = 0; i <= checkPos; i++) { - if (m_pQuantizer && !p_query.HasQuantizedTarget()) - { - p_query.SetTarget(p_query.GetTarget(), m_pQuantizer); - } + auto futureNode = node[i]; + if (futureNode < 0) + break; + _mm_prefetch((const char *)(m_pSamples)[futureNode], _MM_HINT_T0); + } - if (m_deletedID.Count() == 0 || p_searchDeleted) + if (gnode.distance <= p_query.worstDist()) + { + SizeType checkNode = node[checkPos]; + if (checkNode < -1) { - if (p_searchDuplicated) - { - Search(;, if (!p_query.AddPoint(tmpNode, gnode.distance))) - } - else + const COMMON::BKTNode &tnode = m_pTrees[-2 - checkNode]; + SizeType i = -tnode.childStart; + do { - Search(;, p_query.AddPoint(tmpNode, gnode.distance);) - } + if (notDeleted(m_deletedID, tmpNode)) + { + if (checkFilter(m_pMetadata, tmpNode, filterFunc)) + { + if (isDup(p_query, tmpNode, gnode.distance)) + break; + } + } + if (i <= 0) + break; + tmpNode = m_pTrees[i].centerid; + } while (i++ < tnode.childEnd); } else { - if (p_searchDuplicated) + + if (notDeleted(m_deletedID, tmpNode)) { - Search(if (!m_deletedID.Contains(tmpNode)), if (!p_query.AddPoint(tmpNode, gnode.distance))) + if (checkFilter(m_pMetadata, tmpNode, filterFunc)) + { + p_query.AddPoint(tmpNode, gnode.distance); + } } - else + } + } + else + { + if (notDeleted(m_deletedID, tmpNode)) + { + if (gnode.distance > p_space.m_Results.worst() || + p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { - Search(if (!m_deletedID.Contains(tmpNode)), p_query.AddPoint(tmpNode, gnode.distance);) + p_query.SortResult(); + return; } } } - - template - ErrorCode Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const + for (DimensionType i = 0; i <= checkPos; i++) { - if (!m_bReady) return ErrorCode::EmptyIndex; - - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_iMaxCheck, p_query.GetResultNum()); - - SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, p_searchDeleted, true); - - m_workSpacePool->Return(workSpace); + SizeType nn_index = node[i]; + if (nn_index < 0) + break; + + if (p_space.CheckAndSet(nn_index)) + continue; + float distance2leaf = + m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); + p_space.m_iNumberOfCheckedLeaves++; + if (p_space.m_Results.insert(distance2leaf)) + { + p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); + } + } + if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) + { + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, + m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); + } + } + p_query.SortResult(); +} - if (p_query.WithMeta() && nullptr != m_pMetadata) +template +template &, SizeType, float), + bool (*checkFilter)(const std::shared_ptr &, SizeType, std::function)> +int Index::SearchIterative(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_isFirst, + int batch) const +{ + std::shared_lock lock(*(m_pTrees.m_lock)); + if (p_isFirst) + { + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); + } + int count = 0; + const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; + while (!p_space.m_NGQueue.empty()) + { + NodeDistPair gnode = p_space.m_NGQueue.pop(); + SizeType tmpNode = gnode.node; + const SizeType *node = m_pGraph[tmpNode]; + _mm_prefetch((const char *)node, _MM_HINT_T0); + for (DimensionType i = 0; i <= checkPos; i++) + { + auto futureNode = node[i]; + if (futureNode < 0) + break; + _mm_prefetch((const char *)(m_pSamples)[futureNode], _MM_HINT_T0); + } + if (notDeleted(m_deletedID, tmpNode)) + { + if (checkFilter(m_pMetadata, tmpNode, p_space.m_filterFunc)) { - for (int i = 0; i < p_query.GetResultNum(); ++i) + p_query.AddPoint(tmpNode, gnode.distance); + count++; + } + if (gnode.distance > p_space.m_Results.worst() || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) + { + p_space.m_relaxedMono = true; + } + } + SizeType checkNode = node[checkPos]; + if (checkNode < -1) + { + const COMMON::BKTNode &tnode = m_pTrees[-2 - checkNode]; + SizeType i = -tnode.childStart; + while (i < tnode.childEnd) + { + tmpNode = m_pTrees[i].centerid; + if (notDeleted(m_deletedID, tmpNode)) { - SizeType result = p_query.GetResult(i)->VID; - p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + if (!p_space.CheckAndSet(tmpNode)) + { + p_space.m_NGQueue.insert(NodeDistPair(tmpNode, gnode.distance)); + } } + i++; } - return ErrorCode::Success; } - - template - ErrorCode Index::RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted) const + for (DimensionType i = 0; i <= checkPos; i++) + { + SizeType nn_index = node[i]; + if (nn_index < 0) + break; + if (p_space.CheckAndSet(nn_index)) + continue; + float distance2leaf = + m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); + p_space.m_Results.insert(distance2leaf); + } + if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) + { + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, + m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); + } + if (count >= batch) { - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_pGraph.m_iMaxCheckForRefineGraph, p_query.GetResultNum()); + break; + } + } + p_query.SortResult(); + return count; +} - SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, p_searchDeleted, false); +namespace StaticDispatch +{ +template bool AlwaysTrue(Args...) +{ + return true; +} - m_workSpacePool->Return(workSpace); - return ErrorCode::Success; - } +bool CheckIfNotDeleted(const COMMON::Labelset &deletedIDs, SizeType node) +{ + return !deletedIDs.Contains(node); +} + +template bool CheckDup(COMMON::QueryResultSet &query, SizeType node, float score) +{ + return !query.AddPoint(node, score); +} + +template bool NeverDup(COMMON::QueryResultSet &query, SizeType node, float score) +{ + query.AddPoint(node, score); + return true; +} + +bool CheckFilter(const std::shared_ptr &metadata, SizeType node, + std::function filterFunc) +{ + return filterFunc(metadata->GetMetadata(node)); +} + +}; // namespace StaticDispatch + +template +void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted, + bool p_searchDuplicated, std::function filterFunc) const +{ + if (m_pQuantizer && !p_query.HasQuantizedTarget()) + { + p_query.SetTarget(p_query.GetTarget(), m_pQuantizer); + } + + // bitflags for which dispatch to take + uint8_t flags = 0; + flags += (m_deletedID.Count() == 0 || p_searchDeleted) << 2; + flags += p_searchDuplicated << 1; + flags += (filterFunc == nullptr); + + switch (flags) + { + case 0b000: + Search( + p_query, p_space, filterFunc); + break; + case 0b001: + Search( + p_query, p_space, filterFunc); + break; + case 0b010: + Search( + p_query, p_space, filterFunc); + break; + case 0b011: + Search( + p_query, p_space, filterFunc); + break; + case 0b100: + Search(p_query, p_space, + filterFunc); + break; + case 0b101: + Search(p_query, p_space, + filterFunc); + break; + case 0b110: + Search(p_query, p_space, + filterFunc); + break; + case 0b111: + Search(p_query, p_space, + filterFunc); + break; + default: + std::ostringstream oss; + oss << "Invalid flags in BKT SearchIndex dispatch: " << flags; + throw std::logic_error(oss.str()); + } + + p_query.SetScanned(p_space.m_iNumberOfCheckedLeaves); +} + +template +int Index::SearchIndexIterative(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_isFirst, + int batch, bool p_searchDeleted, bool p_searchDuplicated) const +{ + int count = 0; + // bitflags for which dispatch to take + uint8_t flags = 0; + flags += (m_deletedID.Count() == 0 || p_searchDeleted) << 2; + flags += p_searchDuplicated << 1; + flags += (p_space.m_filterFunc == nullptr); + + switch (flags) + { + case 0b000: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b001: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b010: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b011: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b100: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b101: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b110: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + case 0b111: + count = SearchIterative( + p_query, p_space, p_isFirst, batch); + break; + default: + std::ostringstream oss; + oss << "Invalid flags in BKT SearchIndex dispatch: " << flags; + throw std::logic_error(oss.str()); + } + + p_query.SetScanned(p_space.m_iNumberOfCheckedLeaves); + return count; +} - template - ErrorCode Index::SearchTree(QueryResult& p_query) const +template +bool Index::SearchIndexIterativeFromNeareast(QueryResult &p_query, COMMON::WorkSpace *p_space, bool p_isFirst, + bool p_searchDeleted) const +{ + if (p_isFirst) + { + p_space->ResetResult(m_iMaxCheck, p_query.GetResultNum()); + SearchIndex(*((COMMON::QueryResultSet *)&p_query), *p_space, p_searchDeleted, true, p_space->m_filterFunc); + // make sure other node can be traversed after topk found + p_space->nodeCheckStatus.clear(); + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + SizeType result = p_query.GetResult(i)->VID; + if (result < 0) + continue; + p_space->nodeCheckStatus.CheckAndSet(result); + } + + for (int i = 0; i < p_query.GetResultNum(); ++i) { - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_pGraph.m_iMaxCheckForRefineGraph, p_query.GetResultNum()); - - COMMON::QueryResultSet* p_results = (COMMON::QueryResultSet*)&p_query; - m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace); - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace, m_iNumberOfInitialDynamicPivots); - BasicResult * res = p_query.GetResults(); - for (int i = 0; i < p_query.GetResultNum(); i++) + SizeType result = p_query.GetResult(i)->VID; + if (result < 0) + continue; + const DimensionType checkPos = m_pGraph.m_iNeighborhoodSize - 1; + const SizeType *node = m_pGraph[result]; + _mm_prefetch((const char *)node, _MM_HINT_T0); + for (DimensionType i = 0; i <= checkPos; i++) { - auto& cell = workSpace->m_NGQueue.pop(); - res[i].VID = cell.node; - res[i].Dist = cell.distance; + auto futureNode = node[i]; + if (futureNode < 0) + break; + _mm_prefetch((const char *)(m_pSamples)[futureNode], _MM_HINT_T0); + } + for (DimensionType i = 0; i <= checkPos; i++) + { + SizeType nn_index = node[i]; + if (nn_index < 0) + break; + if (p_space->CheckAndSet(nn_index)) + continue; + float distance2leaf = m_fComputeDistance((const T *)p_query.GetQuantizedTarget(), + (m_pSamples)[nn_index], GetFeatureDim()); + p_space->m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); } - m_workSpacePool->Return(workSpace); - return ErrorCode::Success; } -#pragma endregion - - template - ErrorCode Index::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, bool p_shareOwnership) + } + else + { + //p_space->ResetResult(m_iMaxCheck, p_query.GetResultNum()); + SearchIndexIterative(*((COMMON::QueryResultSet *)&p_query), *p_space, p_isFirst, p_query.GetResultNum(), + p_searchDeleted, true); + } + if (p_query.GetResult(0) == nullptr || p_query.GetResult(0)->VID < 0) + { + return false; + } + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) { - if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } + return true; +} - omp_set_num_threads(m_iNumberOfThreads); +template ErrorCode Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - m_pSamples.Initialize(p_vectorNum, p_dimension, m_iDataBlockSize, m_iDataCapacity, (T*)p_data, p_shareOwnership); - m_deletedID.Initialize(p_vectorNum, m_iDataBlockSize, m_iDataCapacity); + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_iMaxCheck); - if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) - { - int base = COMMON::Utils::GetBase(); -#pragma omp parallel for - for (SizeType i = 0; i < GetNumSamples(); i++) { - COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base); - } - } + SearchIndex(*((COMMON::QueryResultSet *)&p_query), *workSpace, p_searchDeleted, true); + + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } + return ErrorCode::Success; +} - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); +template +ErrorCode Index::SearchIndexWithFilter(QueryResult &p_query, std::function filterFunc, + int maxCheck, bool p_searchDeleted) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - auto t1 = std::chrono::high_resolution_clock::now(); - m_pTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads); - auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); - - m_pGraph.BuildGraph(this, &(m_pTrees.GetSampleMap())); + auto workSpace = RentWorkSpace(p_query.GetResultNum(), filterFunc, maxCheck); - auto t3 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Build Graph time (s): %lld\n", std::chrono::duration_cast(t3 - t2).count()); + SearchIndex(*((COMMON::QueryResultSet *)&p_query), *workSpace, p_searchDeleted, true, filterFunc); - m_bReady = true; - return ErrorCode::Success; + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); } + } + return ErrorCode::Success; +} + +template +std::shared_ptr Index::GetIterator(const void *p_target, bool p_searchDeleted, std::function p_filterFunc, int p_maxCheck) const +{ + if (!m_bReady) + return nullptr; + + std::shared_ptr resultIterator = + std::make_shared((const void *)this, p_target, p_searchDeleted, 1, p_filterFunc, p_maxCheck); + return resultIterator; +} + +template +ErrorCode Index::SearchIndexIterativeNext(QueryResult &p_query, COMMON::WorkSpace *workSpace, int p_batch, + int &resultCount, bool p_isFirst, bool p_searchDeleted) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; + if (p_isFirst) workSpace->ResetResult(m_iMaxCheck, p_batch); + resultCount = SearchIndexIterative(*((COMMON::QueryResultSet *)&p_query), *workSpace, p_isFirst, p_batch, + p_searchDeleted, true); - template - ErrorCode Index::RefineIndex(std::shared_ptr& p_newIndex) + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < resultCount; ++i) { - p_newIndex.reset(new Index()); - Index* ptr = (Index*)p_newIndex.get(); + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } + return ErrorCode::Success; +} -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - ptr->VarName = VarName; \ +template ErrorCode Index::SearchIndexIterativeEnd(std::unique_ptr space) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; + if (nullptr != space) + m_workSpaceFactory->ReturnWorkSpace(std::move(space)); + return ErrorCode::Success; +} -#include "inc/Core/BKT/ParameterDefinitionList.h" -#undef DefineBKTParameter +template std::unique_ptr Index::RentWorkSpace(int batch, std::function p_filterFunc, int p_maxCheck) const +{ + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new COMMON::WorkSpace()); + workSpace->Initialize(max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); + } + workSpace->Reset(p_maxCheck > 0? p_maxCheck: m_iMaxCheck, batch); + workSpace->m_filterFunc = p_filterFunc; + return std::move(workSpace); +} - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); +template ErrorCode Index::RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted) const +{ + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_pGraph.m_iMaxCheckForRefineGraph); + SearchIndex(*((COMMON::QueryResultSet *)&p_query), *workSpace, p_searchDeleted, false); + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + return ErrorCode::Success; +} - SizeType newR = GetNumSamples(); +template ErrorCode Index::SearchTree(QueryResult &p_query) const +{ + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_pGraph.m_iMaxCheckForRefineGraph); - std::vector indices; - std::vector reverseIndices(newR); - for (SizeType i = 0; i < newR; i++) { - if (!m_deletedID.Contains(i)) { - indices.push_back(i); - reverseIndices[i] = i; - } - else { - while (m_deletedID.Contains(newR - 1) && newR > i) newR--; - if (newR == i) break; - indices.push_back(newR - 1); - reverseIndices[newR - 1] = i; - newR--; - } - } + COMMON::QueryResultSet *p_results = (COMMON::QueryResultSet *)&p_query; + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace, m_iNumberOfInitialDynamicPivots); + BasicResult *res = p_query.GetResults(); + for (int i = 0; i < p_query.GetResultNum(); i++) + { + auto &cell = workSpace->m_NGQueue.pop(); + res[i].VID = cell.node; + res[i].Dist = cell.distance; + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + + return ErrorCode::Success; +} +#pragma endregion - LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); - if (newR == 0) return ErrorCode::EmptyIndex; +template +ErrorCode Index::BuildIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, + bool p_shareOwnership) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; - ptr->m_workSpacePool.reset(new COMMON::WorkSpacePool()); - ptr->m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - ptr->m_threadPool.init(); + m_pGraph.m_iThreadNum = m_iNumberOfThreads; - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Refine(indices, ptr->m_pSamples)) != ErrorCode::Success) return ret; - if (nullptr != m_pMetadata && (ret = m_pMetadata->RefineMetadata(indices, ptr->m_pMetadata, m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)) != ErrorCode::Success) return ret; + m_pSamples.Initialize(p_vectorNum, p_dimension, m_iDataBlockSize, m_iDataCapacity, (T *)p_data, p_shareOwnership); + m_deletedID.Initialize(p_vectorNum, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); - ptr->m_deletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity); - COMMON::BKTree* newtree = &(ptr->m_pTrees); - (*newtree).BuildTrees(ptr->m_pSamples, ptr->m_iDistCalcMethod, omp_get_num_threads()); - m_pGraph.RefineGraph(this, indices, reverseIndices, nullptr, &(ptr->m_pGraph), &(ptr->m_pTrees.GetSampleMap())); - if (HasMetaMapping()) ptr->BuildMetaMapping(false); - ptr->m_bReady = true; - return ret; - } + if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) + { + int base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); + COMMON::Utils::BatchNormalize(m_pSamples[0], p_vectorNum, p_dimension, base, + m_iNumberOfThreads); + } - template - ErrorCode Index::RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort) - { - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); + m_threadPool.init(); - SizeType newR = GetNumSamples(); + auto t1 = std::chrono::high_resolution_clock::now(); + m_pTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads); + auto t2 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n", + std::chrono::duration_cast(t2 - t1).count()); - std::vector indices; - std::vector reverseIndices(newR); - for (SizeType i = 0; i < newR; i++) { - if (!m_deletedID.Contains(i)) { - indices.push_back(i); - reverseIndices[i] = i; - } - else { - while (m_deletedID.Contains(newR - 1) && newR > i) newR--; - if (newR == i) break; - indices.push_back(newR - 1); - reverseIndices[newR - 1] = i; - newR--; - } - } + m_pGraph.BuildGraph(this, &(m_pTrees.GetSampleMap())); + + auto t3 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Graph time (s): %lld\n", + std::chrono::duration_cast(t3 - t2).count()); - LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); - if (newR == 0) return ErrorCode::EmptyIndex; + m_bReady = true; + return ErrorCode::Success; +} - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Refine(indices, p_indexStreams[0])) != ErrorCode::Success) return ret; +template ErrorCode Index::RefineIndex(std::shared_ptr &p_newIndex) +{ + p_newIndex.reset(new Index()); + Index *ptr = (Index *)p_newIndex.get(); - if (p_abort != nullptr && p_abort->ShouldAbort()) return ErrorCode::ExternalAbort; +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) ptr->VarName = VarName; - COMMON::BKTree newTrees(m_pTrees); - newTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, omp_get_num_threads(), &indices, &reverseIndices); - if ((ret = newTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter - if (p_abort != nullptr && p_abort->ShouldAbort()) return ErrorCode::ExternalAbort; + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); - if ((ret = m_pGraph.RefineGraph(this, indices, reverseIndices, p_indexStreams[2], nullptr, &(newTrees.GetSampleMap()))) != ErrorCode::Success) return ret; + SizeType newR = GetNumSamples(); - COMMON::Labelset newDeletedID; - newDeletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity); - if ((ret = newDeletedID.Save(p_indexStreams[3])) != ErrorCode::Success) return ret; - if (nullptr != m_pMetadata) { - if (p_indexStreams.size() < 6) return ErrorCode::LackOfInputs; - if ((ret = m_pMetadata->RefineMetadata(indices, p_indexStreams[4], p_indexStreams[5])) != ErrorCode::Success) return ret; - } - return ret; + std::vector indices; + std::vector reverseIndices(newR); + for (SizeType i = 0; i < newR; i++) + { + if (!m_deletedID.Contains(i)) + { + indices.push_back(i); + reverseIndices[i] = i; } - - template - ErrorCode Index::DeleteIndex(const void* p_vectors, SizeType p_vectorNum) { - const T* ptr_v = (const T*)p_vectors; -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < p_vectorNum; i++) { - COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); - SearchIndex(query); - - for (int i = 0; i < m_pGraph.m_iCEF; i++) { - if (query.GetResult(i)->Dist < 1e-6) { - DeleteIndex(query.GetResult(i)->VID); - } - } - } - return ErrorCode::Success; + else + { + while (m_deletedID.Contains(newR - 1) && newR > i) + newR--; + if (newR == i) + break; + indices.push_back(newR - 1); + reverseIndices[newR - 1] = i; + newR--; } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); + if (newR == 0) + return ErrorCode::EmptyIndex; + + ptr->m_threadPool.init(); + + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Refine(indices, ptr->m_pSamples)) != ErrorCode::Success) + return ret; + if (nullptr != m_pMetadata && + (ret = m_pMetadata->RefineMetadata(indices, ptr->m_pMetadata, m_iDataBlockSize, m_iDataCapacity, + m_iMetaRecordSize)) != ErrorCode::Success) + return ret; + + ptr->m_deletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + COMMON::BKTree *newtree = &(ptr->m_pTrees); + (*newtree).BuildTrees(ptr->m_pSamples, ptr->m_iDistCalcMethod, m_iNumberOfThreads); + m_pGraph.RefineGraph(this, indices, reverseIndices, nullptr, &(ptr->m_pGraph), &(ptr->m_pTrees.GetSampleMap())); + if (HasMetaMapping()) + ptr->BuildMetaMapping(false); + ptr->m_bReady = true; + return ret; +} - template - ErrorCode Index::DeleteIndex(const SizeType& p_id) { - if (!m_bReady) return ErrorCode::EmptyIndex; +template +ErrorCode Index::RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping) +{ + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); - std::shared_lock sharedlock(m_dataDeleteLock); - if (m_deletedID.Insert(p_id)) return ErrorCode::Success; - return ErrorCode::VectorNotFound; - } + SizeType newR = GetNumSamples(); - template - ErrorCode Index::AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) + std::vector indices; + std::vector reverseIndices; + if (p_mapping == nullptr) + { + p_mapping = &reverseIndices; + } + p_mapping->resize(newR); + for (SizeType i = 0; i < newR; i++) + { + if (!m_deletedID.Contains(i)) { - if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; + indices.push_back(i); + (*p_mapping)[i] = i; + } + else + { + while (m_deletedID.Contains(newR - 1) && newR > i) + newR--; + if (newR == i) + break; + indices.push_back(newR - 1); + (*p_mapping)[newR - 1] = i; + newR--; + } + } - SizeType begin, end; - ErrorCode ret; - { - std::lock_guard lock(m_dataAddLock); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); + if (newR == 0) + return ErrorCode::EmptyIndex; - begin = GetNumSamples(); - end = begin + p_vectorNum; + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Refine(indices, p_indexStreams[0])) != ErrorCode::Success) + return ret; - if (begin == 0) { - if (p_metadataSet != nullptr) { - m_pMetadata.reset(new MemMetadataSet(m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)); - m_pMetadata->AddBatch(*p_metadataSet); - if (p_withMetaIndex) BuildMetaMapping(false); - } - if ((ret = BuildIndex(p_data, p_vectorNum, p_dimension, p_normalized)) != ErrorCode::Success) return ret; - return ErrorCode::Success; - } + if (p_abort != nullptr && p_abort->ShouldAbort()) + return ErrorCode::ExternalAbort; - if (p_dimension != GetFeatureDim()) return ErrorCode::DimensionSizeMismatch; + COMMON::BKTree newTrees(m_pTrees); + newTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads, &indices, p_mapping); + if ((ret = newTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Save Refine Tree!!\n"); + return ret; + } - if (m_pSamples.AddBatch((const T*)p_data, p_vectorNum) != ErrorCode::Success || - m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || - m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); - m_pSamples.SetR(begin); - m_pGraph.SetR(begin); - m_deletedID.SetR(begin); - return ErrorCode::MemoryOverFlow; - } + if (p_abort != nullptr && p_abort->ShouldAbort()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Abort!!\n"); + return ErrorCode::ExternalAbort; + } + + + if ((ret = m_pGraph.RefineGraph(this, indices, (*p_mapping), p_indexStreams[2], nullptr, + &(newTrees.GetSampleMap()))) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Save Refine Graph!!\n"); + return ret; + } + + + COMMON::Labelset newDeletedID; + newDeletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + if ((ret = newDeletedID.Save(p_indexStreams[3])) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Save Refine DeletedID!!\n"); + return ret; + } + + if (nullptr != m_pMetadata) + { + if (p_indexStreams.size() < 6) + return ErrorCode::LackOfInputs; + if ((ret = m_pMetadata->RefineMetadata(indices, p_indexStreams[4], p_indexStreams[5])) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to Save Refine Metadata!!\n"); + return ret; + } + } + return ret; +} - if (m_pMetadata != nullptr) { - if (p_metadataSet != nullptr) { - m_pMetadata->AddBatch(*p_metadataSet); - if (HasMetaMapping()) { - for (SizeType i = begin; i < end; i++) { - ByteArray meta = m_pMetadata->GetMetadata(i); - std::string metastr((char*)meta.Data(), meta.Length()); - UpdateMetaMapping(metastr, i); - } +template ErrorCode Index::DeleteIndex(const void *p_vectors, SizeType p_vectorNum) +{ + const T *ptr_v = (const T *)p_vectors; + std::vector mythreads; + mythreads.reserve(m_iNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < p_vectorNum) + { + COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); + SearchIndex(query); + + for (int j = 0; j < m_pGraph.m_iCEF; j++) + { + if (query.GetResult(j)->Dist < 1e-6) + { + DeleteIndex(query.GetResult(j)->VID); } } - else { - for (SizeType i = begin; i < end; i++) m_pMetadata->Add(ByteArray::c_empty); - } } - } - - if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) - { - int base = COMMON::Utils::GetBase(); - for (SizeType i = begin; i < end; i++) { - COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base); + else + { + return; } } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + return ErrorCode::Success; +} - if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) { - m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph, m_iDistCalcMethod)); - } +template ErrorCode Index::DeleteIndex(const SizeType &p_id) +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; + + std::shared_lock sharedlock(m_dataDeleteLock); + if (m_deletedID.Insert(p_id)) + return ErrorCode::Success; + return ErrorCode::VectorNotFound; +} - for (SizeType node = begin; node < end; node++) +template +ErrorCode Index::AddIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, + std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; + + SizeType begin, end; + ErrorCode ret; + { + std::lock_guard lock(m_dataAddLock); + + begin = GetNumSamples(); + end = begin + p_vectorNum; + + if (begin == 0) + { + if (p_metadataSet != nullptr) { - m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); + m_pMetadata.reset(new MemMetadataSet(m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)); + m_pMetadata->AddBatch(*p_metadataSet); + if (p_withMetaIndex) + BuildMetaMapping(false); } + if ((ret = BuildIndex(p_data, p_vectorNum, p_dimension, p_normalized)) != ErrorCode::Success) + return ret; return ErrorCode::Success; } - template - ErrorCode - Index::UpdateIndex() + if (p_dimension != GetFeatureDim()) + return ErrorCode::DimensionSizeMismatch; + + if (m_pSamples.AddBatch(p_vectorNum, (const T *)p_data) != ErrorCode::Success || + m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || + m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) { - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - return ErrorCode::Success; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + m_deletedID.SetR(begin); + return ErrorCode::MemoryOverFlow; } - template - ErrorCode - Index::SetParameter(const char* p_param, const char* p_value, const char* p_section) + if (m_pMetadata != nullptr) { - if (nullptr == p_param || nullptr == p_value) return ErrorCode::Fail; - -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ - { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ - VarType tmp; \ - if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ - { \ - VarName = tmp; \ - } \ - } \ + if (p_metadataSet != nullptr) + { + m_pMetadata->AddBatch(*p_metadataSet); + if (HasMetaMapping()) + { + for (SizeType i = begin; i < end; i++) + { + ByteArray meta = m_pMetadata->GetMetadata(i); + std::string metastr((char *)meta.Data(), meta.Length()); + UpdateMetaMapping(metastr, i); + } + } + } + else + { + for (SizeType i = begin; i < end; i++) + m_pMetadata->Add(ByteArray::c_empty); + } + } + } -#include "inc/Core/BKT/ParameterDefinitionList.h" -#undef DefineBKTParameter + if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) + { + int base = COMMON::Utils::GetBase(); + for (SizeType i = begin; i < end; i++) + { + COMMON::Utils::Normalize((T *)m_pSamples[i], GetFeatureDim(), base); + } + } - if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) { - m_fComputeDistance = m_pQuantizer ? m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod) : COMMON::DistanceCalcSelector(m_iDistCalcMethod); - auto base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? base * base : 1; - } - return ErrorCode::Success; + if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) + { + m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph, m_iDistCalcMethod)); + } + + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); + } + return ErrorCode::Success; +} + +template +ErrorCode Index::AddIndexId(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, int &beginHead, + int &endHead) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; + + SizeType begin, end; + { + std::lock_guard lock(m_dataAddLock); + + begin = GetNumSamples(); + end = begin + p_vectorNum; + + if (begin == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Index Error: No vector in Index!\n"); + return ErrorCode::EmptyIndex; } + if (p_dimension != GetFeatureDim()) + return ErrorCode::DimensionSizeMismatch; - template - std::string - Index::GetParameter(const char* p_param, const char* p_section) const + if (m_pSamples.AddBatch(p_vectorNum, (const T *)p_data) != ErrorCode::Success || + m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || + m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) { - if (nullptr == p_param) return std::string(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + m_deletedID.SetR(begin); + return ErrorCode::MemoryOverFlow; + } + } + beginHead = begin; + endHead = end; + return ErrorCode::Success; +} + +template ErrorCode Index::AddIndexIdx(SizeType begin, SizeType end) +{ + if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) { + m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph, m_iDistCalcMethod)); + } + + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); + } + return ErrorCode::Success; +} + +template ErrorCode Index::UpdateIndex() +{ + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + return ErrorCode::Success; +} + +template ErrorCode Index::Check() +{ + std::vector mythreads; + mythreads.reserve(m_iNumberOfThreads); + std::atomic_size_t sent(0); + ErrorCode ret = ErrorCode::Success; + for (int tid = 0; tid < m_iNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_pSamples.R()) + { + if (!m_deletedID.Contains(i)) + { + COMMON::QueryResultSet result(m_pSamples[i], 1); + if (SearchIndex(result) != ErrorCode::Success || result.GetResult(0)->VID != i) + { + ret = ErrorCode::Fail; + return; + } + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + return ret; +} -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ - { \ - return SPTAG::Helper::Convert::ConvertToString(VarName); \ - } \ +template ErrorCode Index::SetParameter(const char *p_param, const char *p_value, const char *p_section) +{ + if (nullptr == p_param || nullptr == p_value) + return ErrorCode::Fail; + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + VarType tmp; \ + if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ + { \ + VarName = tmp; \ + } \ + } #include "inc/Core/BKT/ParameterDefinitionList.h" #undef DefineBKTParameter - return std::string(); - } + if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) + { + m_fComputeDistance = m_pQuantizer ? m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod) + : COMMON::DistanceCalcSelector(m_iDistCalcMethod); + auto base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); + m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? base * base : 1; } + return ErrorCode::Success; } -#define DefineVectorValueType(Name, Type) \ -template class SPTAG::BKT::Index; \ +template std::string Index::GetParameter(const char *p_param, const char *p_section) const +{ + if (nullptr == p_param) + return std::string(); -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + return SPTAG::Helper::Convert::ConvertToString(VarName); \ + } +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + return std::string(); +} +} // namespace BKT +} // namespace SPTAG +#define DefineVectorValueType(Name, Type) template class SPTAG::BKT::Index; + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType diff --git a/AnnService/src/Core/Common/CommonUtils.cpp b/AnnService/src/Core/Common/CommonUtils.cpp index 0ba5d262d..80451696b 100644 --- a/AnnService/src/Core/Common/CommonUtils.cpp +++ b/AnnService/src/Core/Common/CommonUtils.cpp @@ -4,21 +4,41 @@ using namespace SPTAG; using namespace SPTAG::COMMON; - #define DefineVectorValueType(Name, Type) template int Utils::GetBase(); #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType -template -void Utils::BatchNormalize(T* data, SizeType row, DimensionType col, int base, int threads) +template void Utils::BatchNormalize(T *data, SizeType row, DimensionType col, int base, int threads) { -#pragma omp parallel for num_threads(threads) - for (SizeType i = 0; i < row; i++) - { - SPTAG::COMMON::Utils::Normalize(data + i * (size_t)col, col, base); - } + std::vector mythreads; + mythreads.reserve(threads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < threads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < row) + { + SPTAG::COMMON::Utils::Normalize(data + i * (size_t)col, col, base); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); } -#define DefineVectorValueType(Name, Type) template void Utils::BatchNormalize(Type* data, SizeType row, DimensionType col, int base, int threads); +#define DefineVectorValueType(Name, Type) \ + template void Utils::BatchNormalize(Type * data, SizeType row, DimensionType col, int base, int threads); #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType \ No newline at end of file diff --git a/AnnService/src/Core/Common/DistanceUtils.cpp b/AnnService/src/Core/Common/DistanceUtils.cpp index e98f38121..8ed1bc9ef 100644 --- a/AnnService/src/Core/Common/DistanceUtils.cpp +++ b/AnnService/src/Core/Common/DistanceUtils.cpp @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "inc/Core/Common/DistanceUtils.h" -#include using namespace SPTAG; using namespace SPTAG::COMMON; @@ -199,8 +198,8 @@ inline __m512 _mm512_mul_epi8(__m512i X, __m512i Y) { __m512i zero = _mm512_setzero_si512(); - __mmask64 sign_x_mask = _mm512_cmpgt_epi8_mask (zero, X); - __mmask64 sign_y_mask = _mm512_cmpgt_epi8_mask (zero, Y); + __mmask64 sign_x_mask = _mm512_cmpgt_epi8_mask(zero, X); + __mmask64 sign_y_mask = _mm512_cmpgt_epi8_mask(zero, Y); __m512i sign_x = _mm512_movm_epi8(sign_x_mask); __m512i sign_y = _mm512_movm_epi8(sign_y_mask); @@ -217,8 +216,8 @@ inline __m512 _mm512_sqdf_epi8(__m512i X, __m512i Y) { __m512i zero = _mm512_setzero_si512(); - __mmask64 sign_x_mask = _mm512_cmpgt_epi8_mask (zero, X); - __mmask64 sign_y_mask = _mm512_cmpgt_epi8_mask (zero, Y); + __mmask64 sign_x_mask = _mm512_cmpgt_epi8_mask(zero, X); + __mmask64 sign_y_mask = _mm512_cmpgt_epi8_mask(zero, Y); __m512i sign_x = _mm512_movm_epi8(sign_x_mask); __m512i sign_y = _mm512_movm_epi8(sign_y_mask); @@ -270,8 +269,8 @@ inline __m512 _mm512_sqdf_epi16(__m512i X, __m512i Y) { __m512i zero = _mm512_setzero_si512(); - __mmask32 sign_x_mask = _mm512_cmpgt_epi16_mask (zero, X); - __mmask32 sign_y_mask = _mm512_cmpgt_epi16_mask (zero, Y); + __mmask32 sign_x_mask = _mm512_cmpgt_epi16_mask(zero, X); + __mmask32 sign_y_mask = _mm512_cmpgt_epi16_mask(zero, Y); __m512i sign_x = _mm512_movm_epi16(sign_x_mask); __m512i sign_y = _mm512_movm_epi16(sign_y_mask); @@ -294,84 +293,103 @@ inline __m512 _mm512_sqdf_ps(__m512 X, __m512 Y) } #endif +#define REPEAT(type, ctype, delta, load, exec, acc, result) \ + { \ + type c1 = load((ctype *)(pX)); \ + type c2 = load((ctype *)(pY)); \ + pX += delta; \ + pY += delta; \ + result = acc(result, exec(c1, c2)); \ + } -#define REPEAT(type, ctype, delta, load, exec, acc, result) \ - { \ - type c1 = load((ctype *)(pX)); \ - type c2 = load((ctype *)(pY)); \ - pX += delta; pY += delta; \ - result = acc(result, exec(c1, c2)); \ - } \ - -float DistanceUtils::ComputeL2Distance_SSE(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_SSE(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) } - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX512(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX512(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::int8_t* pEnd64 = pX + ((length >> 6) << 6); + const std::int8_t *pEnd64 = pX + ((length >> 6) << 6); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd64) { + while (pX < pEnd64) + { REPEAT(__m512i, __m512i, 64, _mm512_loadu_si512, _mm512_sqdf_epi8, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -379,96 +397,124 @@ float DistanceUtils::ComputeL2Distance_AVX512(const std::int8_t* pX, const std:: __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_SSE(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_SSE(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) } - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX512(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX512(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::uint8_t* pEnd64 = pX + ((length >> 6) << 6); + const std::uint8_t *pEnd64 = pX + ((length >> 6) << 6); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd64) { + while (pX < pEnd64) + { REPEAT(__m512i, __m512i, 64, _mm512_loadu_si512, _mm512_sqdf_epu8, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -476,98 +522,126 @@ float DistanceUtils::ComputeL2Distance_AVX512(const std::uint8_t* pX, const std: __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + } + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_SSE(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_SSE(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) } - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX512(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX512(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::int16_t* pEnd32 = pX + ((length >> 5) << 5); + const std::int16_t *pEnd32 = pX + ((length >> 5) << 5); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m512i, __m512i, 32, _mm512_loadu_si512, _mm512_sqdf_epi16, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -575,41 +649,50 @@ float DistanceUtils::ComputeL2Distance_AVX512(const std::int16_t* pX, const std: __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd4) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; - c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd4) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; + c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_SSE(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_SSE(const float *pX, const float *pY, DimensionType length) { - const float* pEnd16 = pX + ((length >> 4) << 4); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); while (pX < pEnd16) { REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) } while (pX < pEnd4) { @@ -617,23 +700,25 @@ float DistanceUtils::ComputeL2Distance_SSE(const float* pX, const float* pY, Dim } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) { - float c1 = (*pX++) - (*pY++); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = (*pX++) - (*pY++); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX(const float *pX, const float *pY, DimensionType length) { - const float* pEnd16 = pX + ((length >> 4) << 4); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); while (pX < pEnd16) { REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); while (pX < pEnd4) @@ -642,22 +727,25 @@ float DistanceUtils::ComputeL2Distance_AVX(const float* pX, const float* pY, Dim } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) { - float c1 = (*pX++) - (*pY++); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = (*pX++) - (*pY++); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeL2Distance_AVX512(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeL2Distance_AVX512(const float *pX, const float *pY, DimensionType length) { - const float* pEnd8 = pX + ((length >> 3) << 3); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd8 = pX + ((length >> 3) << 3); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const float* pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd16 = pX + ((length >> 4) << 4); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m512, const float, 16, _mm512_loadu_ps, _mm512_sqdf_ps, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -676,79 +764,96 @@ float DistanceUtils::ComputeL2Distance_AVX512(const float* pX, const float* pY, } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) { - float c1 = (*pX++) - (*pY++); diff += c1 * c1; + while (pX < pEnd1) + { + float c1 = (*pX++) - (*pY++); + diff += c1 * c1; } return diff; } -float DistanceUtils::ComputeCosineDistance_SSE(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_SSE(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) } - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 16129 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 16129 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX512(const std::int8_t* pX, const std::int8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX512(const std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int8_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::int8_t* pEnd64 = pX + ((length >> 6) << 6); + const std::int8_t *pEnd64 = pX + ((length >> 6) << 6); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd64) { + while (pX < pEnd64) + { REPEAT(__m512i, __m512i, 64, _mm512_loadu_si512, _mm512_mul_epi8, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -756,93 +861,115 @@ float DistanceUtils::ComputeCosineDistance_AVX512(const std::int8_t* pX, const s __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 16129 - diff; } -float DistanceUtils::ComputeCosineDistance_SSE(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_SSE(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { + REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) } - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 65025 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 65025 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX512(const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX512(const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5); - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd32 = pX + ((length >> 5) << 5); + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd4 = pX + ((length >> 2) << 2); + const std::uint8_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::uint8_t* pEnd64 = pX + ((length >> 6) << 6); + const std::uint8_t *pEnd64 = pX + ((length >> 6) << 6); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd64) { + while (pX < pEnd64) + { REPEAT(__m512i, __m512i, 64, _mm512_loadu_si512, _mm512_mul_epu8, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -850,95 +977,117 @@ float DistanceUtils::ComputeCosineDistance_AVX512(const std::uint8_t* pX, const __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); return 65025 - diff; } -float DistanceUtils::ComputeCosineDistance_SSE(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_SSE(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { + REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) } - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); - return 1073676289 - diff; + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); + return 1073676289 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); - return 1073676289 - diff; + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); + return 1073676289 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX512(const std::int16_t* pX, const std::int16_t* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX512(const std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd4 = pX + ((length >> 2) << 2); + const std::int16_t *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const std::int16_t* pEnd32 = pX + ((length >> 5) << 5); + const std::int16_t *pEnd32 = pX + ((length >> 5) << 5); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd32) { + while (pX < pEnd32) + { REPEAT(__m512i, __m512i, 32, _mm512_loadu_si512, _mm512_mul_epi16, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -946,40 +1095,47 @@ float DistanceUtils::ComputeCosineDistance_AVX512(const std::int16_t* pX, const __m256 diff256 = _mm256_setzero_ps(); #endif - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); - while (pX < pEnd8) { + while (pX < pEnd8) + { REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; while (pX < pEnd4) { - float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1; - } - - while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++)); - return 1073676289 - diff; + float c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + c1 = ((float)(*pX++) * (float)(*pY++)); + diff += c1; + } + + while (pX < pEnd1) + diff += ((float)(*pX++) * (float)(*pY++)); + return 1073676289 - diff; } -float DistanceUtils::ComputeCosineDistance_SSE(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_SSE(const float *pX, const float *pY, DimensionType length) { - const float* pEnd16 = pX + ((length >> 4) << 4); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; __m128 diff128 = _mm_setzero_ps(); while (pX < pEnd16) { REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) } while (pX < pEnd4) { @@ -987,21 +1143,22 @@ float DistanceUtils::ComputeCosineDistance_SSE(const float* pX, const float* pY, } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) diff += (*pX++) * (*pY++); + while (pX < pEnd1) + diff += (*pX++) * (*pY++); return 1 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX(const float *pX, const float *pY, DimensionType length) { - const float* pEnd16 = pX + ((length >> 4) << 4); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; __m256 diff256 = _mm256_setzero_ps(); while (pX < pEnd16) { REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) + REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) } __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); while (pX < pEnd4) @@ -1010,20 +1167,22 @@ float DistanceUtils::ComputeCosineDistance_AVX(const float* pX, const float* pY, } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) diff += (*pX++) * (*pY++); + while (pX < pEnd1) + diff += (*pX++) * (*pY++); return 1 - diff; } -float DistanceUtils::ComputeCosineDistance_AVX512(const float* pX, const float* pY, DimensionType length) +float DistanceUtils::ComputeCosineDistance_AVX512(const float *pX, const float *pY, DimensionType length) { - const float* pEnd8 = pX + ((length >> 3) << 3); - const float* pEnd4 = pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd8 = pX + ((length >> 3) << 3); + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; #if (!defined _MSC_VER) || (_MSC_VER >= 1920) - const float* pEnd16 = pX + ((length >> 4) << 4); + const float *pEnd16 = pX + ((length >> 4) << 4); __m512 diff512 = _mm512_setzero_ps(); - while (pX < pEnd16) { + while (pX < pEnd16) + { REPEAT(__m512, const float, 16, _mm512_loadu_ps, _mm512_mul_ps, _mm512_add_ps, diff512) } __m256 diff256 = _mm256_add_ps(_mm512_castps512_ps256(diff512), _mm512_extractf32x8_ps(diff512, 1)); @@ -1042,6 +1201,7 @@ float DistanceUtils::ComputeCosineDistance_AVX512(const float* pX, const float* } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; - while (pX < pEnd1) diff += (*pX++) * (*pY++); + while (pX < pEnd1) + diff += (*pX++) * (*pY++); return 1 - diff; } diff --git a/AnnService/src/Core/Common/IQuantizer.cpp b/AnnService/src/Core/Common/IQuantizer.cpp index 6fe2ea641..f76955a40 100644 --- a/AnnService/src/Core/Common/IQuantizer.cpp +++ b/AnnService/src/Core/Common/IQuantizer.cpp @@ -1,127 +1,155 @@ #include -#include #include +#include #include namespace SPTAG { - namespace COMMON +namespace COMMON +{ +std::shared_ptr IQuantizer::LoadIQuantizer(std::shared_ptr p_in) +{ + QuantizerType quantizerType = QuantizerType::Undefined; + VectorValueType reconstructType = VectorValueType::Undefined; + std::shared_ptr ret = nullptr; + if (p_in->ReadBinary(sizeof(QuantizerType), (char *)&quantizerType) != sizeof(QuantizerType)) + return ret; + if (p_in->ReadBinary(sizeof(VectorValueType), (char *)&reconstructType) != sizeof(VectorValueType)) + return ret; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading quantizer of type %s with reconstructtype %s.\n", + Helper::Convert::ConvertToString(quantizerType).c_str(), + Helper::Convert::ConvertToString(reconstructType).c_str()); + switch (quantizerType) { - std::shared_ptr IQuantizer::LoadIQuantizer(std::shared_ptr p_in) { - QuantizerType quantizerType = QuantizerType::Undefined; - VectorValueType reconstructType = VectorValueType::Undefined; - std::shared_ptr ret = nullptr; - if (p_in->ReadBinary(sizeof(QuantizerType), (char*)&quantizerType) != sizeof(QuantizerType)) return ret; - if (p_in->ReadBinary(sizeof(VectorValueType), (char*)&reconstructType) != sizeof(VectorValueType)) return ret; - LOG(Helper::LogLevel::LL_Info, "Loading quantizer of type %s with reconstructtype %s.\n", Helper::Convert::ConvertToString(quantizerType).c_str(), Helper::Convert::ConvertToString(reconstructType).c_str()); - switch (quantizerType) { - case QuantizerType::None: - break; - case QuantizerType::Undefined: - break; - case QuantizerType::PQQuantizer: - printf("Resetting Quantizer to type PQQuantizer!\n"); - switch (reconstructType) { - #define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; + case QuantizerType::None: + break; + case QuantizerType::Undefined: + break; + case QuantizerType::PQQuantizer: + printf("Resetting Quantizer to type PQQuantizer!\n"); + switch (reconstructType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + ret.reset(new PQQuantizer()); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: break; - } - - if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); - return ret; - case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; + default: + break; + } + + if (ret->LoadQuantizer(p_in) != ErrorCode::Success) + ret.reset(); + return ret; + case QuantizerType::OPQQuantizer: + switch (reconstructType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + ret.reset(new OPQQuantizer()); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: break; - } - if (ret->LoadQuantizer(p_in) != ErrorCode::Success) ret.reset(); - return ret; - } - return ret; + default: + break; } + if (ret->LoadQuantizer(p_in) != ErrorCode::Success) + ret.reset(); + return ret; + } + return ret; +} - std::shared_ptr IQuantizer::LoadIQuantizer(SPTAG::ByteArray bytes) - { - auto raw_bytes = bytes.Data(); - QuantizerType quantizerType = *(QuantizerType*) raw_bytes; - raw_bytes += sizeof(QuantizerType); +std::shared_ptr IQuantizer::LoadIQuantizer(SPTAG::ByteArray bytes) +{ + auto raw_bytes = bytes.Data(); + QuantizerType quantizerType = *(QuantizerType *)raw_bytes; + raw_bytes += sizeof(QuantizerType); - VectorValueType reconstructType = *(VectorValueType*)raw_bytes; - raw_bytes += sizeof(VectorValueType); - LOG(Helper::LogLevel::LL_Info, "Loading quantizer of type %s with reconstructtype %s.\n", Helper::Convert::ConvertToString(quantizerType).c_str(), Helper::Convert::ConvertToString(reconstructType).c_str()); - std::shared_ptr ret = nullptr; + VectorValueType reconstructType = *(VectorValueType *)raw_bytes; + raw_bytes += sizeof(VectorValueType); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading quantizer of type %s with reconstructtype %s.\n", + Helper::Convert::ConvertToString(quantizerType).c_str(), + Helper::Convert::ConvertToString(reconstructType).c_str()); + std::shared_ptr ret = nullptr; - switch (quantizerType) { - case QuantizerType::None: - return ret; - case QuantizerType::Undefined: - return ret; - case QuantizerType::PQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new PQQuantizer()); \ - break; + switch (quantizerType) + { + case QuantizerType::None: + return ret; + case QuantizerType::Undefined: + return ret; + case QuantizerType::PQQuantizer: + switch (reconstructType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + ret.reset(new PQQuantizer()); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: break; - } + default: + break; + } - if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); - return ret; - case QuantizerType::OPQQuantizer: - switch (reconstructType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - ret.reset(new OPQQuantizer()); \ - break; + if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) + ret.reset(); + return ret; + case QuantizerType::OPQQuantizer: + switch (reconstructType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + ret.reset(new OPQQuantizer()); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: break; - } - - if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) ret.reset(); - return ret; - } - return ret; + default: + break; } - template <> - std::function IQuantizer::DistanceCalcSelector(SPTAG::DistCalcMethod p_method) const - { - if (p_method == SPTAG::DistCalcMethod::L2) - { - return ([this](const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) {return L2Distance(pX, pY); }); - } - else - { - return ([this](const std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) {return CosineDistance(pX, pY); }); - } - } + if (ret->LoadQuantizer(raw_bytes) != ErrorCode::Success) + ret.reset(); + return ret; + } + return ret; +} - template - std::function IQuantizer::DistanceCalcSelector(SPTAG::DistCalcMethod p_method) const - { - return SPTAG::COMMON::DistanceCalcSelector(p_method); - } +template <> +std::function IQuantizer::DistanceCalcSelector< + std::uint8_t>(SPTAG::DistCalcMethod p_method) const +{ + if (p_method == SPTAG::DistCalcMethod::L2) + { + return ([this](const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { + return L2Distance(pX, pY); + }); + } + else + { + return ([this](const std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { + return CosineDistance(pX, pY); + }); + } +} + +template +std::function IQuantizer::DistanceCalcSelector( + SPTAG::DistCalcMethod p_method) const +{ + return SPTAG::COMMON::DistanceCalcSelector(p_method); +} -#define DefineVectorValueType(Name, Type) \ -template std::function IQuantizer::DistanceCalcSelector(SPTAG::DistCalcMethod p_method) const; +#define DefineVectorValueType(Name, Type) \ + template std::function IQuantizer::DistanceCalcSelector( \ + SPTAG::DistCalcMethod p_method) const; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - } -} +} // namespace COMMON +} // namespace SPTAG diff --git a/AnnService/src/Core/Common/InstructionUtils.cpp b/AnnService/src/Core/Common/InstructionUtils.cpp index e1ef2cf94..38b849b1e 100644 --- a/AnnService/src/Core/Common/InstructionUtils.cpp +++ b/AnnService/src/Core/Common/InstructionUtils.cpp @@ -2,80 +2,96 @@ #include "inc/Core/Common.h" #ifndef _MSC_VER -void cpuid(int info[4], int InfoType) { +void cpuid(int info[4], int InfoType) +{ __cpuid_count(InfoType, 0, info[0], info[1], info[2], info[3]); } #endif -namespace SPTAG { - namespace COMMON { - const InstructionSet::InstructionSet_Internal InstructionSet::CPU_Rep; +namespace SPTAG +{ +namespace COMMON +{ +const InstructionSet::InstructionSet_Internal InstructionSet::CPU_Rep; - bool InstructionSet::SSE(void) { return CPU_Rep.HW_SSE; } - bool InstructionSet::SSE2(void) { return CPU_Rep.HW_SSE2; } - bool InstructionSet::AVX(void) { return CPU_Rep.HW_AVX; } - bool InstructionSet::AVX2(void) { return CPU_Rep.HW_AVX2; } - bool InstructionSet::AVX512(void) { return CPU_Rep.HW_AVX512; } - - void InstructionSet::PrintInstructionSet(void) - { - if (CPU_Rep.HW_AVX512) - LOG(Helper::LogLevel::LL_Info, "Using AVX512 InstructionSet!\n"); - else if (CPU_Rep.HW_AVX2) - LOG(Helper::LogLevel::LL_Info, "Using AVX2 InstructionSet!\n"); - else if (CPU_Rep.HW_AVX) - LOG(Helper::LogLevel::LL_Info, "Using AVX InstructionSet!\n"); - else if (CPU_Rep.HW_SSE2) - LOG(Helper::LogLevel::LL_Info, "Using SSE2 InstructionSet!\n"); - else if (CPU_Rep.HW_SSE) - LOG(Helper::LogLevel::LL_Info, "Using SSE InstructionSet!\n"); - else - LOG(Helper::LogLevel::LL_Info, "Using NONE InstructionSet!\n"); - } +bool InstructionSet::SSE(void) +{ + return CPU_Rep.HW_SSE; +} +bool InstructionSet::SSE2(void) +{ + return CPU_Rep.HW_SSE2; +} +bool InstructionSet::AVX(void) +{ + return CPU_Rep.HW_AVX; +} +bool InstructionSet::AVX2(void) +{ + return CPU_Rep.HW_AVX2; +} +bool InstructionSet::AVX512(void) +{ + return CPU_Rep.HW_AVX512; +} + +void InstructionSet::PrintInstructionSet(void) +{ + if (CPU_Rep.HW_AVX512) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX512 InstructionSet!\n"); + else if (CPU_Rep.HW_AVX2) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX2 InstructionSet!\n"); + else if (CPU_Rep.HW_AVX) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX InstructionSet!\n"); + else if (CPU_Rep.HW_SSE2) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using SSE2 InstructionSet!\n"); + else if (CPU_Rep.HW_SSE) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using SSE InstructionSet!\n"); + else + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using NONE InstructionSet!\n"); +} - // from https://stackoverflow.com/a/7495023/5053214 - InstructionSet::InstructionSet_Internal::InstructionSet_Internal() : - HW_SSE{ false }, - HW_SSE2{ false }, - HW_AVX{ false }, - HW_AVX512{ false }, - HW_AVX2{ false } - { - int info[4]; - cpuid(info, 0); - int nIds = info[0]; +// from https://stackoverflow.com/a/7495023/5053214 +InstructionSet::InstructionSet_Internal::InstructionSet_Internal() + : HW_SSE{false}, HW_SSE2{false}, HW_AVX{false}, HW_AVX512{false}, HW_AVX2{false} +{ + int info[4]; + cpuid(info, 0); + int nIds = info[0]; - // Detect Features - if (nIds >= 0x00000001) { - cpuid(info, 0x00000001); - HW_SSE = (info[3] & ((int)1 << 25)) != 0; - HW_SSE2 = (info[3] & ((int)1 << 26)) != 0; - HW_AVX = (info[2] & ((int)1 << 28)) != 0; - } - if (nIds >= 0x00000007) { - cpuid(info, 0x00000007); - HW_AVX2 = (info[1] & ((int)1 << 5)) != 0; - HW_AVX512 = (info[1] & (((int)1 << 16) | ((int) 1 << 30))); + // Detect Features + if (nIds >= 0x00000001) + { + cpuid(info, 0x00000001); + HW_SSE = (info[3] & ((int)1 << 25)) != 0; + HW_SSE2 = (info[3] & ((int)1 << 26)) != 0; + HW_AVX = (info[2] & ((int)1 << 28)) != 0; + } + if (nIds >= 0x00000007) + { + cpuid(info, 0x00000007); + HW_AVX2 = (info[1] & ((int)1 << 5)) != 0; + HW_AVX512 = (info[1] & (((int)1 << 16) | ((int)1 << 30))); // If we are not compiling support for AVX-512 due to old compiler version, we should not call it #ifdef _MSC_VER #if _MSC_VER < 1920 - HW_AVX512 = false; + HW_AVX512 = false; #endif #endif - } - if (HW_AVX512) - LOG(Helper::LogLevel::LL_Info, "Using AVX512 InstructionSet!\n"); - else if (HW_AVX2) - LOG(Helper::LogLevel::LL_Info, "Using AVX2 InstructionSet!\n"); - else if (HW_AVX) - LOG(Helper::LogLevel::LL_Info, "Using AVX InstructionSet!\n"); - else if (HW_SSE2) - LOG(Helper::LogLevel::LL_Info, "Using SSE2 InstructionSet!\n"); - else if (HW_SSE) - LOG(Helper::LogLevel::LL_Info, "Using SSE InstructionSet!\n"); - else - LOG(Helper::LogLevel::LL_Info, "Using NONE InstructionSet!\n"); - } } + if (HW_AVX512) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX512 InstructionSet!\n"); + else if (HW_AVX2) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX2 InstructionSet!\n"); + else if (HW_AVX) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using AVX InstructionSet!\n"); + else if (HW_SSE2) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using SSE2 InstructionSet!\n"); + else if (HW_SSE) + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using SSE InstructionSet!\n"); + else + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using NONE InstructionSet!\n"); } +} // namespace COMMON +} // namespace SPTAG diff --git a/AnnService/src/Core/Common/Kernel.cu b/AnnService/src/Core/Common/Kernel.cu index d2ebee0b8..c31e730d8 100644 --- a/AnnService/src/Core/Common/Kernel.cu +++ b/AnnService/src/Core/Common/Kernel.cu @@ -28,9 +28,12 @@ #include #include #include +#include +//#include "inc/Core/Common/TruthSet.h" #include "inc/Core/Common/cuda/params.h" #include "inc/Core/Common/cuda/TPtree.hxx" +//#include "inc/Core/Common/cuda/KNN.hxx" /***************************************************************************************** * Convert sums to means for each split key diff --git a/AnnService/src/Core/Common/NeighborhoodGraph.cpp b/AnnService/src/Core/Common/NeighborhoodGraph.cpp index 73eb457ad..b57ef9726 100644 --- a/AnnService/src/Core/Common/NeighborhoodGraph.cpp +++ b/AnnService/src/Core/Common/NeighborhoodGraph.cpp @@ -14,7 +14,7 @@ std::shared_ptr NeighborhoodGraph::CreateInstance(std::string { res.reset(new RelativeNeighborhoodGraph); } - else if (type == "NNG") + else if (type == "NNG") { res.reset(new KNearestNeighborhoodGraph); } diff --git a/AnnService/src/Core/Common/SIMDUtils.cpp b/AnnService/src/Core/Common/SIMDUtils.cpp index 428552bb9..ba59fc0e9 100644 --- a/AnnService/src/Core/Common/SIMDUtils.cpp +++ b/AnnService/src/Core/Common/SIMDUtils.cpp @@ -2,188 +2,206 @@ // Licensed under the MIT License. #include "inc/Core/Common/SIMDUtils.h" -#include using namespace SPTAG; using namespace SPTAG::COMMON; -void SIMDUtils::ComputeSum_SSE(std::int8_t* pX, const std::int8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_SSE(std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX(std::int8_t* pX, const std::int8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX(std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX512(std::int8_t* pX, const std::int8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX512(std::int8_t *pX, const std::int8_t *pY, DimensionType length) { - const std::int8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::int8_t* pEnd1 = pX + length; + const std::int8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::int8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_SSE(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_SSE(std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX(std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX512(std::uint8_t* pX, const std::uint8_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX512(std::uint8_t *pX, const std::uint8_t *pY, DimensionType length) { - const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4); - const std::uint8_t* pEnd1 = pX + length; + const std::uint8_t *pEnd16 = pX + ((length >> 4) << 4); + const std::uint8_t *pEnd1 = pX + length; - while (pX < pEnd16) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd16) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi8(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 16; pY += 16; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_SSE(std::int16_t* pX, const std::int16_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_SSE(std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd1 = pX + length; - while (pX < pEnd8) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd8) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi16(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 8; pY += 8; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX(std::int16_t* pX, const std::int16_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX(std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd1 = pX + length; - while (pX < pEnd8) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd8) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi16(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 8; pY += 8; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX512(std::int16_t* pX, const std::int16_t* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX512(std::int16_t *pX, const std::int16_t *pY, DimensionType length) { - const std::int16_t* pEnd8 = pX + ((length >> 3) << 3); - const std::int16_t* pEnd1 = pX + length; + const std::int16_t *pEnd8 = pX + ((length >> 3) << 3); + const std::int16_t *pEnd1 = pX + length; - while (pX < pEnd8) { - __m128i x_part = _mm_loadu_si128((__m128i*) pX); - __m128i y_part = _mm_loadu_si128((__m128i*) pY); + while (pX < pEnd8) + { + __m128i x_part = _mm_loadu_si128((__m128i *)pX); + __m128i y_part = _mm_loadu_si128((__m128i *)pY); x_part = _mm_add_epi16(x_part, y_part); - _mm_storeu_si128((__m128i*) pX, x_part); + _mm_storeu_si128((__m128i *)pX, x_part); pX += 8; pY += 8; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_SSE(float* pX, const float* pY, DimensionType length) +void SIMDUtils::ComputeSum_SSE(float *pX, const float *pY, DimensionType length) { - const float* pEnd4= pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; - while (pX < pEnd4) { + while (pX < pEnd4) + { __m128 x_part = _mm_loadu_ps(pX); __m128 y_part = _mm_loadu_ps(pY); x_part = _mm_add_ps(x_part, y_part); @@ -192,17 +210,19 @@ void SIMDUtils::ComputeSum_SSE(float* pX, const float* pY, DimensionType length) pY += 4; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; - } + } } -void SIMDUtils::ComputeSum_AVX(float* pX, const float* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX(float *pX, const float *pY, DimensionType length) { - const float* pEnd4= pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; - while (pX < pEnd4) { + while (pX < pEnd4) + { __m128 x_part = _mm_loadu_ps(pX); __m128 y_part = _mm_loadu_ps(pY); x_part = _mm_add_ps(x_part, y_part); @@ -211,17 +231,19 @@ void SIMDUtils::ComputeSum_AVX(float* pX, const float* pY, DimensionType length) pY += 4; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } -void SIMDUtils::ComputeSum_AVX512(float* pX, const float* pY, DimensionType length) +void SIMDUtils::ComputeSum_AVX512(float *pX, const float *pY, DimensionType length) { - const float* pEnd4= pX + ((length >> 2) << 2); - const float* pEnd1 = pX + length; + const float *pEnd4 = pX + ((length >> 2) << 2); + const float *pEnd1 = pX + length; - while (pX < pEnd4) { + while (pX < pEnd4) + { __m128 x_part = _mm_loadu_ps(pX); __m128 y_part = _mm_loadu_ps(pY); x_part = _mm_add_ps(x_part, y_part); @@ -230,7 +252,8 @@ void SIMDUtils::ComputeSum_AVX512(float* pX, const float* pY, DimensionType leng pY += 4; } - while (pX < pEnd1) { + while (pX < pEnd1) + { *pX++ += *pY++; } } diff --git a/AnnService/src/Core/Common/TruthSet.cpp b/AnnService/src/Core/Common/TruthSet.cpp new file mode 100644 index 000000000..2d8e88c1d --- /dev/null +++ b/AnnService/src/Core/Common/TruthSet.cpp @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/TruthSet.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/VectorIndex.h" + +#if defined(GPU) +#include +#include +#include +#include +#include + +#include "inc/Core/Common/cuda/KNN.hxx" +#include "inc/Core/Common/cuda/params.h" +#endif + +namespace SPTAG +{ +namespace COMMON +{ +#if defined(GPU) +template +void TruthSet::GenerateTruth(std::shared_ptr querySet, std::shared_ptr vectorSet, + const std::string truthFile, const SPTAG::DistCalcMethod distMethod, const int K, + const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr &quantizer) +{ + if (querySet->Dimension() != vectorSet->Dimension() && !quantizer) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "query and vector have different dimensions."); + exit(1); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin to generate truth for query(%d,%d) and doc(%d,%d)...\n", + querySet->Count(), querySet->Dimension(), vectorSet->Count(), vectorSet->Dimension()); + std::vector> truthset(querySet->Count(), std::vector(K, 0)); + std::vector> distset(querySet->Count(), std::vector(K, 0)); + + GenerateTruthGPU(querySet, vectorSet, truthFile, distMethod, K, p_truthFileType, quantizer, truthset, distset); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to write truth file...\n"); + writeTruthFile(truthFile, querySet->Count(), K, truthset, distset, p_truthFileType); + + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize((truthFile + ".dist.bin").c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", (truthFile + ".dist.bin").c_str()); + exit(1); + } + + int int32_queryNumber = (int)querySet->Count(); + ptr->WriteBinary(4, (char *)&int32_queryNumber); + ptr->WriteBinary(4, (char *)&K); + + for (size_t i = 0; i < int32_queryNumber; i++) + { + for (int k = 0; k < K; k++) + { + if (ptr->WriteBinary(4, (char *)(&(truthset[i][k]))) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); + exit(1); + } + if (ptr->WriteBinary(4, (char *)(&(distset[i][k]))) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); + exit(1); + } + } + } +} +#else +template +void TruthSet::GenerateTruth(std::shared_ptr querySet, std::shared_ptr vectorSet, + const std::string truthFile, const SPTAG::DistCalcMethod distMethod, const int K, + const SPTAG::TruthFileType p_truthFileType, const std::shared_ptr &quantizer) +{ + if (querySet->Dimension() != vectorSet->Dimension() && !quantizer) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "query and vector have different dimensions."); + exit(1); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin to generate truth for query(%d,%d) and doc(%d,%d)...\n", + querySet->Count(), querySet->Dimension(), vectorSet->Count(), vectorSet->Dimension()); + std::vector> truthset(querySet->Count(), std::vector(K, 0)); + std::vector> distset(querySet->Count(), std::vector(K, 0)); + auto fComputeDistance = + quantizer ? quantizer->DistanceCalcSelector(distMethod) : COMMON::DistanceCalcSelector(distMethod); + + std::vector mythreads; + int maxthreads = std::thread::hardware_concurrency(); + mythreads.reserve(maxthreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < maxthreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < querySet->Count()) + { + SPTAG::COMMON::QueryResultSet query((const T *)(querySet->GetVector(i)), K); + query.SetTarget((const T *)(querySet->GetVector(i)), quantizer); + for (SPTAG::SizeType j = 0; j < vectorSet->Count(); j++) + { + float dist = + fComputeDistance(query.GetQuantizedTarget(), reinterpret_cast(vectorSet->GetVector(j)), + vectorSet->Dimension()); + query.AddPoint(j, dist); + } + query.SortResult(); + + for (int k = 0; k < K; k++) + { + truthset[i][k] = (query.GetResult(k))->VID; + distset[i][k] = (query.GetResult(k))->Dist; + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to write truth file...\n"); + writeTruthFile(truthFile, querySet->Count(), K, truthset, distset, p_truthFileType); + + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize((truthFile + ".dist.bin").c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", (truthFile + ".dist.bin").c_str()); + exit(1); + } + + int int32_queryNumber = (int)querySet->Count(); + ptr->WriteBinary(4, (char *)&int32_queryNumber); + ptr->WriteBinary(4, (char *)&K); + + for (size_t i = 0; i < int32_queryNumber; i++) + { + for (int k = 0; k < K; k++) + { + if (ptr->WriteBinary(4, (char *)(&(truthset[i][k]))) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); + exit(1); + } + if (ptr->WriteBinary(4, (char *)(&(distset[i][k]))) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth dist file!\n"); + exit(1); + } + } + } +} + +#endif // (GPU) + +#define DefineVectorValueType(Name, Type) \ + template void TruthSet::GenerateTruth( \ + std::shared_ptr querySet, std::shared_ptr vectorSet, const std::string truthFile, \ + const SPTAG::DistCalcMethod distMethod, const int K, const SPTAG::TruthFileType p_truthFileType, \ + const std::shared_ptr &quantizer); +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType +} // namespace COMMON +} // namespace SPTAG \ No newline at end of file diff --git a/AnnService/src/Core/KDT/KDTIndex.cpp b/AnnService/src/Core/KDT/KDTIndex.cpp index fef4b6373..f768a6054 100644 --- a/AnnService/src/Core/KDT/KDTIndex.cpp +++ b/AnnService/src/Core/KDT/KDTIndex.cpp @@ -2,554 +2,932 @@ // Licensed under the MIT License. #include "inc/Core/KDT/Index.h" +#include "inc/Core/ResultIterator.h" #include -#pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4244) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4127) // conditional expression is constant +#pragma warning(disable : 4242) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4244) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4127) // conditional expression is constant namespace SPTAG { - namespace KDT +template thread_local std::unique_ptr COMMON::ThreadLocalWorkSpaceFactory::m_workspace; +namespace KDT +{ + +template ErrorCode Index::LoadConfig(Helper::IniReader &p_reader) +{ +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, p_reader.GetParameter("Index", RepresentStr, std::string(#DefaultValue)).c_str()); + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + return ErrorCode::Success; +} + +template <> void Index::SetQuantizer(std::shared_ptr quantizer) +{ + m_pQuantizer = quantizer; + m_pTrees.m_pQuantizer = quantizer; + if (m_pQuantizer) { - template - ErrorCode Index::LoadConfig(Helper::IniReader& p_reader) - { -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - SetParameter(RepresentStr, \ - p_reader.GetParameter("Index", \ - RepresentStr, \ - std::string(#DefaultValue)).c_str()); \ + m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod); + m_iBaseSquare = + (m_iDistCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; + } + else + { + m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); + m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) + ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() + : 1; + } +} + +template void Index::SetQuantizer(std::shared_ptr quantizer) +{ + m_pQuantizer = quantizer; + m_pTrees.m_pQuantizer = quantizer; + if (quantizer) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Set non-null quantizer for index with data type other than BYTE"); + } +} + +template ErrorCode Index::LoadIndexDataFromMemory(const std::vector &p_indexBlobs) +{ + if (p_indexBlobs.size() < 3) + return ErrorCode::LackOfInputs; + + if (m_pSamples.Load((char *)p_indexBlobs[0].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (m_pTrees.LoadTrees((char *)p_indexBlobs[1].Data()) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (m_pGraph.LoadGraph((char *)p_indexBlobs[2].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + if (p_indexBlobs.size() <= 3) + m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + else if (m_deletedID.Load((char *)p_indexBlobs[3].Data(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains) != ErrorCode::Success) + return ErrorCode::FailedParseValue; + + if (m_pSamples.R() != m_pGraph.R() || m_pSamples.R() != m_deletedID.R()) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Index data is corrupted, please rebuild the index. Samples: %i, Graph: %i, DeletedID: %i.", + m_pSamples.R(), m_pGraph.R(), m_deletedID.R()); + return ErrorCode::FailedParseValue; + } + + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + m_threadPool.init(); + return ErrorCode::Success; +} + +template +ErrorCode Index::LoadIndexData(const std::vector> &p_indexStreams) +{ + if (p_indexStreams.size() < 4) + return ErrorCode::LackOfInputs; + + ErrorCode ret = ErrorCode::Success; + if (p_indexStreams[0] == nullptr || + (ret = m_pSamples.Load(p_indexStreams[0], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) + return ret; + if (p_indexStreams[1] == nullptr || (ret = m_pTrees.LoadTrees(p_indexStreams[1])) != ErrorCode::Success) + return ret; + if (p_indexStreams[2] == nullptr || + (ret = m_pGraph.LoadGraph(p_indexStreams[2], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) + return ret; + if (p_indexStreams[3] == nullptr) + m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + else if ((ret = m_deletedID.Load(p_indexStreams[3], m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains)) != ErrorCode::Success) + return ret; + + if (m_pSamples.R() != m_pGraph.R() || m_pSamples.R() != m_deletedID.R()) + { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, + "Index data is corrupted, please rebuild the index. Samples: %i, Graph: %i, DeletedID: %i.", + m_pSamples.R(), m_pGraph.R(), m_deletedID.R()); + return ErrorCode::FailedParseValue; + } + + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + m_threadPool.init(); + return ret; +} + +template ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) +{ + auto workspace = m_workSpaceFactory->GetWorkSpace(); + if (workspace) + { + m_iHashTableExp = workspace->HashTableExponent(); + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workspace)); + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + GetParameter(RepresentStr) + std::string("\n")).c_str()); #include "inc/Core/KDT/ParameterDefinitionList.h" #undef DefineKDTParameter - return ErrorCode::Success; + + IOSTRING(p_configOut, WriteString, "\n"); + return ErrorCode::Success; +} + +template +ErrorCode Index::SaveIndexData(const std::vector> &p_indexStreams) +{ + if (p_indexStreams.size() < 4) + return ErrorCode::LackOfInputs; + + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); + + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Save(p_indexStreams[0])) != ErrorCode::Success) + return ret; + if ((ret = m_pTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) + return ret; + if ((ret = m_pGraph.SaveGraph(p_indexStreams[2])) != ErrorCode::Success) + return ret; + if ((ret = m_deletedID.Save(p_indexStreams[3])) != ErrorCode::Success) + return ret; + return ret; +} + +#pragma region K - NN search +/* +#define Search(CheckDeleted) \ +std::shared_lock lock(*(m_pTrees.m_lock)); \ +m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); \ +m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ +while (!p_space.m_NGQueue.empty()) { \ + NodeDistPair gnode = p_space.m_NGQueue.pop(); \ + const SizeType *node = m_pGraph[gnode.node]; \ + _mm_prefetch((const char *)node, _MM_HINT_T0); \ + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) \ + _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + CheckDeleted { \ + if (!p_query.AddPoint(gnode.node, gnode.distance) && p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ + p_query.SortResult(); return; \ + } \ + } \ + float upperBound = max(p_query.worstDist(), gnode.distance); \ + bool bLocalOpt = true; \ + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) { \ + SizeType nn_index = node[i]; \ + if (nn_index < 0) break; \ + if (p_space.CheckAndSet(nn_index)) continue; \ + float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); +\ + if (distance2leaf <= upperBound) bLocalOpt = false; \ + p_space.m_iNumberOfCheckedLeaves++; \ + p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); \ + } \ + if (bLocalOpt) p_space.m_iNumOfContinuousNoBetterPropagation++; \ + else p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ + if (p_space.m_iNumOfContinuousNoBetterPropagation > m_iThresholdOfNumberOfContinuousNoBetterPropagation) { \ + if (p_space.m_iNumberOfTreeCheckedLeaves <= p_space.m_iNumberOfCheckedLeaves / 10) { \ + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + +p_space.m_iNumberOfCheckedLeaves); \ + } else if (gnode.distance > p_query.worstDist()) { \ + break; \ + } \ + } \ +} \ +p_query.SortResult(); \ +*/ +template +template +void Index::Search(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const +{ + std::shared_lock lock(*(m_pTrees.m_lock)); + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); + while (!p_space.m_NGQueue.empty()) + { + NodeDistPair gnode = p_space.m_NGQueue.pop(); + const SizeType *node = m_pGraph[gnode.node]; + _mm_prefetch((const char *)node, _MM_HINT_T0); + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) + { + auto futureNode = node[i]; + if (futureNode < 0) + break; + _mm_prefetch((const char *)(m_pSamples)[futureNode], _MM_HINT_T0); } - template <> - void Index::SetQuantizer(std::shared_ptr quantizer) + if (notDeleted(m_deletedID, gnode.node)) { - m_pQuantizer = quantizer; - m_pTrees.m_pQuantizer = quantizer; - if (m_pQuantizer) + if (!p_query.AddPoint(gnode.node, gnode.distance) && p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { - m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; - } - else - { - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; + p_query.SortResult(); + return; } } - template - void Index::SetQuantizer(std::shared_ptr quantizer) + float upperBound = max(p_query.worstDist(), gnode.distance); + bool bLocalOpt = true; + for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) + { + SizeType nn_index = node[i]; + if (nn_index < 0) + break; + + if (p_space.CheckAndSet(nn_index)) + continue; + float distance2leaf = + m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); + if (distance2leaf <= upperBound) + bLocalOpt = false; + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); + } + if (bLocalOpt) + p_space.m_iNumOfContinuousNoBetterPropagation++; + else + p_space.m_iNumOfContinuousNoBetterPropagation = 0; + if (p_space.m_iNumOfContinuousNoBetterPropagation > m_iThresholdOfNumberOfContinuousNoBetterPropagation) { - m_pQuantizer = quantizer; - m_pTrees.m_pQuantizer = quantizer; - if (quantizer) + if (p_space.m_iNumberOfTreeCheckedLeaves <= p_space.m_iNumberOfCheckedLeaves / 10) { - LOG(SPTAG::Helper::LogLevel::LL_Error, "Set non-null quantizer for index with data type other than BYTE"); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, + m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); + } + else if (gnode.distance > p_query.worstDist()) + { + break; } } + } + p_query.SortResult(); +} - template - ErrorCode Index::LoadIndexDataFromMemory(const std::vector& p_indexBlobs) - { - if (p_indexBlobs.size() < 3) return ErrorCode::LackOfInputs; - - if (m_pSamples.Load((char*)p_indexBlobs[0].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (m_pTrees.LoadTrees((char*)p_indexBlobs[1].Data()) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (m_pGraph.LoadGraph((char*)p_indexBlobs[2].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - if (p_indexBlobs.size() <= 3) m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity); - else if (m_deletedID.Load((char*)p_indexBlobs[3].Data(), m_iDataBlockSize, m_iDataCapacity) != ErrorCode::Success) return ErrorCode::FailedParseValue; - - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); - return ErrorCode::Success; - } +namespace StaticDispatch +{ +bool AlwaysTrue(const COMMON::Labelset &deletedIDs, SizeType node) +{ + return true; +} - template - ErrorCode Index::LoadIndexData(const std::vector>& p_indexStreams) +bool CheckIfNotDeleted(const COMMON::Labelset &deletedIDs, SizeType node) +{ + return !deletedIDs.Contains(node); +} +}; // namespace StaticDispatch + +template +template +void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, bool p_searchDeleted) const +{ + if (m_deletedID.Count() == 0 || p_searchDeleted) + { + Search(p_query, p_space); + } + else + { + Search(p_query, p_space); + } + p_query.SetScanned(p_space.m_iNumberOfCheckedLeaves); +} + +template ErrorCode Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; + + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_iMaxCheck); + + COMMON::QueryResultSet *p_results = (COMMON::QueryResultSet *)&p_query; + + if (m_pQuantizer) + { + if (!(p_results->HasQuantizedTarget())) { - if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; - - ErrorCode ret = ErrorCode::Success; - if (p_indexStreams[0] == nullptr || (ret = m_pSamples.Load(p_indexStreams[0], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - if (p_indexStreams[1] == nullptr || (ret = m_pTrees.LoadTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; - if (p_indexStreams[2] == nullptr || (ret = m_pGraph.LoadGraph(p_indexStreams[2], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - if (p_indexStreams[3] == nullptr) m_deletedID.Initialize(m_pSamples.R(), m_iDataBlockSize, m_iDataCapacity); - else if ((ret = m_deletedID.Load(p_indexStreams[3], m_iDataBlockSize, m_iDataCapacity)) != ErrorCode::Success) return ret; - - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); - return ret; + p_results->SetTarget(p_results->GetTarget(), m_pQuantizer); } - - template - ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) + switch (m_pQuantizer->GetReconstructType()) { - auto workSpace = m_workSpacePool->Rent(); - m_iHashTableExp = workSpace->HashTableExponent(); - m_workSpacePool->Return(workSpace); +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + SearchIndex(*p_results, *workSpace, p_searchDeleted); \ + break; -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + GetParameter(RepresentStr) + std::string("\n")).c_str()); +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType -#include "inc/Core/KDT/ParameterDefinitionList.h" -#undef DefineKDTParameter - - IOSTRING(p_configOut, WriteString, "\n"); - return ErrorCode::Success; + default: + break; } + } + else + { + SearchIndex(*p_results, *workSpace, p_searchDeleted); + } + + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); - template - ErrorCode Index::SaveIndexData(const std::vector>& p_indexStreams) + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) { - if (p_indexStreams.size() < 4) return ErrorCode::LackOfInputs; + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); + return ErrorCode::Success; +} - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Save(p_indexStreams[0])) != ErrorCode::Success) return ret; - if ((ret = m_pTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; - if ((ret = m_pGraph.SaveGraph(p_indexStreams[2])) != ErrorCode::Success) return ret; - if ((ret = m_deletedID.Save(p_indexStreams[3])) != ErrorCode::Success) return ret; - return ret; - } +template +std::shared_ptr Index::GetIterator(const void *p_target, bool p_searchDeleted, std::function p_filterFunc, int p_maxCheck) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ITERATIVE NOT SUPPORT FOR KDT"); + return nullptr; +} -#pragma region K-NN search +template +ErrorCode Index::SearchIndexIterativeNext(QueryResult &p_query, COMMON::WorkSpace *workSpace, int p_batch, + int &resultCount, bool p_isFirst, bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ITERATIVE NOT SUPPORT FOR KDT"); + return ErrorCode::Fail; +} -#define Search(CheckDeleted) \ - if (m_pQuantizer && !p_query.HasQuantizedTarget()) \ - { p_query.SetTarget(p_query.GetTarget(), m_pQuantizer); } \ - std::shared_lock lock(*(m_pTrees.m_lock)); \ - m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space); \ - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfInitialDynamicPivots); \ - while (!p_space.m_NGQueue.empty()) { \ - NodeDistPair gnode = p_space.m_NGQueue.pop(); \ - const SizeType *node = m_pGraph[gnode.node]; \ - _mm_prefetch((const char *)node, _MM_HINT_T0); \ - for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) \ - _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ - CheckDeleted { \ - if (!p_query.AddPoint(gnode.node, gnode.distance) && p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ - p_query.SortResult(); return; \ - } \ - } \ - float upperBound = max(p_query.worstDist(), gnode.distance); \ - bool bLocalOpt = true; \ - for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) { \ - SizeType nn_index = node[i]; \ - if (nn_index < 0) break; \ - if (p_space.CheckAndSet(nn_index)) continue; \ - float distance2leaf = m_fComputeDistance(p_query.GetQuantizedTarget(), (m_pSamples)[nn_index], GetFeatureDim()); \ - if (distance2leaf <= upperBound) bLocalOpt = false; \ - p_space.m_iNumberOfCheckedLeaves++; \ - p_space.m_NGQueue.insert(NodeDistPair(nn_index, distance2leaf)); \ - } \ - if (bLocalOpt) p_space.m_iNumOfContinuousNoBetterPropagation++; \ - else p_space.m_iNumOfContinuousNoBetterPropagation = 0; \ - if (p_space.m_iNumOfContinuousNoBetterPropagation > m_iThresholdOfNumberOfContinuousNoBetterPropagation) { \ - if (p_space.m_iNumberOfTreeCheckedLeaves <= p_space.m_iNumberOfCheckedLeaves / 10) { \ - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, p_query, p_space, m_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves); \ - } else if (gnode.distance > p_query.worstDist()) { \ - break; \ - } \ - } \ - } \ - p_query.SortResult(); \ +template ErrorCode Index::SearchIndexIterativeEnd(std::unique_ptr space) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ITERATIVE NOT SUPPORT FOR KDT"); + return ErrorCode::Fail; +} - template - void Index::SearchIndexWithoutDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const - { - Search(if (!m_deletedID.Contains(gnode.node))) - } +template +bool Index::SearchIndexIterativeFromNeareast(QueryResult &p_query, COMMON::WorkSpace *p_space, bool p_isFirst, + bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SearchIndexIterativeFromNeareast NOT SUPPORT FOR KDT"); + return false; +} - template - void Index::SearchIndexWithDeleted(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space) const +template std::unique_ptr Index::RentWorkSpace(int batch, std::function p_filterFunc, int p_maxCheck) const +{ + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new COMMON::WorkSpace()); + workSpace->Initialize(max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); + } + workSpace->Reset(p_maxCheck > 0? p_maxCheck: m_iMaxCheck, batch); + workSpace->m_filterFunc = p_filterFunc; + return std::move(workSpace); +} + +template +ErrorCode Index::SearchIndexWithFilter(QueryResult &p_query, std::function filterFunc, + int maxCheck, bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Not Support Filter on KDT Index!\n"); + return ErrorCode::Fail; +} + +template ErrorCode Index::RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted) const +{ + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_pGraph.m_iMaxCheckForRefineGraph); + + COMMON::QueryResultSet *p_results = (COMMON::QueryResultSet *)&p_query; + + if (m_pQuantizer) + { + if (!(p_results->HasQuantizedTarget())) { - Search(;) + p_results->SetTarget(p_results->GetTarget(), m_pQuantizer); } - - template - ErrorCode - Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const + switch (m_pQuantizer->GetReconstructType()) { - if (!m_bReady) return ErrorCode::EmptyIndex; +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + SearchIndex(*p_results, *workSpace, p_searchDeleted); \ + break; - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_iMaxCheck, p_query.GetResultNum()); +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType - if (m_deletedID.Count() == 0 || p_searchDeleted) - SearchIndexWithDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); - else - SearchIndexWithoutDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); + default: + break; + } + } + else + { + SearchIndex(*p_results, *workSpace, p_searchDeleted); + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); - m_workSpacePool->Return(workSpace); + return ErrorCode::Success; +} - if (p_query.WithMeta() && nullptr != m_pMetadata) - { - for (int i = 0; i < p_query.GetResultNum(); ++i) - { - SizeType result = p_query.GetResult(i)->VID; - p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); - } - } - return ErrorCode::Success; - } +template ErrorCode Index::SearchTree(QueryResult &p_query) const +{ + auto workSpace = RentWorkSpace(p_query.GetResultNum(), nullptr, m_pGraph.m_iMaxCheckForRefineGraph); + + COMMON::QueryResultSet *p_results = (COMMON::QueryResultSet *)&p_query; - template - ErrorCode Index::RefineSearchIndex(QueryResult &p_query, bool p_searchDeleted) const + if (m_pQuantizer) + { + if (!(p_results->HasQuantizedTarget())) + { + p_results->SetTarget(p_results->GetTarget(), m_pQuantizer); + } + switch (m_pQuantizer->GetReconstructType()) { - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_pGraph.m_iMaxCheckForRefineGraph, p_query.GetResultNum()); +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace); \ + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace, \ + m_iNumberOfInitialDynamicPivots); \ + break; - if (m_deletedID.Count() == 0 || p_searchDeleted) - SearchIndexWithDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); - else - SearchIndexWithoutDeleted(*((COMMON::QueryResultSet*)&p_query), *workSpace); +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType - m_workSpacePool->Return(workSpace); - return ErrorCode::Success; + default: + break; } + } + else + { + m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace); + m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace, + m_iNumberOfInitialDynamicPivots); + } - template - ErrorCode Index::SearchTree(QueryResult& p_query) const - { - auto workSpace = m_workSpacePool->Rent(); - workSpace->Reset(m_pGraph.m_iMaxCheckForRefineGraph, p_query.GetResultNum()); - - COMMON::QueryResultSet* p_results = (COMMON::QueryResultSet*)&p_query; - m_pTrees.InitSearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace); - m_pTrees.SearchTrees(m_pSamples, m_fComputeDistance, *p_results, *workSpace, m_iNumberOfInitialDynamicPivots); - BasicResult * res = p_query.GetResults(); - for (int i = 0; i < p_query.GetResultNum(); i++) - { - auto& cell = workSpace->m_NGQueue.pop(); - res[i].VID = cell.node; - res[i].Dist = cell.distance; - } - m_workSpacePool->Return(workSpace); - return ErrorCode::Success; - } + BasicResult *res = p_query.GetResults(); + for (int i = 0; i < p_query.GetResultNum(); i++) + { + auto &cell = workSpace->m_NGQueue.pop(); + res[i].VID = cell.node; + res[i].Dist = cell.distance; + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + + return ErrorCode::Success; +} #pragma endregion - template - ErrorCode Index::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, bool p_shareOwnership) - { - if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; +template +ErrorCode Index::BuildIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, + bool p_shareOwnership) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; - omp_set_num_threads(m_iNumberOfThreads); + m_pGraph.m_iThreadNum = m_iNumberOfThreads; - m_pSamples.Initialize(p_vectorNum, p_dimension, m_iDataBlockSize, m_iDataCapacity, (T*)p_data, p_shareOwnership); - m_deletedID.Initialize(p_vectorNum, m_iDataBlockSize, m_iDataCapacity); + m_pSamples.Initialize(p_vectorNum, p_dimension, m_iDataBlockSize, m_iDataCapacity, (T *)p_data, p_shareOwnership); + m_deletedID.Initialize(p_vectorNum, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); - if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) - { - int base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); -#pragma omp parallel for - for (SizeType i = 0; i < GetNumSamples(); i++) { - COMMON::Utils::Normalize(m_pSamples[i], GetFeatureDim(), base); - } - } + if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) + { + int base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); + COMMON::Utils::BatchNormalize(m_pSamples[0], p_vectorNum, p_dimension, base, m_iNumberOfThreads); + } - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - m_threadPool.init(); + m_threadPool.init(); - auto t1 = std::chrono::high_resolution_clock::now(); + auto t1 = std::chrono::high_resolution_clock::now(); - m_pTrees.BuildTrees(m_pSamples, m_iNumberOfThreads); + m_pTrees.BuildTrees(m_pSamples, m_iNumberOfThreads); - auto t2 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); - m_pGraph.BuildGraph(this); - auto t3 = std::chrono::high_resolution_clock::now(); - LOG(Helper::LogLevel::LL_Info, "Build Graph time (s): %lld\n", std::chrono::duration_cast(t3 - t2).count()); + auto t2 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n", + std::chrono::duration_cast(t2 - t1).count()); + m_pGraph.BuildGraph(this); + auto t3 = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Graph time (s): %lld\n", + std::chrono::duration_cast(t3 - t2).count()); - m_bReady = true; - return ErrorCode::Success; - } + m_bReady = true; + return ErrorCode::Success; +} - template - ErrorCode Index::RefineIndex(std::shared_ptr& p_newIndex) - { - p_newIndex.reset(new Index()); - Index* ptr = (Index*)p_newIndex.get(); +template ErrorCode Index::RefineIndex(std::shared_ptr &p_newIndex) +{ + p_newIndex.reset(new Index()); + Index *ptr = (Index *)p_newIndex.get(); -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - ptr->VarName = VarName; \ +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) ptr->VarName = VarName; #include "inc/Core/KDT/ParameterDefinitionList.h" #undef DefineKDTParameter - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); - - SizeType newR = GetNumSamples(); + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); - std::vector indices; - std::vector reverseIndices(newR); - for (SizeType i = 0; i < newR; i++) { - if (!m_deletedID.Contains(i)) { - indices.push_back(i); - reverseIndices[i] = i; - } - else { - while (m_deletedID.Contains(newR - 1) && newR > i) newR--; - if (newR == i) break; - indices.push_back(newR - 1); - reverseIndices[newR - 1] = i; - newR--; - } - } + SizeType newR = GetNumSamples(); - LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); - if (newR == 0) return ErrorCode::EmptyIndex; + std::vector indices; + std::vector reverseIndices(newR); + for (SizeType i = 0; i < newR; i++) + { + if (!m_deletedID.Contains(i)) + { + indices.push_back(i); + reverseIndices[i] = i; + } + else + { + while (m_deletedID.Contains(newR - 1) && newR > i) + newR--; + if (newR == i) + break; + indices.push_back(newR - 1); + reverseIndices[newR - 1] = i; + newR--; + } + } - ptr->m_workSpacePool.reset(new COMMON::WorkSpacePool()); - ptr->m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - ptr->m_threadPool.init(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); + if (newR == 0) + return ErrorCode::EmptyIndex; + + ptr->m_threadPool.init(); + + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Refine(indices, ptr->m_pSamples)) != ErrorCode::Success) + return ret; + if (nullptr != m_pMetadata && + (ret = m_pMetadata->RefineMetadata(indices, ptr->m_pMetadata, m_iDataBlockSize, m_iDataCapacity, + m_iMetaRecordSize)) != ErrorCode::Success) + return ret; + + ptr->m_deletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + COMMON::KDTree *newtree = &(ptr->m_pTrees); + + (*newtree).BuildTrees(ptr->m_pSamples, m_iNumberOfThreads); + m_pGraph.RefineGraph(this, indices, reverseIndices, nullptr, &(ptr->m_pGraph)); + if (HasMetaMapping()) + ptr->BuildMetaMapping(false); + ptr->m_bReady = true; + return ret; +} - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Refine(indices, ptr->m_pSamples)) != ErrorCode::Success) return ret; - if (nullptr != m_pMetadata && (ret = m_pMetadata->RefineMetadata(indices, ptr->m_pMetadata, m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)) != ErrorCode::Success) return ret; +template +ErrorCode Index::RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping) +{ + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); - ptr->m_deletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity); - COMMON::KDTree* newtree = &(ptr->m_pTrees); + SizeType newR = GetNumSamples(); - (*newtree).BuildTrees(ptr->m_pSamples, omp_get_num_threads()); - m_pGraph.RefineGraph(this, indices, reverseIndices, nullptr, &(ptr->m_pGraph)); - if (HasMetaMapping()) ptr->BuildMetaMapping(false); - ptr->m_bReady = true; - return ret; + std::vector indices; + std::vector reverseIndices; + if (p_mapping == nullptr) + { + p_mapping = &reverseIndices; + } + p_mapping->resize(newR); + for (SizeType i = 0; i < newR; i++) + { + if (!m_deletedID.Contains(i)) + { + indices.push_back(i); + (*p_mapping)[i] = i; } - - template - ErrorCode Index::RefineIndex(const std::vector>& p_indexStreams, IAbortOperation* p_abort) + else { - std::lock_guard lock(m_dataAddLock); - std::unique_lock uniquelock(m_dataDeleteLock); + while (m_deletedID.Contains(newR - 1) && newR > i) + newR--; + if (newR == i) + break; + indices.push_back(newR - 1); + (*p_mapping)[newR - 1] = i; + newR--; + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); + if (newR == 0) + return ErrorCode::EmptyIndex; - SizeType newR = GetNumSamples(); + ErrorCode ret = ErrorCode::Success; + if ((ret = m_pSamples.Refine(indices, p_indexStreams[0])) != ErrorCode::Success) + return ret; - std::vector indices; - std::vector reverseIndices(newR); - for (SizeType i = 0; i < newR; i++) { - if (!m_deletedID.Contains(i)) { - indices.push_back(i); - reverseIndices[i] = i; + if (p_abort != nullptr && p_abort->ShouldAbort()) + return ErrorCode::ExternalAbort; + + COMMON::KDTree newTrees(m_pTrees); + newTrees.BuildTrees(m_pSamples, m_iNumberOfThreads, &indices); + std::vector mythreads; + mythreads.reserve(m_iNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < newTrees.size()) + { + if (newTrees[i].left < 0) + newTrees[i].left = -(*p_mapping)[-newTrees[i].left - 1] - 1; + if (newTrees[i].right < 0) + newTrees[i].right = -(*p_mapping)[-newTrees[i].right - 1] - 1; } - else { - while (m_deletedID.Contains(newR - 1) && newR > i) newR--; - if (newR == i) break; - indices.push_back(newR - 1); - reverseIndices[newR - 1] = i; - newR--; + else + { + return; } } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + + if ((ret = newTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) + return ret; - LOG(Helper::LogLevel::LL_Info, "Refine... from %d -> %d\n", GetNumSamples(), newR); - if (newR == 0) return ErrorCode::EmptyIndex; + if (p_abort != nullptr && p_abort->ShouldAbort()) + return ErrorCode::ExternalAbort; - ErrorCode ret = ErrorCode::Success; - if ((ret = m_pSamples.Refine(indices, p_indexStreams[0])) != ErrorCode::Success) return ret; + if ((ret = m_pGraph.RefineGraph(this, indices, (*p_mapping), p_indexStreams[2], nullptr)) != + ErrorCode::Success) + return ret; - if (p_abort != nullptr && p_abort->ShouldAbort()) return ErrorCode::ExternalAbort; + COMMON::Labelset newDeletedID; + newDeletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity, + COMMON::Labelset::InvalidIDBehavior::AlwaysContains); + if ((ret = newDeletedID.Save(p_indexStreams[3])) != ErrorCode::Success) + return ret; - COMMON::KDTree newTrees(m_pTrees); - newTrees.BuildTrees(m_pSamples, omp_get_num_threads(), &indices); -#pragma omp parallel for - for (SizeType i = 0; i < newTrees.size(); i++) { - if (newTrees[i].left < 0) - newTrees[i].left = -reverseIndices[-newTrees[i].left - 1] - 1; - if (newTrees[i].right < 0) - newTrees[i].right = -reverseIndices[-newTrees[i].right - 1] - 1; + if (nullptr != m_pMetadata) + { + if (p_indexStreams.size() < 6) + return ErrorCode::LackOfInputs; + if ((ret = m_pMetadata->RefineMetadata(indices, p_indexStreams[4], p_indexStreams[5])) != ErrorCode::Success) + return ret; + } + return ret; +} + +template ErrorCode Index::DeleteIndex(const void *p_vectors, SizeType p_vectorNum) +{ + const T *ptr_v = (const T *)p_vectors; + std::vector mythreads; + mythreads.reserve(m_iNumberOfThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_iNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < p_vectorNum) + { + COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); + SearchIndex(query); + + for (int j = 0; j < m_pGraph.m_iCEF; j++) + { + if (query.GetResult(j)->Dist < 1e-6) + { + DeleteIndex(query.GetResult(j)->VID); + } + } + } + else + { + return; + } } - if ((ret = newTrees.SaveTrees(p_indexStreams[1])) != ErrorCode::Success) return ret; + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + return ErrorCode::Success; +} - if (p_abort != nullptr && p_abort->ShouldAbort()) return ErrorCode::ExternalAbort; +template ErrorCode Index::DeleteIndex(const SizeType &p_id) +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - if ((ret = m_pGraph.RefineGraph(this, indices, reverseIndices, p_indexStreams[2], nullptr)) != ErrorCode::Success) return ret; + std::shared_lock sharedlock(m_dataDeleteLock); + if (m_deletedID.Insert(p_id)) + return ErrorCode::Success; + return ErrorCode::VectorNotFound; +} - COMMON::Labelset newDeletedID; - newDeletedID.Initialize(newR, m_iDataBlockSize, m_iDataCapacity); - if ((ret = newDeletedID.Save(p_indexStreams[3])) != ErrorCode::Success) return ret; +template +ErrorCode Index::AddIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, + std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; - if (nullptr != m_pMetadata) { - if (p_indexStreams.size() < 6) return ErrorCode::LackOfInputs; - if ((ret = m_pMetadata->RefineMetadata(indices, p_indexStreams[4], p_indexStreams[5])) != ErrorCode::Success) return ret; - } - return ret; - } + SizeType begin, end; + ErrorCode ret; + { + std::lock_guard lock(m_dataAddLock); - template - ErrorCode Index::DeleteIndex(const void* p_vectors, SizeType p_vectorNum) { - const T* ptr_v = (const T*)p_vectors; -#pragma omp parallel for schedule(dynamic) - for (SizeType i = 0; i < p_vectorNum; i++) { - COMMON::QueryResultSet query(ptr_v + i * GetFeatureDim(), m_pGraph.m_iCEF); - SearchIndex(query); - - for (int i = 0; i < m_pGraph.m_iCEF; i++) { - if (query.GetResult(i)->Dist < 1e-6) { - DeleteIndex(query.GetResult(i)->VID); - } - } + begin = GetNumSamples(); + end = begin + p_vectorNum; + + if (begin == 0) + { + if (p_metadataSet != nullptr) + { + m_pMetadata.reset(new MemMetadataSet(m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)); + m_pMetadata->AddBatch(*p_metadataSet); + if (p_withMetaIndex) + BuildMetaMapping(false); } + if ((ret = BuildIndex(p_data, p_vectorNum, p_dimension, p_normalized)) != ErrorCode::Success) + return ret; return ErrorCode::Success; } - template - ErrorCode Index::DeleteIndex(const SizeType& p_id) { - if (!m_bReady) return ErrorCode::EmptyIndex; + if (p_dimension != GetFeatureDim()) + return ErrorCode::DimensionSizeMismatch; - std::shared_lock sharedlock(m_dataDeleteLock); - if (m_deletedID.Insert(p_id)) return ErrorCode::Success; - return ErrorCode::VectorNotFound; + if (m_pSamples.AddBatch(p_vectorNum, (const T *)p_data) != ErrorCode::Success || + m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || + m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + m_deletedID.SetR(begin); + return ErrorCode::MemoryOverFlow; } - template - ErrorCode Index::AddIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) + if (m_pMetadata != nullptr) { - if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; - - SizeType begin, end; - ErrorCode ret; + if (p_metadataSet != nullptr) { - std::lock_guard lock(m_dataAddLock); - - begin = GetNumSamples(); - end = begin + p_vectorNum; - - if (begin == 0) { - if (p_metadataSet != nullptr) { - m_pMetadata.reset(new MemMetadataSet(m_iDataBlockSize, m_iDataCapacity, m_iMetaRecordSize)); - m_pMetadata->AddBatch(*p_metadataSet); - if (p_withMetaIndex) BuildMetaMapping(false); + m_pMetadata->AddBatch(*p_metadataSet); + if (HasMetaMapping()) + { + for (SizeType i = begin; i < end; i++) + { + ByteArray meta = m_pMetadata->GetMetadata(i); + std::string metastr((char *)meta.Data(), meta.Length()); + UpdateMetaMapping(metastr, i); } - if ((ret = BuildIndex(p_data, p_vectorNum, p_dimension, p_normalized)) != ErrorCode::Success) return ret; - return ErrorCode::Success; } + } + else + { + for (SizeType i = begin; i < end; i++) + m_pMetadata->Add(ByteArray::c_empty); + } + } + } - if (p_dimension != GetFeatureDim()) return ErrorCode::DimensionSizeMismatch; + if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) + { + int base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); + for (SizeType i = begin; i < end; i++) + { + COMMON::Utils::Normalize((T *)m_pSamples[i], GetFeatureDim(), base); + } + } - if (m_pSamples.AddBatch((const T*)p_data, p_vectorNum) != ErrorCode::Success || - m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || - m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); - m_pSamples.SetR(begin); - m_pGraph.SetR(begin); - m_deletedID.SetR(begin); - return ErrorCode::MemoryOverFlow; - } + if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) + { + m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph)); + } - if (m_pMetadata != nullptr) { - if (p_metadataSet != nullptr) { - m_pMetadata->AddBatch(*p_metadataSet); - if (HasMetaMapping()) { - for (SizeType i = begin; i < end; i++) { - ByteArray meta = m_pMetadata->GetMetadata(i); - std::string metastr((char*)meta.Data(), meta.Length()); - UpdateMetaMapping(metastr, i); - } - } - } - else { - for (SizeType i = begin; i < end; i++) m_pMetadata->Add(ByteArray::c_empty); - } - } - } + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); + } + return ErrorCode::Success; +} - if (DistCalcMethod::Cosine == m_iDistCalcMethod && !p_normalized) - { - int base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); - for (SizeType i = begin; i < end; i++) { - COMMON::Utils::Normalize((T*)m_pSamples[i], GetFeatureDim(), base); - } - } +template +ErrorCode Index::AddIndexId(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, int &beginHead, + int &endHead) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; - if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) { - m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph)); - } + SizeType begin, end; + { + std::lock_guard lock(m_dataAddLock); - for (SizeType node = begin; node < end; node++) - { - m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); - } - return ErrorCode::Success; - } + begin = GetNumSamples(); + end = begin + p_vectorNum; - template - ErrorCode - Index::UpdateIndex() + if (begin == 0) { - omp_set_num_threads(m_iNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_iNumberOfThreads, max(m_iMaxCheck, m_pGraph.m_iMaxCheckForRefineGraph), m_iHashTableExp); - return ErrorCode::Success; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Index Error: No vector in Index!\n"); + return ErrorCode::EmptyIndex; } - template - ErrorCode - Index::SetParameter(const char* p_param, const char* p_value, const char* p_section) + if (p_dimension != GetFeatureDim()) + return ErrorCode::DimensionSizeMismatch; + + if (m_pSamples.AddBatch(p_vectorNum, (const T *)p_data) != ErrorCode::Success || + m_pGraph.AddBatch(p_vectorNum) != ErrorCode::Success || + m_deletedID.AddBatch(p_vectorNum) != ErrorCode::Success) { - if (nullptr == p_param || nullptr == p_value) return ErrorCode::Fail; - -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ - { \ - LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ - VarType tmp; \ - if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ - { \ - VarName = tmp; \ - } \ - } \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Memory Error: Cannot alloc space for vectors!\n"); + m_pSamples.SetR(begin); + m_pGraph.SetR(begin); + m_deletedID.SetR(begin); + return ErrorCode::MemoryOverFlow; + } + } + beginHead = begin; + endHead = end; + return ErrorCode::Success; +} + +template ErrorCode Index::AddIndexIdx(SizeType begin, SizeType end) +{ + // if (end - m_pTrees.sizePerTree() >= m_addCountForRebuild && m_threadPool.jobsize() == 0) { + // m_threadPool.add(new RebuildJob(&m_pSamples, &m_pTrees, &m_pGraph, m_iDistCalcMethod)); + // } + + for (SizeType node = begin; node < end; node++) + { + m_pGraph.RefineNode(this, node, true, true, m_pGraph.m_iAddCEF); + } + return ErrorCode::Success; +} + +template ErrorCode Index::UpdateIndex() +{ + m_pGraph.m_iThreadNum = m_iNumberOfThreads; + return ErrorCode::Success; +} + +template ErrorCode Index::SetParameter(const char *p_param, const char *p_value, const char *p_section) +{ + if (nullptr == p_param || nullptr == p_value) + return ErrorCode::Fail; + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting %s with value %s\n", RepresentStr, p_value); \ + VarType tmp; \ + if (SPTAG::Helper::Convert::ConvertStringTo(p_value, tmp)) \ + { \ + VarName = tmp; \ + } \ + } #include "inc/Core/KDT/ParameterDefinitionList.h" #undef DefineKDTParameter - if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) { - m_fComputeDistance = m_pQuantizer ? m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod) : COMMON::DistanceCalcSelector(m_iDistCalcMethod); - auto base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); - m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? base * base : 1; - } - return ErrorCode::Success; - } - + if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) + { + m_fComputeDistance = m_pQuantizer ? m_pQuantizer->DistanceCalcSelector(m_iDistCalcMethod) + : COMMON::DistanceCalcSelector(m_iDistCalcMethod); + auto base = m_pQuantizer ? m_pQuantizer->GetBase() : COMMON::Utils::GetBase(); + m_iBaseSquare = (m_iDistCalcMethod == DistCalcMethod::Cosine) ? base * base : 1; + } + return ErrorCode::Success; +} - template - std::string - Index::GetParameter(const char* p_param, const char* p_section) const - { - if (nullptr == p_param) return std::string(); +template std::string Index::GetParameter(const char *p_param, const char *p_section) const +{ + if (nullptr == p_param) + return std::string(); -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ - { \ - return SPTAG::Helper::Convert::ConvertToString(VarName); \ - } \ +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + else if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, RepresentStr)) \ + { \ + return SPTAG::Helper::Convert::ConvertToString(VarName); \ + } #include "inc/Core/KDT/ParameterDefinitionList.h" #undef DefineKDTParameter - return std::string(); - } - } + return std::string(); } +} // namespace KDT +} // namespace SPTAG -#define DefineVectorValueType(Name, Type) \ -template class SPTAG::KDT::Index; \ +#define DefineVectorValueType(Name, Type) template class SPTAG::KDT::Index; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - - diff --git a/AnnService/src/Core/MetadataSet.cpp b/AnnService/src/Core/MetadataSet.cpp index c6ab25049..bcef57bb5 100644 --- a/AnnService/src/Core/MetadataSet.cpp +++ b/AnnService/src/Core/MetadataSet.cpp @@ -2,70 +2,75 @@ // Licensed under the MIT License. #include "inc/Core/MetadataSet.h" - -#include +#include "inc/Core/VectorIndex.h" #include +#include using namespace SPTAG; #include "inc/Helper/LockFree.h" typedef typename SPTAG::Helper::LockFree::LockFreeVector MetadataOffsets; -ErrorCode -MetadataSet::RefineMetadata(std::vector& indices, std::shared_ptr& p_newMetadata, - std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) const +ErrorCode MetadataSet::RefineMetadata(std::vector &indices, std::shared_ptr &p_newMetadata, + std::uint64_t p_blockSize, std::uint64_t p_capacity, + std::uint64_t p_metaSize) const { p_newMetadata.reset(new MemMetadataSet(p_blockSize, p_capacity, p_metaSize)); - for (SizeType& t : indices) { + for (SizeType &t : indices) + { p_newMetadata->Add(GetMetadata(t)); } return ErrorCode::Success; } -ErrorCode -MetadataSet::RefineMetadata(std::vector& indices, std::shared_ptr p_metaOut, std::shared_ptr p_metaIndexOut) const +ErrorCode MetadataSet::RefineMetadata(std::vector &indices, std::shared_ptr p_metaOut, + std::shared_ptr p_metaIndexOut) const { SizeType R = (SizeType)indices.size(); - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char*)&R); + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char *)&R); std::uint64_t offset = 0; - for (SizeType i = 0; i < R; i++) { - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char*)&offset); + for (SizeType i = 0; i < R; i++) + { + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char *)&offset); ByteArray meta = GetMetadata(indices[i]); offset += meta.Length(); } - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char*)&offset); + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char *)&offset); - for (SizeType i = 0; i < R; i++) { + for (SizeType i = 0; i < R; i++) + { ByteArray meta = GetMetadata(indices[i]); - IOBINARY(p_metaOut, WriteBinary, sizeof(uint8_t) * meta.Length(), (const char*)meta.Data()); + IOBINARY(p_metaOut, WriteBinary, sizeof(uint8_t) * meta.Length(), (const char *)meta.Data()); } - LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%d) Meta(%llu)\n", R, offset); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%d) Meta(%llu)\n", R, offset); return ErrorCode::Success; } - -ErrorCode -MetadataSet::RefineMetadata(std::vector& indices, const std::string& p_metaFile, const std::string& p_metaindexFile) const +ErrorCode MetadataSet::RefineMetadata(std::vector &indices, const std::string &p_metaFile, + const std::string &p_metaindexFile) const { { std::shared_ptr ptrMeta = f_createIO(), ptrMetaIndex = f_createIO(); - if (ptrMeta == nullptr || ptrMetaIndex == nullptr || !ptrMeta->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || !ptrMetaIndex->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) - return ErrorCode::FailedCreateFile; + if (ptrMeta == nullptr || ptrMetaIndex == nullptr || + !ptrMeta->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || + !ptrMetaIndex->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) + return ErrorCode::FailedCreateFile; ErrorCode ret = RefineMetadata(indices, ptrMeta, ptrMetaIndex); - if (ret != ErrorCode::Success) return ret; + if (ret != ErrorCode::Success) + return ret; } - if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); - if (fileexists(p_metaindexFile.c_str())) std::remove(p_metaindexFile.c_str()); + if (fileexists(p_metaFile.c_str())) + std::remove(p_metaFile.c_str()); + if (fileexists(p_metaindexFile.c_str())) + std::remove(p_metaindexFile.c_str()); std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); std::rename((p_metaindexFile + "_tmp").c_str(), p_metaindexFile.c_str()); return ErrorCode::Success; } - -void -MetadataSet::AddBatch(MetadataSet& data) +void MetadataSet::AddBatch(MetadataSet &data) { for (SizeType i = 0; i < data.Count(); i++) { @@ -73,384 +78,389 @@ MetadataSet::AddBatch(MetadataSet& data) } } - MetadataSet::MetadataSet() { } - MetadataSet::~MetadataSet() { } -bool MetadataSet::GetMetadataOffsets(const std::uint8_t* p_meta, const std::uint64_t p_metaLength, std::uint64_t* p_offsets, std::uint64_t p_offsetLength, char p_delimiter) +bool MetadataSet::GetMetadataOffsets(const std::uint8_t *p_meta, const std::uint64_t p_metaLength, + std::uint64_t *p_offsets, std::uint64_t p_offsetLength, char p_delimiter) { std::uint64_t current = 0; p_offsets[current++] = 0; - for (std::uint64_t i = 0; i < p_metaLength && current < p_offsetLength; i++) { + for (std::uint64_t i = 0; i < p_metaLength && current < p_offsetLength; i++) + { if ((char)(p_meta[i]) == p_delimiter) p_offsets[current++] = (std::uint64_t)(i + 1); } if ((char)(p_meta[p_metaLength - 1]) != p_delimiter && current < p_offsetLength) p_offsets[current++] = p_metaLength; - if (current < p_offsetLength) { - LOG(Helper::LogLevel::LL_Error, "The metadata(%d) and vector(%d) numbers are not match! Check whether it is unicode encoding issue.\n", current-1, p_offsetLength-1); + if (current < p_offsetLength) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "The metadata(%d) and vector(%d) numbers are not match! Check whether it is unicode encoding issue.\n", + current - 1, p_offsetLength - 1); return false; } return true; } - -FileMetadataSet::FileMetadataSet(const std::string& p_metafile, const std::string& p_metaindexfile, std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) +FileMetadataSet::FileMetadataSet(const std::string &p_metafile, const std::string &p_metaindexfile, + std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) { m_fp = f_createIO(); auto fpidx = f_createIO(); - if (m_fp == nullptr || fpidx == nullptr || !m_fp->Initialize(p_metafile.c_str(), std::ios::binary | std::ios::in) || !fpidx->Initialize(p_metaindexfile.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open meta files %s or %s!\n", p_metafile.c_str(), p_metaindexfile.c_str()); + if (m_fp == nullptr || fpidx == nullptr || !m_fp->Initialize(p_metafile.c_str(), std::ios::binary | std::ios::in) || + !fpidx->Initialize(p_metaindexfile.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open meta files %s or %s!\n", p_metafile.c_str(), + p_metaindexfile.c_str()); throw std::runtime_error("Cannot open meta files"); } - if (fpidx->ReadBinary(sizeof(m_count), (char*)&m_count) != sizeof(m_count)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read FileMetadataSet!\n"); + if (fpidx->ReadBinary(sizeof(m_count), (char *)&m_count) != sizeof(m_count)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read FileMetadataSet!\n"); throw std::runtime_error("Cannot read meta files"); } m_offsets.reserve(p_blockSize); m_offsets.resize(m_count + 1); - if (fpidx->ReadBinary(sizeof(std::uint64_t) * (m_count + 1), (char*)m_offsets.data()) != sizeof(std::uint64_t) * (m_count + 1)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read FileMetadataSet!\n"); + if (fpidx->ReadBinary(sizeof(std::uint64_t) * (m_count + 1), (char *)m_offsets.data()) != + sizeof(std::uint64_t) * (m_count + 1)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read FileMetadataSet!\n"); throw std::runtime_error("Cannot read meta files"); } m_newdata.reserve(p_blockSize * p_metaSize); m_lock.reset(new std::shared_timed_mutex, std::default_delete()); - LOG(Helper::LogLevel::LL_Info, "Load MetaIndex(%d) Meta(%llu)\n", m_count, m_offsets[m_count]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load MetaIndex(%d) Meta(%llu)\n", m_count, m_offsets[m_count]); } - FileMetadataSet::~FileMetadataSet() { } - -ByteArray -FileMetadataSet::GetMetadata(SizeType p_vectorID) const +ByteArray FileMetadataSet::GetMetadata(SizeType p_vectorID) const { - std::unique_lock lock(*static_cast(m_lock.get())); + std::unique_lock lock(*static_cast(m_lock.get())); std::uint64_t startoff = m_offsets[p_vectorID]; std::uint64_t bytes = m_offsets[p_vectorID + 1] - startoff; - if (p_vectorID < m_count) { + if (p_vectorID < m_count) + { ByteArray b = ByteArray::Alloc(bytes); - m_fp->ReadBinary(bytes, (char*)b.Data(), startoff); + m_fp->ReadBinary(bytes, (char *)b.Data(), startoff); return b; } - else { - return ByteArray((std::uint8_t*)m_newdata.data() + startoff - m_offsets[m_count], bytes, false); + else + { + return ByteArray((std::uint8_t *)m_newdata.data() + startoff - m_offsets[m_count], bytes, false); } } - -ByteArray -FileMetadataSet::GetMetadataCopy(SizeType p_vectorID) const +ByteArray FileMetadataSet::GetMetadataCopy(SizeType p_vectorID) const { - std::unique_lock lock(*static_cast(m_lock.get())); + std::unique_lock lock(*static_cast(m_lock.get())); std::uint64_t startoff = m_offsets[p_vectorID]; std::uint64_t bytes = m_offsets[p_vectorID + 1] - startoff; - if (p_vectorID < m_count) { + if (p_vectorID < m_count) + { ByteArray b = ByteArray::Alloc(bytes); - m_fp->ReadBinary(bytes, (char*)b.Data(), startoff); + m_fp->ReadBinary(bytes, (char *)b.Data(), startoff); return b; } - else { + else + { ByteArray b = ByteArray::Alloc(bytes); memcpy(b.Data(), m_newdata.data() + (startoff - m_offsets[m_count]), bytes); return b; } } - -SizeType -FileMetadataSet::Count() const +SizeType FileMetadataSet::Count() const { - std::shared_lock lock(*static_cast(m_lock.get())); + std::shared_lock lock(*static_cast(m_lock.get())); return static_cast(m_offsets.size() - 1); } - -bool -FileMetadataSet::Available() const +bool FileMetadataSet::Available() const { - std::shared_lock lock(*static_cast(m_lock.get())); + std::shared_lock lock(*static_cast(m_lock.get())); return m_fp != nullptr && m_offsets.size() > 1; } - -std::pair -FileMetadataSet::BufferSize() const +std::pair FileMetadataSet::BufferSize() const { - std::shared_lock lock(*static_cast(m_lock.get())); - return std::make_pair(m_offsets.back(), - sizeof(SizeType) + sizeof(std::uint64_t) * m_offsets.size()); + std::shared_lock lock(*static_cast(m_lock.get())); + return std::make_pair(m_offsets.back(), sizeof(SizeType) + sizeof(std::uint64_t) * m_offsets.size()); } - -void -FileMetadataSet::Add(const ByteArray& data) +void FileMetadataSet::Add(const ByteArray &data) { - std::unique_lock lock(*static_cast(m_lock.get())); + std::unique_lock lock(*static_cast(m_lock.get())); m_newdata.insert(m_newdata.end(), data.Data(), data.Data() + data.Length()); m_offsets.push_back(m_offsets.back() + data.Length()); } - -ErrorCode -FileMetadataSet::SaveMetadata(std::shared_ptr p_metaOut, std::shared_ptr p_metaIndexOut) +ErrorCode FileMetadataSet::SaveMetadata(std::shared_ptr p_metaOut, + std::shared_ptr p_metaIndexOut) { - std::shared_lock lock(*static_cast(m_lock.get())); + std::shared_lock lock(*static_cast(m_lock.get())); SizeType count = Count(); - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char*)&count); - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t) * m_offsets.size(), (const char*)m_offsets.data()); + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char *)&count); + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t) * m_offsets.size(), (const char *)m_offsets.data()); std::uint64_t bufsize = 1000000; - char* buf = new char[bufsize]; + char *buf = new char[bufsize]; auto readsize = m_fp->ReadBinary(bufsize, buf, 0); - while (readsize > 0) { + while (readsize > 0) + { IOBINARY(p_metaOut, WriteBinary, readsize, buf); readsize = m_fp->ReadBinary(bufsize, buf); } delete[] buf; - - if (m_newdata.size() > 0) { - IOBINARY(p_metaOut, WriteBinary, m_newdata.size(), (const char*)m_newdata.data()); + + if (m_newdata.size() > 0) + { + IOBINARY(p_metaOut, WriteBinary, m_newdata.size(), (const char *)m_newdata.data()); } - LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%llu) Meta(%llu)\n", m_offsets.size() - 1, m_offsets.back()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%llu) Meta(%llu)\n", m_offsets.size() - 1, + m_offsets.back()); return ErrorCode::Success; } - -ErrorCode -FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) +ErrorCode FileMetadataSet::SaveMetadata(const std::string &p_metaFile, const std::string &p_metaindexFile) { { std::shared_ptr metaOut = f_createIO(), metaIndexOut = f_createIO(); - if (metaOut == nullptr || metaIndexOut == nullptr || !metaOut->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || !metaIndexOut->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) + if (metaOut == nullptr || metaIndexOut == nullptr || + !metaOut->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || + !metaIndexOut->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; ErrorCode ret = SaveMetadata(metaOut, metaIndexOut); - if (ret != ErrorCode::Success) return ret; + if (ret != ErrorCode::Success) + return ret; } { - std::unique_lock lock(*static_cast(m_lock.get())); + std::unique_lock lock(*static_cast(m_lock.get())); m_fp->ShutDown(); - if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); - if (fileexists(p_metaindexFile.c_str())) std::remove(p_metaindexFile.c_str()); + if (fileexists(p_metaFile.c_str())) + std::remove(p_metaFile.c_str()); + if (fileexists(p_metaindexFile.c_str())) + std::remove(p_metaindexFile.c_str()); std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); std::rename((p_metaindexFile + "_tmp").c_str(), p_metaindexFile.c_str()); - if (!m_fp->Initialize(p_metaFile.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; + if (!m_fp->Initialize(p_metaFile.c_str(), std::ios::binary | std::ios::in)) + return ErrorCode::FailedOpenFile; m_count = static_cast(m_offsets.size() - 1); m_newdata.clear(); } return ErrorCode::Success; } - -MemMetadataSet::MemMetadataSet(std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize): m_count(0), m_metadataHolder(ByteArray::c_empty) +MemMetadataSet::MemMetadataSet(std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) + : m_count(0), m_metadataHolder(ByteArray::c_empty) { m_pOffsets.reset(new MetadataOffsets, std::default_delete()); - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); m_offsets.reserve(p_blockSize, p_capacity); m_offsets.push_back(0); m_newdata.reserve(p_blockSize * p_metaSize); m_lock.reset(new std::shared_timed_mutex, std::default_delete()); } - -ErrorCode -MemMetadataSet::Init(std::shared_ptr p_metain, std::shared_ptr p_metaindexin, - std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) +ErrorCode MemMetadataSet::Init(std::shared_ptr p_metain, std::shared_ptr p_metaindexin, + std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) { - IOBINARY(p_metaindexin, ReadBinary, sizeof(m_count), (char*)&m_count); + IOBINARY(p_metaindexin, ReadBinary, sizeof(m_count), (char *)&m_count); m_pOffsets.reset(new MetadataOffsets, std::default_delete()); - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); m_offsets.reserve(p_blockSize, p_capacity); { std::vector tmp(m_count + 1, 0); - IOBINARY(p_metaindexin, ReadBinary, sizeof(std::uint64_t) * (m_count + 1), (char*)tmp.data()); + IOBINARY(p_metaindexin, ReadBinary, sizeof(std::uint64_t) * (m_count + 1), (char *)tmp.data()); m_offsets.assign(tmp.data(), tmp.data() + tmp.size()); } m_metadataHolder = ByteArray::Alloc(m_offsets[m_count]); - IOBINARY(p_metain, ReadBinary, m_metadataHolder.Length(), (char*)m_metadataHolder.Data()); + IOBINARY(p_metain, ReadBinary, m_metadataHolder.Length(), (char *)m_metadataHolder.Data()); m_newdata.reserve(p_blockSize * p_metaSize); m_lock.reset(new std::shared_timed_mutex, std::default_delete()); - LOG(Helper::LogLevel::LL_Info, "Load MetaIndex(%d) Meta(%llu)\n", m_count, m_offsets[m_count]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load MetaIndex(%d) Meta(%llu)\n", m_count, m_offsets[m_count]); return ErrorCode::Success; } - MemMetadataSet::MemMetadataSet(std::shared_ptr p_metain, std::shared_ptr p_metaindexin, - std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) + std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) { - if (Init(p_metain, p_metaindexin, p_blockSize, p_capacity, p_metaSize) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read MemMetadataSet!\n"); + if (Init(p_metain, p_metaindexin, p_blockSize, p_capacity, p_metaSize) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read MemMetadataSet!\n"); throw std::runtime_error("Cannot read MemMetadataSet"); } } -MemMetadataSet::MemMetadataSet(const std::string& p_metafile, const std::string& p_metaindexfile, - std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) +MemMetadataSet::MemMetadataSet(const std::string &p_metafile, const std::string &p_metaindexfile, + std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) { std::shared_ptr ptrMeta = f_createIO(), ptrMetaIndex = f_createIO(); - if (ptrMeta == nullptr || ptrMetaIndex == nullptr || !ptrMeta->Initialize(p_metafile.c_str(), std::ios::binary | std::ios::in) || !ptrMetaIndex->Initialize(p_metaindexfile.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open meta files %s or %s!\n", p_metafile.c_str(), p_metaindexfile.c_str()); + if (ptrMeta == nullptr || ptrMetaIndex == nullptr || + !ptrMeta->Initialize(p_metafile.c_str(), std::ios::binary | std::ios::in) || + !ptrMetaIndex->Initialize(p_metaindexfile.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open meta files %s or %s!\n", p_metafile.c_str(), + p_metaindexfile.c_str()); throw std::runtime_error("Cannot open MemMetadataSet files"); } - if (Init(ptrMeta, ptrMetaIndex, p_blockSize, p_capacity, p_metaSize) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read MemMetadataSet!\n"); + if (Init(ptrMeta, ptrMetaIndex, p_blockSize, p_capacity, p_metaSize) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot read MemMetadataSet!\n"); throw std::runtime_error("Cannot read MemMetadataSet"); } } - MemMetadataSet::MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count) : m_count(p_count), m_metadataHolder(std::move(p_metadata)) { m_pOffsets.reset(new MetadataOffsets, std::default_delete()); - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); m_offsets.reserve(p_count + 1, p_count + 1); - const std::uint64_t* newdata = reinterpret_cast(p_offsets.Data()); + const std::uint64_t *newdata = reinterpret_cast(p_offsets.Data()); m_offsets.assign(newdata, newdata + p_count + 1); m_lock.reset(new std::shared_timed_mutex, std::default_delete()); } -MemMetadataSet::MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count, - std::uint64_t p_blockSize, std::uint64_t p_capacity, std::uint64_t p_metaSize) +MemMetadataSet::MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count, std::uint64_t p_blockSize, + std::uint64_t p_capacity, std::uint64_t p_metaSize) : m_count(p_count), m_metadataHolder(std::move(p_metadata)) { m_pOffsets.reset(new MetadataOffsets, std::default_delete()); - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); m_offsets.reserve(p_blockSize, p_capacity); - const std::uint64_t* newdata = reinterpret_cast(p_offsets.Data()); + const std::uint64_t *newdata = reinterpret_cast(p_offsets.Data()); m_offsets.assign(newdata, newdata + p_count + 1); m_newdata.reserve(p_blockSize * p_metaSize); m_lock.reset(new std::shared_timed_mutex, std::default_delete()); } - MemMetadataSet::~MemMetadataSet() { } - -ByteArray -MemMetadataSet::GetMetadata(SizeType p_vectorID) const +ByteArray MemMetadataSet::GetMetadata(SizeType p_vectorID) const { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); std::uint64_t startoff = m_offsets[p_vectorID]; std::uint64_t bytes = m_offsets[p_vectorID + 1] - startoff; - if (p_vectorID < m_count) { + if (p_vectorID < m_count) + { return ByteArray(m_metadataHolder.Data() + startoff, bytes, false); - } else { - std::shared_lock lock(*static_cast(m_lock.get())); - return ByteArray((std::uint8_t*)m_newdata.data() + startoff - m_offsets[m_count], bytes, false); + } + else + { + std::shared_lock lock(*static_cast(m_lock.get())); + return ByteArray((std::uint8_t *)m_newdata.data() + startoff - m_offsets[m_count], bytes, false); } } - -ByteArray -MemMetadataSet::GetMetadataCopy(SizeType p_vectorID) const +ByteArray MemMetadataSet::GetMetadataCopy(SizeType p_vectorID) const { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); std::uint64_t startoff = m_offsets[p_vectorID]; std::uint64_t bytes = m_offsets[p_vectorID + 1] - startoff; - if (p_vectorID < m_count) { + if (p_vectorID < m_count) + { ByteArray b = ByteArray::Alloc(bytes); memcpy(b.Data(), m_metadataHolder.Data() + startoff, bytes); return b; - } else { + } + else + { ByteArray b = ByteArray::Alloc(bytes); - std::shared_lock lock(*static_cast(m_lock.get())); + std::shared_lock lock(*static_cast(m_lock.get())); memcpy(b.Data(), m_newdata.data() + (startoff - m_offsets[m_count]), bytes); return b; } } - -SizeType -MemMetadataSet::Count() const +SizeType MemMetadataSet::Count() const { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); return static_cast(m_offsets.size() - 1); } - -bool -MemMetadataSet::Available() const +bool MemMetadataSet::Available() const { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); return m_offsets.size() > 1; } - -std::pair -MemMetadataSet::BufferSize() const +std::pair MemMetadataSet::BufferSize() const { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); auto n = m_offsets.size(); - return std::make_pair(m_offsets[n-1], - sizeof(SizeType) + sizeof(std::uint64_t) * n); + return std::make_pair(m_offsets[n - 1], sizeof(SizeType) + sizeof(std::uint64_t) * n); } - -void -MemMetadataSet::Add(const ByteArray& data) +void MemMetadataSet::Add(const ByteArray &data) { - auto& m_offsets = *static_cast(m_pOffsets.get()); - std::unique_lock lock(*static_cast(m_lock.get())); + auto &m_offsets = *static_cast(m_pOffsets.get()); + std::unique_lock lock(*static_cast(m_lock.get())); m_newdata.insert(m_newdata.end(), data.Data(), data.Data() + data.Length()); - if (!m_offsets.push_back(m_offsets.back() + data.Length())) { - LOG(Helper::LogLevel::LL_Error, "Insert MetaIndex error! DataCapacity overflow!\n"); + if (!m_offsets.push_back(m_offsets.back() + data.Length())) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Insert MetaIndex error! DataCapacity overflow!\n"); m_newdata.resize(m_newdata.size() - data.Length()); } } - -ErrorCode -MemMetadataSet::SaveMetadata(std::shared_ptr p_metaOut, std::shared_ptr p_metaIndexOut) +ErrorCode MemMetadataSet::SaveMetadata(std::shared_ptr p_metaOut, + std::shared_ptr p_metaIndexOut) { - auto& m_offsets = *static_cast(m_pOffsets.get()); + auto &m_offsets = *static_cast(m_pOffsets.get()); SizeType count = Count(); - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char*)&count); - for (SizeType i = 0; i <= count; i++) { - IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char*)(&m_offsets[i])); + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(SizeType), (const char *)&count); + for (SizeType i = 0; i <= count; i++) + { + IOBINARY(p_metaIndexOut, WriteBinary, sizeof(std::uint64_t), (const char *)(&m_offsets[i])); } - IOBINARY(p_metaOut, WriteBinary, m_metadataHolder.Length(), reinterpret_cast(m_metadataHolder.Data())); - if (m_newdata.size() > 0) { - std::shared_lock lock(*static_cast(m_lock.get())); - IOBINARY(p_metaOut, WriteBinary, m_offsets[count] - m_offsets[m_count], (const char*)m_newdata.data()); + IOBINARY(p_metaOut, WriteBinary, m_offsets[m_count], + reinterpret_cast(m_metadataHolder.Data())); + if (m_newdata.size() > 0) + { + std::shared_lock lock(*static_cast(m_lock.get())); + IOBINARY(p_metaOut, WriteBinary, m_offsets[count] - m_offsets[m_count], (const char *)m_newdata.data()); } - LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%llu) Meta(%llu)\n", m_offsets.size() - 1, m_offsets.back()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Save MetaIndex(%llu) Meta(%llu)\n", m_offsets.size() - 1, + m_offsets.back()); return ErrorCode::Success; } - - -ErrorCode -MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) +ErrorCode MemMetadataSet::SaveMetadata(const std::string &p_metaFile, const std::string &p_metaindexFile) { { std::shared_ptr metaOut = f_createIO(), metaIndexOut = f_createIO(); - if (metaOut == nullptr || metaIndexOut == nullptr || !metaOut->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || !metaIndexOut->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) + if (metaOut == nullptr || metaIndexOut == nullptr || + !metaOut->Initialize((p_metaFile + "_tmp").c_str(), std::ios::binary | std::ios::out) || + !metaIndexOut->Initialize((p_metaindexFile + "_tmp").c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; ErrorCode ret = SaveMetadata(metaOut, metaIndexOut); - if (ret != ErrorCode::Success) return ret; + if (ret != ErrorCode::Success) + return ret; } - if (fileexists(p_metaFile.c_str())) std::remove(p_metaFile.c_str()); - if (fileexists(p_metaindexFile.c_str())) std::remove(p_metaindexFile.c_str()); + if (fileexists(p_metaFile.c_str())) + std::remove(p_metaFile.c_str()); + if (fileexists(p_metaindexFile.c_str())) + std::remove(p_metaindexFile.c_str()); std::rename((p_metaFile + "_tmp").c_str(), p_metaFile.c_str()); std::rename((p_metaindexFile + "_tmp").c_str(), p_metaindexFile.c_str()); return ErrorCode::Success; } - diff --git a/AnnService/src/Core/MultiIndexScan.cpp b/AnnService/src/Core/MultiIndexScan.cpp new file mode 100644 index 000000000..d237ff7e6 --- /dev/null +++ b/AnnService/src/Core/MultiIndexScan.cpp @@ -0,0 +1,205 @@ +#include "inc/Core/MultiIndexScan.h" + +namespace SPTAG +{ +MultiIndexScan::MultiIndexScan() +{ +} +MultiIndexScan::MultiIndexScan(std::vector> vectorIndices, std::vector p_targets, + unsigned int k, float (*rankFunction)(std::vector), bool useTimer, + int termCondVal, int searchLimit) +{ + // copy parameters + this->fwdLUTs = std::vector>(vectorIndices); + this->k = k; + this->func = rankFunction; + // terminate related initialization + this->useTimer = useTimer; + this->termCondVal = termCondVal; + this->searchLimit = searchLimit; + this->t_start = std::chrono::high_resolution_clock::now(); + this->consecutive_drops = 0; + // internal states + // topK status + this->pq = std::priority_queue, pq_item_compare>(); + this->seenSet = std::unordered_set(); + // output dataStrucure + this->outputStk = std::stack(); + this->terminate = false; + + for (int i = 0; i < fwdLUTs.size(); i++) + { + indexIters.push_back(vectorIndices[i]->GetIterator(p_targets[i], true)); + } + printf("(%zu, %zu, %zu)\n", fwdLUTs.size(), p_targets.size(), indexIters.size()); +} +float rankFunc(std::vector in) +{ + return (float)std::accumulate(in.begin(), in.end(), 0.0f); +} + +float MultiIndexScan::WeightedRankFunc(std::vector in) +{ + float result = 0; + for (int i = 0; i < in.size(); i++) + { + result += in[i] * weight[i]; + } + return result; +} + +void MultiIndexScan::Init(std::vector> vectorIndices, std::vector p_targets, + std::vector weight, unsigned int k, bool useTimer, int termCondVal, int searchLimit) +{ + this->fwdLUTs = vectorIndices; + this->k = k; + this->func = nullptr; + this->weight = weight; + // terminate related initialization + this->useTimer = useTimer; + this->termCondVal = termCondVal; + this->searchLimit = searchLimit; + this->t_start = std::chrono::high_resolution_clock::now(); + this->consecutive_drops = 0; + // internal states + // topK status + this->pq = std::priority_queue, pq_item_compare>(); + this->seenSet = std::unordered_set(); + // output dataStrucure + this->outputStk = std::stack(); + this->terminate = false; + + for (int i = 0; i < fwdLUTs.size(); i++) + { + p_data_array.push_back(p_targets[i]); + indexIters.push_back(vectorIndices[i]->GetIterator(p_data_array[i].Data(), true)); + } + printf("(%zu, %zu, %zu)\n", fwdLUTs.size(), p_targets.size(), indexIters.size()); +} + +MultiIndexScan::~MultiIndexScan() +{ + Close(); +} + +bool MultiIndexScan::Next(BasicResult &result) +{ + int numCols = (int)indexIters.size(); + while (!terminate) + { + + for (int i = 0; i < numCols; i++) + { + auto result_iter = indexIters[i]; + // printf("probing index %d, %s\n", i, fwdLUTs[i]->GetIndexName().c_str()); + auto results = result_iter->Next(1); + if (results->GetResultNum() == 0) + { + printf("index %d no more result!! Terminating\n", i); + terminate = true; + break; + } + + auto vid = results->GetResult(0)->VID; + auto dist = results->GetResult(0)->Dist; + // we ignore meta for now:: auto meta = curr_result.Meta; + + // printf("vid = %lu, dist = %f from index %d\n", vid, dist, i); + + // insert into heap only if it has NOT been seen + if (seenSet.find(vid) == seenSet.end()) + { + + std::vector dists(numCols, 0); + dists[i] = dist; + + // lookup scores using foward index + for (int j = 0; j < numCols; j++) + { + if (i != j) + { + dists[j] = fwdLUTs[j]->GetDistance(indexIters[j]->GetTarget(), vid); + } + } + + // using UDF to calculate the score + float score; + if (func == nullptr) + { + score = WeightedRankFunc(dists); + } + else + { + score = func(dists); + } + + // printf("vid = %d, not seen! score = %f\n", vid, score); + + // insert into heap only if it is smaller than the largest score + // in our ANN case, we keep the k smallest score/dists + if (pq.size() == k && pq.top().first <= score) + { + // printf("%f >= largest score in heap %f, drop", score, pq.top().first); + consecutive_drops++; + } + else + { + consecutive_drops = 0; + if (pq.size() == k) + pq.pop(); // the largest score is squeezed out by the current score + pq.push(std::make_pair(score, vid)); + } + + // insert it into our seen set; + seenSet.insert(vid); + } + } + + // checkout terminate condition and check if we need to start stream out output + // why we check term cond outside for loop? because we don't want to check that often~ + if (useTimer) + { + auto t_end = std::chrono::high_resolution_clock::now(); + terminate = (std::chrono::duration(t_end - t_start).count() >= termCondVal); + } + else if (!useTimer && consecutive_drops >= termCondVal) + { + terminate = true; + } + else if (seenSet.size() >= searchLimit) + { + terminate = true; + } + } + // put back the output into the correct order, i.e. smallest score first + if (pq.size() > 0) + { + while (!pq.empty()) + { + outputStk.push(pq.top()); + pq.pop(); + } + } + + // if no outputs + if (outputStk.size() == 0) + return false; + + // pop-out outputs + result.VID = outputStk.top().second; + result.Dist = outputStk.top().first; + outputStk.pop(); + + return true; +} + +// Add end into destructor. +void MultiIndexScan::Close() +{ + for (auto resultIter : indexIters) + { + resultIter->Close(); + } +} + +} // namespace SPTAG diff --git a/AnnService/src/Core/ResultIterator.cpp b/AnnService/src/Core/ResultIterator.cpp new file mode 100644 index 000000000..9c33b021b --- /dev/null +++ b/AnnService/src/Core/ResultIterator.cpp @@ -0,0 +1,87 @@ +#include "inc/Core/ResultIterator.h" + +struct UniqueHandler +{ + std::unique_ptr m_handler; +}; + +ResultIterator::ResultIterator(const void *p_index, const void *p_target, bool p_searchDeleted, int p_workspaceBatch, std::function p_filterFunc, int p_maxCheck) + : m_index((const VectorIndex *)p_index), m_target(p_target), m_searchDeleted(p_searchDeleted) +{ + m_workspace = new UniqueHandler; + ((UniqueHandler *)m_workspace)->m_handler = std::move(m_index->RentWorkSpace(p_workspaceBatch, p_filterFunc, p_maxCheck)); + m_isFirstResult = true; +} + +ResultIterator::~ResultIterator() +{ + Close(); +} + +void *ResultIterator::GetWorkSpace() +{ + if (m_workspace == nullptr) + return nullptr; + return (((UniqueHandler *)m_workspace)->m_handler).get(); +} + +std::shared_ptr ResultIterator::Next(int batch) +{ + if (m_queryResult == nullptr) + { + m_queryResult = std::make_unique(m_target, batch, true); + } + else if (batch <= m_queryResult->GetResultNum()) + { + m_queryResult->SetResultNum(batch); + } + else + { + batch = m_queryResult->GetResultNum(); + } + + m_queryResult->Reset(); + if (m_workspace == nullptr) + return m_queryResult; + + int resultCount = 0; + m_index->SearchIndexIterativeNext(*m_queryResult, (((UniqueHandler *)m_workspace)->m_handler).get(), batch, + resultCount, m_isFirstResult, m_searchDeleted); + m_isFirstResult = false; + for (int i = 0; i < resultCount; i++) + { + m_queryResult->GetResult(i)->RelaxedMono = (((UniqueHandler *)m_workspace)->m_handler)->m_relaxedMono; + } + m_queryResult->SetResultNum(resultCount); + return m_queryResult; +} + +bool ResultIterator::GetRelaxedMono() +{ + if (m_workspace == nullptr) + return false; + + return (((UniqueHandler *)m_workspace)->m_handler)->m_relaxedMono; +} + +SPTAG::ErrorCode ResultIterator::GetErrorCode() +{ + return SPTAG::ErrorCode::Success; +} + + // Add end into destructor. +void ResultIterator::Close() +{ + if (m_workspace != nullptr) + { + m_index->SearchIndexIterativeEnd(std::move(((UniqueHandler *)m_workspace)->m_handler)); + delete (UniqueHandler *)m_workspace; + m_workspace = nullptr; + } +} + +const void *ResultIterator::GetTarget() +{ + return m_target; +} +// namespace SPTAG diff --git a/AnnService/src/Core/SPANN/ExtraFileController.cpp b/AnnService/src/Core/SPANN/ExtraFileController.cpp new file mode 100644 index 000000000..ea69c0c38 --- /dev/null +++ b/AnnService/src/Core/SPANN/ExtraFileController.cpp @@ -0,0 +1,474 @@ +#include "inc/Core/SPANN/ExtraFileController.h" + +namespace SPTAG::SPANN +{ +extern std::function(void)> f_createAsyncIO; + +bool FileIO::BlockController::Initialize(SPANN::Options &p_opt) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Initialize(%s, %d)\n", + p_opt.m_ssdMappingFile.c_str(), p_opt.m_spdkBatchSize); + + m_growthThreshold = p_opt.m_growThreshold; + m_growthBlocks = ((std::uint64_t)p_opt.m_growthFileSize) << (30 - PageSizeEx); + m_maxBlocks = ((std::uint64_t)p_opt.m_maxFileSize) << (30 - PageSizeEx); + m_batchSize = p_opt.m_spdkBatchSize; + strcpy(m_filePath, (p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdMappingFile + "_postings").c_str()); + m_disableCheckpoint = p_opt.m_disableCheckpoint; + m_startTime = std::chrono::high_resolution_clock::now(); + + int numblocks = m_batchSize; + m_fileHandle = f_createAsyncIO(); + if (m_fileHandle == nullptr || + !m_fileHandle->Initialize( + m_filePath, +#ifndef _MSC_VER + O_RDWR | O_DIRECT, numblocks, 2, 2, + max(p_opt.m_ioThreads, (2 * max(p_opt.m_searchThreadNum, p_opt.m_iSSDNumberOfThreads) + + p_opt.m_insertThreadNum + p_opt.m_reassignThreadNum + p_opt.m_appendThreadNum)), + ((std::uint64_t)p_opt.m_startFileSize) << 30 +#else + GENERIC_READ | GENERIC_WRITE, numblocks, 2, 2, + (std::uint16_t)(p_opt.m_ioThreads), + ((std::uint64_t)p_opt.m_startFileSize) << 30 +#endif + )) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FileIO::BlockController::Initialize failed\n"); + return false; + } + std::string blockpoolPath = (p_opt.m_recovery) + ? p_opt.m_persistentBufferPath + FolderSep + p_opt.m_ssdMappingFile + "_postings" + : m_filePath; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController:Loading block pool from file:%s\n", + blockpoolPath.c_str()); + ErrorCode ret = LoadBlockPool(blockpoolPath, ((std::uint64_t)p_opt.m_startFileSize) << (30 - PageSizeEx), + !p_opt.m_recovery, p_opt.m_datasetRowsInBlock, p_opt.m_datasetCapacity); + if (ErrorCode::Success != ret) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController:Loading block pool failed!\n"); + return false; + } + return true; +} + +bool FileIO::BlockController::NeedsExpansion(int p_size) +{ + // TODO: Replace RemainBlocks() which internally uses tbb::concurrent_queue::unsafe_size(). + // unsafe_size() is *not thread-safe* and may yield inconsistent results under concurrent access. + size_t available = RemainBlocks(); + size_t total = m_totalAllocatedBlocks.load(); + float ratio = static_cast(available) / static_cast(total); + + return (available < p_size) || (ratio < m_growthThreshold); +} + +bool FileIO::BlockController::ExpandFile(AddressType blocksToAdd) +{ + AddressType currentTotal = m_totalAllocatedBlocks.load(); + + // Cap the growth so we do not exceed the limit + AddressType allowedBlocks = min(blocksToAdd, m_maxBlocks - currentTotal); + if (allowedBlocks <= 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "[ExpandFile] File is already at max capacity (%llu blocks = %.2f GB). Expansion aborted.\n", + static_cast(m_maxBlocks), + static_cast(m_maxBlocks >> (30 - PageSizeEx))); + } + + if (!m_fileHandle->ExpandFile(allowedBlocks * PageSize)) + return false; + + m_available.AddBatch(allowedBlocks); + for (AddressType i = 0; i < allowedBlocks; ++i) + { + m_blockAddresses.push(currentTotal + i); + m_available.Insert(currentTotal + i); + } + + AddressType newTotal = currentTotal + allowedBlocks; + m_totalAllocatedBlocks.store(newTotal); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "[ExpandFile] Expanded file from %llu to %llu blocks (added %llu blocks = %.2f GB)\n", + static_cast(currentTotal), static_cast(newTotal), + static_cast(allowedBlocks), + static_cast(allowedBlocks >> (30 - PageSizeEx))); + + if (allowedBlocks < blocksToAdd) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "[ExpandFile] Requested %llu blocks, but only %llu were allowed due to cap (%llu max)\n", + static_cast(blocksToAdd), static_cast(allowedBlocks), + static_cast(m_maxBlocks)); + } + + return true; +} + +bool FileIO::BlockController::GetBlocks(AddressType *p_data, int p_size) +{ + // Trigger expansion if we're below threshold + if (NeedsExpansion(p_size)) + { + std::unique_lock lock( + m_expandLock); // ensure only one thread tries expandingAdd commentMore actions + if (NeedsExpansion(p_size)) + { // recheck inside lock + AddressType growBy = max(static_cast(m_growthBlocks), (AddressType)p_size); + AddressType total = m_totalAllocatedBlocks.load(); + if (total + growBy <= m_maxBlocks) + { + if (!ExpandFile(growBy)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FileIO::BlockController::GetBlocks: expansion failed\n"); + return false; + } + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "FileIO::BlockController::GetBlocks: cannot expand beyond cap (%llu)\n", + static_cast(m_maxBlocks)); + } + } + } + + for (int i = 0; i < p_size; i++) + { + int retry = 0; + AddressType currBlockAddress = 0xffffffffffffffff; + while (retry < 3 && !m_blockAddresses.try_pop(currBlockAddress)) retry++; + if (retry < 3) + { + p_data[i] = currBlockAddress; + if (!m_available.Reset(currBlockAddress)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "GetBlocks: Allocate a used block id:%lld\n", + currBlockAddress); + } + } + else + return false; + } + return true; +} + +bool FileIO::BlockController::ReleaseBlocks(AddressType *p_data, int p_size) +{ + for (int i = 0; i < p_size; i++) + { + if (!m_available.Insert(p_data[i])) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "ReleaseBlocks: Double release a free block id:%lld\n", + p_data[i]); + continue; + } + if (m_disableCheckpoint) + { + m_blockAddresses.push(p_data[i]); + } + else + { + m_blockAddresses_reserve.push(p_data[i]); + } + } + return true; +} + +bool FileIO::BlockController::ReadBlocks(AddressType *p_data, std::string *p_value, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + if ((uintptr_t)p_data == 0xffffffffffffffff) + { + p_value->resize(0); + return true; + } + + const int64_t postingSize = (int64_t)(p_data[0]); + auto blockNum = (postingSize + PageSize - 1) >> PageSizeEx; + if (blockNum > reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FileIO::BlockController::ReadBlocks: %d > %u\n", (int)blockNum, + reqs->size()); + return false; + } + + AddressType currOffset = 0; + AddressType dataIdx = 1; + for (int i = 0; i < blockNum; i++) + { + Helper::AsyncReadRequest &curr = reqs->at(i); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data[dataIdx] * PageSize; + currOffset += PageSize; + dataIdx++; + } + + std::uint32_t totalReads = m_fileHandle->BatchReadFile(reqs->data(), blockNum, timeout, m_batchSize); + read_submit_vec += blockNum; + read_complete_vec += totalReads; + + if (totalReads < blockNum) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "FileIO::BlockController::ReadBlocks: %u < %u\n", totalReads, + blockNum); + m_batchReadTimeouts++; + return false; + } + + p_value->resize(postingSize); + for (int i = 0; i < blockNum; i++) + { + Helper::AsyncReadRequest &curr = reqs->at(i); + memcpy(p_value->data() + i * PageSize, curr.m_buffer, curr.m_readSize); + } + return true; +} + +bool FileIO::BlockController::ReadBlocks(const std::vector &p_data, std::vector *p_values, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + m_batchReadTimes++; + + std::uint32_t reqcount = 0; + for (size_t i = 0; i < p_data.size(); i++) + { + AddressType *p_data_i = p_data[i]; + + if (p_data_i == nullptr || (uintptr_t)p_data_i == 0xffffffffffffffff) + { + if (p_data_i != nullptr) (*p_values)[i].resize(0); + continue; + } + + std::size_t postingSize = (std::size_t)p_data_i[0]; + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < postingSize) + { + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "FileIO::BlockController::ReadBlocks: %u >= %u\n", reqcount, + reqs->size()); + return false; + } + + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data_i[dataIdx] * PageSize; + currOffset += PageSize; + dataIdx++; + reqcount++; + } + } + + std::uint32_t totalReads = m_fileHandle->BatchReadFile(reqs->data(), reqcount, timeout, m_batchSize); + read_submit_vec += reqcount; + read_complete_vec += totalReads; + + if (totalReads < reqcount) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "FileIO::BlockController::ReadBlocks: %u < %u\n", totalReads, + reqcount); + m_batchReadTimeouts++; + return false; + } + + std::uint32_t reqdone = 0; + for (size_t i = 0; i < p_data.size(); i++) + { + AddressType *p_data_i = p_data[i]; + + if (p_data_i == nullptr || (uintptr_t)p_data_i == 0xffffffffffffffff) + { + continue; + } + + std::size_t postingSize = (std::size_t)p_data_i[0]; + ((*p_values)[i]).resize(postingSize); + AddressType currOffset = 0; + while (currOffset < postingSize) + { + Helper::AsyncReadRequest &curr = reqs->at(reqdone); + memcpy(((*p_values)[i]).data() + currOffset, curr.m_buffer, curr.m_readSize); + currOffset += PageSize; + reqdone++; + } + } + return true; +} + +bool FileIO::BlockController::ReadBlocks(const std::vector &p_data, + std::vector> &p_values, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + m_batchReadTimes++; + std::uint32_t reqcount = 0; + std::uint32_t emptycount = 0; + for (size_t i = 0; i < p_data.size(); i++) + { + AddressType *p_data_i = p_data[i]; + int numPages = (p_values[i].GetPageSize() >> PageSizeEx); + + if (p_data_i == nullptr || (uintptr_t)p_data_i == 0xffffffffffffffff) + { + if (p_data_i != nullptr) p_values[i].SetAvailableSize(0); + for (std::uint32_t r = 0; r < numPages; r++) + { + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = 0; + reqcount++; + emptycount++; + } + continue; + } + + std::size_t postingSize = (std::size_t)p_data_i[0]; + p_values[i].SetAvailableSize(postingSize); + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < postingSize) + { + if (dataIdx > numPages) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "FileIO::BlockController::ReadBlocks: block (%d) pos:%llu/%llu >= buffer page size (%d)\n", + dataIdx - 1, currOffset, postingSize, numPages); + break; + } + + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "FileIO::BlockController::ReadBlocks: req (%u) >= req array size (%u)\n", reqcount, + reqs->size()); + return false; + } + + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data_i[dataIdx] * PageSize; + currOffset += PageSize; + dataIdx++; + reqcount++; + } + + while (dataIdx - 1 < numPages) + { + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "FileIO::BlockController::ReadBlocks: req (%u) >= req array size (%u)\n", reqcount, + reqs->size()); + return false; + } + + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = 0; + dataIdx++; + reqcount++; + emptycount++; + } + } + + std::uint32_t totalReads = m_fileHandle->BatchReadFile(reqs->data(), reqcount, timeout, m_batchSize); + read_submit_vec += reqcount - emptycount; + read_complete_vec += totalReads; + if (totalReads < reqcount - emptycount) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "FileIO::BlockController::ReadBlocks: %u < %u\n", totalReads, + reqcount - emptycount); + m_batchReadTimeouts++; + return false; + } + return true; +} + +bool FileIO::BlockController::WriteBlocks(AddressType *p_data, int p_size, const std::string &p_value, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "WriteBlocks to string with blocknum=%d size=%u!\n", p_size, + // p_value.size()); + AddressType currOffset = 0; + int totalSize = p_value.size(); + for (int i = 0; i < p_size; i++) + { + Helper::AsyncReadRequest &curr = reqs->at(i); + curr.m_readSize = (totalSize - currOffset) < PageSize ? (totalSize - currOffset) : PageSize; + curr.m_offset = p_data[i] * PageSize; + memcpy(curr.m_buffer, p_value.data() + currOffset, curr.m_readSize); + currOffset += PageSize; + } + + std::uint32_t totalWrites = m_fileHandle->BatchWriteFile(reqs->data(), p_size, timeout, m_batchSize); + write_submit_vec += p_size; + write_complete_vec += totalWrites; + if (totalWrites < p_size) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "FileIO::BlockController::WriteBlocks: %u < %d\n", totalWrites, + p_size); + return false; + } + return true; +} + +bool FileIO::BlockController::IOStatistics() +{ + int64_t currReadCount = read_complete_vec; + int64_t read_submit_count = read_submit_vec; + int64_t currWriteCount = write_complete_vec; + int64_t write_submit_count = write_submit_vec; + int64_t read_blocks_time = read_blocks_time_vec; + int64_t read_bytes_count = read_bytes_vec; + int64_t write_bytes_count = write_bytes_vec; + + int currIOCount = currReadCount + currWriteCount; + int diffIOCount = currIOCount - m_preIOCompleteCount; + m_preIOCompleteCount = currIOCount; + + int64_t currBytesCount = read_bytes_count + write_bytes_count; + // int64_t diffBytesCount = currBytesCount - m_preIOBytes; + m_preIOBytes = currBytesCount; + + auto currTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(currTime - m_preTime); + m_preTime = currTime; + + double currIOPS = (double)diffIOCount * 1000 / duration.count(); + double currBandWidth = (double)diffIOCount * PageSize / 1024 * 1000 / 1024 * 1000 / duration.count(); + // double currBandWidth = (double)diffBytesCount / 1024 * 1000 / 1024 * 1000 / duration.count(); + + std::cout << "Diff IO Count: " << diffIOCount << " Time: " << duration.count() << "us" << std::endl; + std::cout << "IOPS: " << currIOPS << "k Bandwidth: " << currBandWidth << "MB/s" << std::endl; + std::cout << "Read Count: " << currReadCount << " Write Count: " << currWriteCount + << " Read Submit Count: " << read_submit_count << " Write Submit Count: " << write_submit_count + << std::endl; + std::cout << "Read Bytes Count: " << read_bytes_count << " Write Bytes Count: " << write_bytes_count << std::endl; + std::cout << "Read Blocks Time: " << read_blocks_time << "ns" << std::endl; + std::cout << "Batch Read Times: " << m_batchReadTimes.load() + << " Batch Read Timeouts: " << m_batchReadTimeouts.load() << std::endl; + return true; +} + +bool FileIO::BlockController::ShutDown() +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO:BlockController:ShutDown!\n"); + Checkpoint(m_filePath); + while (!m_blockAddresses.empty()) + { + AddressType currBlockAddress; + m_blockAddresses.try_pop(currBlockAddress); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "FileIO::BlockController::Close file handler\n"); + m_fileHandle->ShutDown(); + return true; +} + +} // namespace SPTAG::SPANN diff --git a/AnnService/src/Core/SPANN/ExtraSPDKController.cpp b/AnnService/src/Core/SPANN/ExtraSPDKController.cpp new file mode 100644 index 000000000..6427701db --- /dev/null +++ b/AnnService/src/Core/SPANN/ExtraSPDKController.cpp @@ -0,0 +1,810 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef _MSC_VER +#include "inc/Core/SPANN/ExtraSPDKController.h" +#include "inc/Helper/AsyncFileReader.h" + +namespace SPTAG::SPANN +{ +int SPDKIO::BlockController::m_ssdInflight = 0; +int SPDKIO::BlockController::m_ioCompleteCount = 0; +std::unique_ptr SPDKIO::BlockController::m_memBuffer; + +void SPDKIO::BlockController::SpdkBdevEventCallback(enum spdk_bdev_event_type type, struct spdk_bdev *bdev, + void *event_ctx) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SpdkBdevEventCallback: supported bdev event type %d\n", type); +} + +void SPDKIO::BlockController::SpdkBdevIoCallback(struct spdk_bdev_io *bdev_io, bool success, void *cb_arg) +{ + Helper::AsyncReadRequest *currSubIo = (Helper::AsyncReadRequest *)cb_arg; + if (success) + { + m_ioCompleteCount++; + spdk_bdev_free_io(bdev_io); + reinterpret_cast(currSubIo->m_extension)->push(currSubIo); + m_ssdInflight--; + SpdkIoLoop(currSubIo->m_ctrl); + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SpdkBdevIoCallback: I/O failed %p\n", currSubIo); + spdk_app_stop(-1); + } +} + +void SPDKIO::BlockController::SpdkStop(void *arg) +{ + SPDKIO::BlockController *ctrl = (SPDKIO::BlockController *)arg; + // Close I/O channel and bdev + spdk_put_io_channel(ctrl->m_ssdSpdkBdevIoChannel); + spdk_bdev_close(ctrl->m_ssdSpdkBdevDesc); + fprintf(stdout, "SPDKIO::BlockController::SpdkStop: finalized\n"); +} + +void SPDKIO::BlockController::SpdkIoLoop(void *arg) +{ + SPDKIO::BlockController *ctrl = (SPDKIO::BlockController *)arg; + int rc = 0; + Helper::AsyncReadRequest *currSubIo = nullptr; + while (!ctrl->m_ssdSpdkThreadExiting) + { + if (ctrl->m_submittedSubIoRequests.try_pop(currSubIo)) + { + if ((currSubIo->myiocb).aio_lio_opcode == IOCB_CMD_PREAD) + { + rc = spdk_bdev_read(ctrl->m_ssdSpdkBdevDesc, ctrl->m_ssdSpdkBdevIoChannel, currSubIo->m_buffer, + currSubIo->m_offset, PageSize, SpdkBdevIoCallback, currSubIo); + } + else + { + rc = spdk_bdev_write(ctrl->m_ssdSpdkBdevDesc, ctrl->m_ssdSpdkBdevIoChannel, currSubIo->m_buffer, + currSubIo->m_offset, PageSize, SpdkBdevIoCallback, currSubIo); + } + if (rc && rc != -ENOMEM) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "SPDKIO::BlockController::SpdkStart %s failed: %d, shutting down, offset: %ld\n", + ((currSubIo->myiocb).aio_lio_opcode == IOCB_CMD_PREAD) ? "spdk_bdev_read" + : "spdk_bdev_write", + rc, currSubIo->m_offset); + spdk_app_stop(-1); + break; + } + else + { + m_ssdInflight++; + } + } + else if (m_ssdInflight) + { + break; + } + } + if (ctrl->m_ssdSpdkThreadExiting) + { + SpdkStop(ctrl); + } +} + +void SPDKIO::BlockController::SpdkStart(void *arg) +{ + SPDKIO::BlockController *ctrl = (SPDKIO::BlockController *)arg; + + fprintf(stdout, "SPDKIO::BlockController::SpdkStart: using bdev %s\n", ctrl->m_ssdSpdkBdevName); + + int rc = 0; + ctrl->m_ssdSpdkBdev = NULL; + ctrl->m_ssdSpdkBdevDesc = NULL; + + // Open bdev + rc = spdk_bdev_open_ext(ctrl->m_ssdSpdkBdevName, true, SpdkBdevEventCallback, NULL, &ctrl->m_ssdSpdkBdevDesc); + if (rc) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::SpdkStart: spdk_bdev_open_ext failed, %d\n", + rc); + ctrl->m_ssdSpdkThreadStartFailed = true; + spdk_app_stop(-1); + return; + } + ctrl->m_ssdSpdkBdev = spdk_bdev_desc_get_bdev(ctrl->m_ssdSpdkBdevDesc); + + // Open I/O channel + ctrl->m_ssdSpdkBdevIoChannel = spdk_bdev_get_io_channel(ctrl->m_ssdSpdkBdevDesc); + if (ctrl->m_ssdSpdkBdevIoChannel == NULL) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "SPDKIO::BlockController::SpdkStart: spdk_bdev_get_io_channel failed\n"); + spdk_bdev_close(ctrl->m_ssdSpdkBdevDesc); + ctrl->m_ssdSpdkThreadStartFailed = true; + spdk_app_stop(-1); + return; + } + + ctrl->m_ssdSpdkThreadReady = true; + m_ssdInflight = 0; + + SpdkIoLoop(ctrl); +} + +void *SPDKIO::BlockController::InitializeSpdk(void *arg) +{ + SPDKIO::BlockController *ctrl = (SPDKIO::BlockController *)arg; + + struct spdk_app_opts opts; + spdk_app_opts_init(&opts, sizeof(opts)); + opts.name = "spfresh"; + const char *spdkConf = getenv(kSpdkConfEnv); + opts.json_config_file = spdkConf ? spdkConf : ""; + const char *spdkBdevName = getenv(kSpdkBdevNameEnv); + ctrl->m_ssdSpdkBdevName = spdkBdevName ? spdkBdevName : ""; + + int rc; + rc = spdk_app_start(&opts, &SPTAG::SPANN::SPDKIO::BlockController::SpdkStart, arg); + if (rc) + { + ctrl->m_ssdSpdkThreadStartFailed = true; + } + else + { + spdk_app_fini(); + } + pthread_exit(NULL); +} + +bool SPDKIO::BlockController::Initialize(SPANN::Options &p_opt) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::Initialize(%s, %d)\n", + p_opt.m_ssdMappingFile.c_str(), p_opt.m_spdkBatchSize); + + const char *useMemImplEnvStr = getenv(kUseMemImplEnv); + m_useMemImpl = useMemImplEnvStr && !strcmp(useMemImplEnvStr, "1"); + const char *useSsdImplEnvStr = getenv(kUseSsdImplEnv); + m_useSsdImpl = useSsdImplEnvStr && !strcmp(useSsdImplEnvStr, "1"); + if (m_useMemImpl) + { + std::uint64_t memoryBlocks = ((std::uint64_t)p_opt.m_startFileSize) << (30 - PageSizeEx); + if (m_memBuffer == nullptr) + { + m_memBuffer.reset(new char[memoryBlocks >> PageSizeEx]); + } + for (AddressType i = 0; i < memoryBlocks; i++) + { + m_blockAddresses.push(i); + } + return true; + } + else if (m_useSsdImpl) + { + m_growthThreshold = p_opt.m_growThreshold; + m_growthBlocks = ((std::uint64_t)p_opt.m_growthFileSize) << (30 - PageSizeEx); + m_maxBlocks = ((std::uint64_t)p_opt.m_maxFileSize) << (30 - PageSizeEx); + m_batchSize = p_opt.m_spdkBatchSize; + m_filePath = p_opt.m_indexDirectory + FolderSep + p_opt.m_ssdMappingFile + "_postings"; + pthread_create(&m_ssdSpdkTid, NULL, &InitializeSpdk, this); + while (!m_ssdSpdkThreadReady && !m_ssdSpdkThreadStartFailed) + ; + if (m_ssdSpdkThreadStartFailed) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::Initialize failed\n"); + return false; + } + std::string blockpoolPath = + (p_opt.m_recovery) ? p_opt.m_persistentBufferPath + FolderSep + p_opt.m_ssdMappingFile + "_postings" + : m_filePath; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController:Loading block pool from file:%s\n", + blockpoolPath.c_str()); + ErrorCode ret = LoadBlockPool(blockpoolPath, ((std::uint64_t)p_opt.m_startFileSize) << (30 - PageSizeEx), + !p_opt.m_recovery); + if (ErrorCode::Success != ret) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController:Loading block pool failed!\n"); + return false; + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::Initialize failed\n"); + return false; + } +} + +bool SPDKIO::BlockController::NeedsExpansion() +{ + // TODO: Replace RemainBlocks() which internally uses tbb::concurrent_queue::unsafe_size(). + // unsafe_size() is *not thread-safe* and may yield inconsistent results under concurrent access. + size_t available = RemainBlocks(); + size_t total = m_totalAllocatedBlocks.load(); + float ratio = static_cast(available) / static_cast(total); + + return (ratio < m_growthThreshold); +} + +bool SPDKIO::BlockController::ExpandFile(AddressType blocksToAdd) +{ + AddressType currentTotal = m_totalAllocatedBlocks.load(); + + // Cap the growth so we do not exceed the limit + AddressType allowedBlocks = min(blocksToAdd, m_maxBlocks - currentTotal); + if (allowedBlocks <= 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "[ExpandFile] File is already at max capacity (%llu blocks = %.2f GB). Expansion aborted.\n", + static_cast(m_maxBlocks), + static_cast(m_maxBlocks >> (30 - PageSizeEx))); + } + + for (AddressType i = 0; i < allowedBlocks; ++i) + { + m_blockAddresses.push(currentTotal + i); + } + + AddressType newTotal = currentTotal + allowedBlocks; + m_totalAllocatedBlocks.store(newTotal); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "[ExpandFile] Expanded file from %llu to %llu blocks (added %llu blocks = %.2f GB)\n", + static_cast(currentTotal), static_cast(newTotal), + static_cast(allowedBlocks), + static_cast(allowedBlocks >> (30 - PageSizeEx))); + + if (allowedBlocks < blocksToAdd) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "[ExpandFile] Requested %llu blocks, but only %llu were allowed due to cap (%llu max)\n", + static_cast(blocksToAdd), static_cast(allowedBlocks), + static_cast(m_maxBlocks)); + } + + return true; +} + +// get p_size blocks from front, and fill in p_data array +bool SPDKIO::BlockController::GetBlocks(AddressType *p_data, int p_size) +{ + + if (NeedsExpansion()) + { + std::unique_lock lock( + m_expandLock); // ensure only one thread tries expandingAdd commentMore actions + if (NeedsExpansion()) + { // recheck inside lock + AddressType growBy = static_cast(m_growthBlocks); + AddressType total = m_totalAllocatedBlocks.load(); + if (total + growBy <= m_maxBlocks) + { + if (!ExpandFile(growBy)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::GetBlocks: expansion failed\n"); + return false; + } + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, + "SPDKIO::BlockController::GetBlocks: cannot expand beyond cap (%llu)\n", + static_cast(m_maxBlocks)); + } + } + } + + AddressType currBlockAddress = 0; + if (m_useMemImpl || m_useSsdImpl) + { + for (int i = 0; i < p_size; i++) + { + while (!m_blockAddresses.try_pop(currBlockAddress)) + ; + p_data[i] = currBlockAddress; + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::GetBlocks failed\n"); + return false; + } +} + +// release p_size blocks, put them at the end of the queue +bool SPDKIO::BlockController::ReleaseBlocks(AddressType *p_data, int p_size) +{ + if (m_useMemImpl || m_useSsdImpl) + { + for (int i = 0; i < p_size; i++) + { + // m_blockAddresses.push(p_data[i]); + m_blockAddresses_reserve.push(p_data[i]); + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReleaseBlocks failed\n"); + return false; + } +} + +// read a posting list. p_data[0] is the total data size, +// p_data[1], p_data[2], ..., p_data[((p_data[0] + PageSize - 1) >> PageSizeEx)] are the addresses of the blocks +// concat all the block contents together into p_value string. +bool SPDKIO::BlockController::ReadBlocks(AddressType *p_data, std::string *p_value, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + if (m_useMemImpl) + { + p_value->resize(p_data[0]); + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < p_data[0]) + { + AddressType readSize = (p_data[0] - currOffset) < PageSize ? (p_data[0] - currOffset) : PageSize; + memcpy((char *)p_value->data() + currOffset, m_memBuffer.get() + p_data[dataIdx] * PageSize, readSize); + currOffset += PageSize; + dataIdx++; + } + return true; + } + else if (m_useSsdImpl) + { + + // Clear timeout I/Os + Helper::AsyncReadRequest *currSubIo; + Helper::RequestQueue *completeQueue = (Helper::RequestQueue *)(reqs->at(0).m_extension); + while (completeQueue->try_pop(currSubIo)) + ; + + size_t postingSize = (size_t)p_data[0]; + p_value->resize(postingSize); + + AddressType currOffset = 0; + AddressType dataIdx = 1; + int in_flight = 0; + auto t1 = std::chrono::high_resolution_clock::now(); + // Submit all I/Os + while (currOffset < postingSize || in_flight) + { + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) + { + return false; + } + // Try submit + if (currOffset < postingSize) + { + Helper::AsyncReadRequest &curr = reqs->at(dataIdx - 1); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data[dataIdx] * PageSize; + curr.m_payload = (char *)p_value->data() + currOffset; + curr.m_ctrl = this; + curr.myiocb.aio_lio_opcode = IOCB_CMD_PREAD; + m_submittedSubIoRequests.push(&curr); + currOffset += PageSize; + dataIdx++; + in_flight++; + } + // Try complete + if (in_flight && completeQueue->try_pop(currSubIo)) + { + memcpy((char *)(currSubIo->m_payload), currSubIo->m_buffer, currSubIo->m_readSize); + in_flight--; + } + else + { + std::this_thread::sleep_for(std::chrono::microseconds(1)); // Sleep for a short time + } + } + + if (in_flight > 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Read timeout happens! Skip %d requests...\n", in_flight); + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReadBlocks single failed\n"); + return false; + } +} + +// parallel read a list of posting lists. +bool SPDKIO::BlockController::ReadBlocks(std::vector &p_data, std::vector *p_values, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + if (m_useMemImpl) + { + p_values->resize(p_data.size()); + for (size_t i = 0; i < p_data.size(); i++) + { + ReadBlocks(p_data[i], &((*p_values)[i]), timeout, reqs); + } + return true; + } + else if (m_useSsdImpl) + { + // Temporarily disable timeout + + // Convert request format to SubIoRequests + auto t1 = std::chrono::high_resolution_clock::now(); + p_values->resize(p_data.size()); + + std::uint32_t reqcount = 0; + for (size_t i = 0; i < p_data.size(); i++) + { + AddressType *p_data_i = p_data[i]; + std::string *p_value = &((*p_values)[i]); + + if (p_data_i == nullptr) + { + continue; + } + + std::size_t postingSize = (std::size_t)p_data_i[0]; + p_value->resize(postingSize); + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < postingSize) + { + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data_i[dataIdx] * PageSize; + curr.m_payload = (char *)p_value->data() + currOffset; + curr.m_ctrl = this; + curr.myiocb.aio_lio_opcode = IOCB_CMD_PREAD; + currOffset += PageSize; + dataIdx++; + reqcount++; + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReadBlocks: %u >= %u\n", + reqcount, reqs->size()); + return false; + } + } + } + + // Clear timeout I/Os + Helper::AsyncReadRequest *currSubIo; + Helper::RequestQueue *completeQueue = (Helper::RequestQueue *)(reqs->at(0).m_extension); + while (completeQueue->try_pop(currSubIo)) + ; + + for (int currSubIoStartId = 0; currSubIoStartId < reqcount; currSubIoStartId += m_batchSize) + { + int currSubIoEndId = + (currSubIoStartId + m_batchSize) > reqcount ? reqcount : currSubIoStartId + m_batchSize; + int currSubIoIdx = currSubIoStartId; + + int in_flight = 0; + while (currSubIoIdx < currSubIoEndId || in_flight) + { + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) + { + break; + } + + // Try submit + if (currSubIoIdx < currSubIoEndId) + { + Helper::AsyncReadRequest &curr = reqs->at(currSubIoIdx); + m_submittedSubIoRequests.push(&curr); + in_flight++; + currSubIoIdx++; + } + + // Try complete + if (in_flight && completeQueue->try_pop(currSubIo)) + { + memcpy((char *)(currSubIo->m_payload), currSubIo->m_buffer, currSubIo->m_readSize); + in_flight--; + } + else + { + std::this_thread::sleep_for(std::chrono::microseconds(1)); // 休眠一个非常短的时间 + } + } + + if (in_flight > 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Read timeout happens! Skip %d requests...\n", + (reqcount - currSubIoEndId + in_flight)); + } + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) + { + break; + } + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReadBlocks batch failed\n"); + return false; + } +} + +bool SPDKIO::BlockController::ReadBlocks(std::vector &p_data, + std::vector> &p_values, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + if (m_useMemImpl) + { + for (size_t i = 0; i < p_data.size(); i++) + { + std::uint8_t *p_value = p_values[i].GetBuffer(); + auto &p_data_i = p_data[i]; + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < p_data_i[0]) + { + AddressType readSize = (p_data_i[0] - currOffset) < PageSize ? (p_data_i[0] - currOffset) : PageSize; + memcpy((char *)p_value + currOffset, m_memBuffer.get() + p_data_i[dataIdx] * PageSize, readSize); + currOffset += PageSize; + dataIdx++; + } + } + return true; + } + else if (m_useSsdImpl) + { + // Temporarily disable timeout + + // Convert request format to SubIoRequests + auto t1 = std::chrono::high_resolution_clock::now(); + std::uint32_t reqcount = 0; + std::uint32_t emptycount = 0; + for (size_t i = 0; i < p_data.size(); i++) + { + AddressType *p_data_i = p_data[i]; + int numPages = (p_values[i].GetPageSize() >> PageSizeEx); + + if (p_data_i == nullptr) + { + p_values[i].SetAvailableSize(0); + for (std::uint32_t r = 0; r < numPages; r++) + { + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = 0; + reqcount++; + emptycount++; + } + continue; + } + + std::size_t postingSize = (std::size_t)p_data_i[0]; + p_values[i].SetAvailableSize(postingSize); + AddressType currOffset = 0; + AddressType dataIdx = 1; + while (currOffset < postingSize) + { + if (dataIdx > numPages) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "SPDKIO::BlockController::ReadBlocks: block (%d) >= buffer page size (%d)\n", + dataIdx - 1, numPages); + return false; + } + + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "SPDKIO::BlockController::ReadBlocks: req (%u) >= req array size (%u)\n", reqcount, + reqs->size()); + return false; + } + + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = (postingSize - currOffset) < PageSize ? (postingSize - currOffset) : PageSize; + curr.m_offset = p_data_i[dataIdx] * PageSize; + curr.m_ctrl = this; + curr.myiocb.aio_lio_opcode = IOCB_CMD_PREAD; + currOffset += PageSize; + dataIdx++; + reqcount++; + } + + while (dataIdx - 1 < numPages) + { + if (reqcount >= reqs->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "SPDKIO::BlockController::ReadBlocks: req (%u) >= req array size (%u)\n", reqcount, + reqs->size()); + return false; + } + + Helper::AsyncReadRequest &curr = reqs->at(reqcount); + curr.m_readSize = 0; + dataIdx++; + reqcount++; + emptycount++; + } + } + + // Clear timeout I/Os + Helper::AsyncReadRequest *currSubIo; + Helper::RequestQueue *completeQueue = (Helper::RequestQueue *)(reqs->at(0).m_extension); + while (completeQueue->try_pop(currSubIo)) + ; + + int realcount = reqcount - emptycount; + int reqIdx = 0; + for (int currSubIoStartId = 0; currSubIoStartId < realcount; currSubIoStartId += m_batchSize) + { + int currSubIoEndId = + (currSubIoStartId + m_batchSize) > realcount ? realcount : currSubIoStartId + m_batchSize; + int currSubIoIdx = currSubIoStartId; + + int in_flight = 0; + while (currSubIoIdx < currSubIoEndId || in_flight) + { + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) + { + break; + } + + // Try submit + if (currSubIoIdx < currSubIoEndId) + { + while (reqIdx < reqcount && (reqs->at(reqIdx)).m_readSize == 0) + reqIdx++; + Helper::AsyncReadRequest &curr = reqs->at(reqIdx); + m_submittedSubIoRequests.push(&curr); + in_flight++; + currSubIoIdx++; + reqIdx++; + } + + // Try complete + if (in_flight && completeQueue->try_pop(currSubIo)) + { + in_flight--; + } + else + { + std::this_thread::sleep_for(std::chrono::microseconds(1)); // 休眠一个非常短的时间 + } + } + + if (in_flight > 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Read timeout happens! Skip %d requests...\n", + (realcount - currSubIoEndId + in_flight)); + } + auto t2 = std::chrono::high_resolution_clock::now(); + if (std::chrono::duration_cast(t2 - t1) > timeout) + { + break; + } + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReadBlocks batch failed\n"); + return false; + } +} + +// write p_value into p_size blocks start from p_data +bool SPDKIO::BlockController::WriteBlocks(AddressType *p_data, int p_size, const std::string &p_value, + const std::chrono::microseconds &timeout, + std::vector *reqs) +{ + if (m_useMemImpl) + { + for (int i = 0; i < p_size; i++) + { + memcpy(m_memBuffer.get() + p_data[i] * PageSize, p_value.data() + i * PageSize, PageSize); + } + return true; + } + else if (m_useSsdImpl) + { + Helper::AsyncReadRequest *currSubIo; + Helper::RequestQueue *completeQueue = (Helper::RequestQueue *)(reqs->at(0).m_extension); + + int totalSize = p_value.size(); + AddressType currBlockIdx = 0; + int inflight = 0; + // Submit all I/Os + while (currBlockIdx < p_size || inflight) + { + // Try submit + if (currBlockIdx < p_size) + { + Helper::AsyncReadRequest &curr = reqs->at(currBlockIdx); + int currOffset = currBlockIdx * PageSize; + curr.m_readSize = (totalSize - currOffset) < PageSize ? (totalSize - currOffset) : PageSize; + curr.m_offset = p_data[currBlockIdx] * PageSize; + memcpy(curr.m_buffer, const_cast(p_value.data()) + currOffset, curr.m_readSize); + curr.m_ctrl = this; + curr.myiocb.aio_lio_opcode = IOCB_CMD_PWRITE; + m_submittedSubIoRequests.push(&curr); + currBlockIdx++; + inflight++; + } + + // Try complete + if (inflight && completeQueue->try_pop(currSubIo)) + { + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO::BlockController::WriteBlocks: completed %d inflight + // requests\n", inflight); + inflight--; + } + else + { + std::this_thread::sleep_for(std::chrono::microseconds(1)); // 休眠一个非常短的时间 + } + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ReadBlocks single failed\n"); + return false; + } +} + +bool SPDKIO::BlockController::IOStatistics() +{ + int currIOCount = m_ioCompleteCount; + int diffIOCount = currIOCount - m_preIOCompleteCount; + m_preIOCompleteCount = currIOCount; + + auto currTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(currTime - m_preTime); + m_preTime = currTime; + + double currIOPS = (double)diffIOCount * 1000 / duration.count(); + double currBandWidth = (double)diffIOCount * PageSize / 1024 * 1000 / 1024 * 1000 / duration.count(); + + std::cout << "IOPS: " << currIOPS << "k Bandwidth: " << currBandWidth << "MB/s" << std::endl; + + return true; +} + +bool SPDKIO::BlockController::ShutDown() +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SPDKIO:BlockController:ShutDown!\n"); + if (m_useMemImpl) + { + while (!m_blockAddresses.empty()) + { + AddressType currBlockAddress; + m_blockAddresses.try_pop(currBlockAddress); + } + return true; + } + else if (m_useSsdImpl) + { + m_ssdSpdkThreadExiting = true; + spdk_app_start_shutdown(); + pthread_join(m_ssdSpdkTid, NULL); + Checkpoint(m_filePath); + while (!m_blockAddresses.empty()) + { + AddressType currBlockAddress; + m_blockAddresses.try_pop(currBlockAddress); + } + return true; + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPDKIO::BlockController::ShutDown failed\n"); + return false; + } +} + +} // namespace SPTAG::SPANN +#endif diff --git a/AnnService/src/Core/SPANN/SPANNIndex.cpp b/AnnService/src/Core/SPANN/SPANNIndex.cpp index 900493196..3c52c08f2 100644 --- a/AnnService/src/Core/SPANN/SPANNIndex.cpp +++ b/AnnService/src/Core/SPANN/SPANNIndex.cpp @@ -3,835 +3,1666 @@ #include "inc/Core/SPANN/Index.h" #include "inc/Helper/VectorSetReaders/MemoryReader.h" -#include "inc/Core/SPANN/ExtraFullGraphSearcher.h" +#include "inc/Core/SPANN/ExtraDynamicSearcher.h" +#include "inc/Core/SPANN/ExtraStaticSearcher.h" #include +#include +#include -#pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4244) // '=' : conversion from 'int' to 'short', possible loss of data -#pragma warning(disable:4127) // conditional expression is constant +#include "inc/Core/ResultIterator.h" +#include "inc/Core/SPANN/SPANNResultIterator.h" + +#pragma warning(disable : 4242) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4244) // '=' : conversion from 'int' to 'short', possible loss of data +#pragma warning(disable : 4127) // conditional expression is constant namespace SPTAG { - namespace SPANN +template thread_local std::unique_ptr COMMON::ThreadLocalWorkSpaceFactory::m_workspace; +namespace SPANN +{ +EdgeCompare Selection::g_edgeComparer; + +std::function(void)> f_createAsyncIO = []() -> std::shared_ptr { + return std::shared_ptr(new Helper::AsyncFileIO()); +}; + +template bool Index::CheckHeadIndexType() +{ + SPTAG::VectorValueType v1 = m_index->GetVectorValueType(), v2 = GetEnumValueType(); + if (v1 != v2) { - std::atomic_int ExtraWorkSpace::g_spaceCount(0); - EdgeCompare Selection::g_edgeComparer; + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, "Head index and vectors don't have the same value types, which are %s %s\n", + SPTAG::Helper::Convert::ConvertToString(v1).c_str(), SPTAG::Helper::Convert::ConvertToString(v2).c_str()); + if (!m_pQuantizer) + return false; + } + return true; +} - std::function(void)> f_createAsyncIO = []() -> std::shared_ptr { return std::shared_ptr(new Helper::AsyncFileIO()); }; +template void Index::SetQuantizer(std::shared_ptr quantizer) +{ + m_pQuantizer = quantizer; + if (m_pQuantizer) + { + m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_options.m_distCalcMethod); + m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) + ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() + : 1; + } + else + { + m_fComputeDistance = COMMON::DistanceCalcSelector(m_options.m_distCalcMethod); + m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) + ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() + : 1; + } + if (m_index) + { + m_index->SetQuantizer(quantizer); + } +} - template - bool Index::CheckHeadIndexType() { - SPTAG::VectorValueType v1 = m_index->GetVectorValueType(), v2 = GetEnumValueType(); - if (v1 != v2) { - LOG(Helper::LogLevel::LL_Error, "Head index and vectors don't have the same value types, which are %s %s\n", - SPTAG::Helper::Convert::ConvertToString(v1).c_str(), - SPTAG::Helper::Convert::ConvertToString(v2).c_str() - ); - if (!m_pQuantizer) return false; - } - return true; - } +template ErrorCode Index::LoadConfig(Helper::IniReader &p_reader) +{ + IndexAlgoType algoType = p_reader.GetParameter("Base", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = p_reader.GetParameter("Base", "ValueType", VectorValueType::Undefined); + if ((m_index = CreateInstance(algoType, valueType)) == nullptr) + return ErrorCode::FailedParseValue; - template - void Index::SetQuantizer(std::shared_ptr quantizer) + std::string sections[] = {"Base", "SelectHead", "BuildHead", "BuildSSDIndex"}; + for (int i = 0; i < 4; i++) + { + auto parameters = p_reader.GetParameters(sections[i].c_str()); + for (auto iter = parameters.begin(); iter != parameters.end(); iter++) { - m_pQuantizer = quantizer; - if (m_pQuantizer) - { - m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_options.m_distCalcMethod); - m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; - } - else - { - m_fComputeDistance = COMMON::DistanceCalcSelector(m_options.m_distCalcMethod); - m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; - } - if (m_index) - { - m_index->SetQuantizer(quantizer); - } + SetParameter(iter->first.c_str(), iter->second.c_str(), sections[i].c_str()); } + } - template - ErrorCode Index::LoadConfig(Helper::IniReader& p_reader) - { - IndexAlgoType algoType = p_reader.GetParameter("Base", "IndexAlgoType", IndexAlgoType::Undefined); - VectorValueType valueType = p_reader.GetParameter("Base", "ValueType", VectorValueType::Undefined); - if ((m_index = CreateInstance(algoType, valueType)) == nullptr) return ErrorCode::FailedParseValue; + if (m_pQuantizer) + { + m_pQuantizer->SetEnableADC(m_options.m_enableADC); + } - std::string sections[] = { "Base", "SelectHead", "BuildHead", "BuildSSDIndex" }; - for (int i = 0; i < 4; i++) { - auto parameters = p_reader.GetParameters(sections[i].c_str()); - for (auto iter = parameters.begin(); iter != parameters.end(); iter++) { - SetParameter(iter->first.c_str(), iter->second.c_str(), sections[i].c_str()); - } - } + return ErrorCode::Success; +} - if (m_pQuantizer) - { - m_pQuantizer->SetEnableADC(m_options.m_enableADC); - } +template ErrorCode Index::LoadIndexDataFromMemory(const std::vector &p_indexBlobs) +{ + /** Need to modify **/ + m_index->SetQuantizer(m_pQuantizer); + if (!m_options.m_persistentBufferPath.empty() && !direxists(m_options.m_persistentBufferPath.c_str())) + mkdir(m_options.m_persistentBufferPath.c_str()); - return ErrorCode::Success; - } + if (m_index->LoadIndexDataFromMemory(p_indexBlobs) != ErrorCode::Success) + return ErrorCode::Fail; - template - ErrorCode Index::LoadIndexDataFromMemory(const std::vector& p_indexBlobs) - { - m_index->SetQuantizer(m_pQuantizer); - if (m_index->LoadIndexDataFromMemory(p_indexBlobs) != ErrorCode::Success) return ErrorCode::Fail; + m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); + // m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); + // m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); + m_index->UpdateIndex(); + m_index->SetReady(true); - m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); - m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); - m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); - m_index->UpdateIndex(); - m_index->SetReady(true); + if (m_options.m_storage == Storage::STATIC) + { + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraStaticSearcher()); + else + m_extraSearcher.reset(new ExtraStaticSearcher()); + } + else + { + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + else + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + } - if (m_pQuantizer) - { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); - } - else - { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); - } - - if (!m_extraSearcher->LoadIndex(m_options)) return ErrorCode::Fail; + if (!m_extraSearcher->LoadIndex(m_options, m_versionMap, m_vectorTranslateMap, m_index)) + return ErrorCode::Fail; - m_vectorTranslateMap.reset((std::uint64_t*)(p_indexBlobs.back().Data()), [=](std::uint64_t* ptr) {}); - - omp_set_num_threads(m_options.m_iSSDNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); - return ErrorCode::Success; - } + m_vectorTranslateMap.Initialize(m_index->GetNumSamples(), 1, m_index->m_iDataBlockSize, m_index->m_iDataCapacity, + p_indexBlobs.back().Data(), false); - template - ErrorCode Index::LoadIndexData(const std::vector>& p_indexStreams) - { - m_index->SetQuantizer(m_pQuantizer); - if (m_index->LoadIndexData(p_indexStreams) != ErrorCode::Success) return ErrorCode::Fail; + return ErrorCode::Success; +} - m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); - m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); - m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); - m_index->UpdateIndex(); - m_index->SetReady(true); +template +ErrorCode Index::LoadIndexData(const std::vector> &p_indexStreams) +{ + m_index->SetQuantizer(m_pQuantizer); + if (!m_options.m_persistentBufferPath.empty() && !direxists(m_options.m_persistentBufferPath.c_str())) + mkdir(m_options.m_persistentBufferPath.c_str()); - if (m_pQuantizer) - { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); - } - else + auto headfiles = m_index->GetIndexFiles(); + if (m_options.m_recovery) + { + std::shared_ptr> files(new std::vector); + auto headfiles = m_index->GetIndexFiles(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recovery: Loading another in-memory index\n"); + std::string filename = m_options.m_persistentBufferPath + FolderSep + m_options.m_headIndexFolder; + for (auto file : *headfiles) + { + files->push_back(filename + FolderSep + file); + } + std::vector> handles; + for (std::string &f : *files) + { + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize(f.c_str(), std::ios::binary | std::ios::in)) { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open file %s!\n", f.c_str()); + ptr = nullptr; } + handles.push_back(std::move(ptr)); + } + if (m_index->LoadIndexData(handles) != ErrorCode::Success) + return ErrorCode::Fail; + } + else if (m_index->LoadIndexData(p_indexStreams) != ErrorCode::Success) + return ErrorCode::Fail; + + m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); + m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); + m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); + m_index->UpdateIndex(); + m_index->SetReady(true); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading headID map\n"); + m_vectorTranslateMap.Load(p_indexStreams[m_index->GetIndexFiles()->size()], m_index->m_iDataBlockSize, + m_index->m_iDataCapacity); + + std::string kvpath = m_options.m_indexDirectory + FolderSep + m_options.m_KVFile; + std::string ssdmappingpath = m_options.m_indexDirectory + FolderSep + m_options.m_ssdMappingFile; + if (m_options.m_recovery) + { + kvpath = m_options.m_persistentBufferPath + FolderSep + m_options.m_KVFile; + ssdmappingpath = m_options.m_persistentBufferPath + FolderSep + m_options.m_ssdMappingFile; + } - if (!m_extraSearcher->LoadIndex(m_options)) return ErrorCode::Fail; + if (m_options.m_storage == Storage::STATIC) + { + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraStaticSearcher()); + else + m_extraSearcher.reset(new ExtraStaticSearcher()); + } + else + { + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + else + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + } - m_vectorTranslateMap.reset(new std::uint64_t[m_index->GetNumSamples()], std::default_delete()); - IOBINARY(p_indexStreams[m_index->GetIndexFiles()->size()], ReadBinary, sizeof(std::uint64_t) * m_index->GetNumSamples(), reinterpret_cast(m_vectorTranslateMap.get())); + if (m_options.m_storage != Storage::STATIC && !m_extraSearcher->Available()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Extrasearcher is not available and failed to initialize.\n"); + return ErrorCode::Fail; + } - omp_set_num_threads(m_options.m_iSSDNumberOfThreads); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); - return ErrorCode::Success; - } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading storage\n"); + if (!m_extraSearcher->LoadIndex(m_options, m_versionMap, m_vectorTranslateMap, m_index)) + return ErrorCode::Fail; - template - ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) - { - IOSTRING(p_configOut, WriteString, "[Base]\n"); -#define DefineBasicParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + std::string("\n")).c_str()); \ + if ((m_options.m_storage != Storage::STATIC) && m_options.m_preReassign) + { + if (m_extraSearcher->RefineIndex(m_index) != ErrorCode::Success) + return ErrorCode::Fail; + } + + return ErrorCode::Success; +} + +template ErrorCode Index::SaveConfig(std::shared_ptr p_configOut) +{ + IOSTRING(p_configOut, WriteString, "[Base]\n"); +#define DefineBasicParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + \ + std::string("\n")) \ + .c_str()); #include "inc/Core/SPANN/ParameterDefinitionList.h" #undef DefineBasicParameter - IOSTRING(p_configOut, WriteString, "[SelectHead]\n"); -#define DefineSelectHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + std::string("\n")).c_str()); \ + IOSTRING(p_configOut, WriteString, "[SelectHead]\n"); +#define DefineSelectHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + \ + std::string("\n")) \ + .c_str()); #include "inc/Core/SPANN/ParameterDefinitionList.h" #undef DefineSelectHeadParameter - IOSTRING(p_configOut, WriteString, "[BuildHead]\n"); -#define DefineBuildHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + std::string("\n")).c_str()); \ + IOSTRING(p_configOut, WriteString, "[BuildHead]\n"); +#define DefineBuildHeadParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + \ + std::string("\n")) \ + .c_str()); #include "inc/Core/SPANN/ParameterDefinitionList.h" #undef DefineBuildHeadParameter - m_index->SaveConfig(p_configOut); + m_index->SaveConfig(p_configOut); - Helper::Convert::ConvertStringTo(m_index->GetParameter("HashTableExponent").c_str(), m_options.m_hashExp); - IOSTRING(p_configOut, WriteString, "[BuildSSDIndex]\n"); -#define DefineSSDParameter(VarName, VarType, DefaultValue, RepresentStr) \ - IOSTRING(p_configOut, WriteString, (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + std::string("\n")).c_str()); \ + Helper::Convert::ConvertStringTo(m_index->GetParameter("HashTableExponent").c_str(), m_options.m_hashExp); + IOSTRING(p_configOut, WriteString, "[BuildSSDIndex]\n"); +#define DefineSSDParameter(VarName, VarType, DefaultValue, RepresentStr) \ + IOSTRING(p_configOut, WriteString, \ + (RepresentStr + std::string("=") + SPTAG::Helper::Convert::ConvertToString(m_options.VarName) + \ + std::string("\n")) \ + .c_str()); #include "inc/Core/SPANN/ParameterDefinitionList.h" #undef DefineSSDParameter - IOSTRING(p_configOut, WriteString, "\n"); - return ErrorCode::Success; - } - - template - ErrorCode Index::SaveIndexData(const std::vector>& p_indexStreams) - { - if (m_index == nullptr || m_vectorTranslateMap == nullptr) return ErrorCode::EmptyIndex; - - ErrorCode ret; - if ((ret = m_index->SaveIndexData(p_indexStreams)) != ErrorCode::Success) return ret; - - IOBINARY(p_indexStreams[m_index->GetIndexFiles()->size()], WriteBinary, sizeof(std::uint64_t) * m_index->GetNumSamples(), (char*)(m_vectorTranslateMap.get())); - return ErrorCode::Success; - } - -#pragma region K-NN search - - template - ErrorCode Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const - { - if (!m_bReady) return ErrorCode::EmptyIndex; - - COMMON::QueryResultSet* p_queryResults; - if (p_query.GetResultNum() >= m_options.m_searchInternalResultNum) - p_queryResults = (COMMON::QueryResultSet*) & p_query; - else - p_queryResults = new COMMON::QueryResultSet((const T*)p_query.GetTarget(), m_options.m_searchInternalResultNum); - - m_index->SearchIndex(*p_queryResults); - - if (m_pQuantizer.get() != m_index->m_pQuantizer.get()) { p_queryResults->SetTarget(p_queryResults->GetTarget(), m_index->m_pQuantizer); } - std::shared_ptr workSpace = nullptr; - if (m_extraSearcher != nullptr) { - workSpace = m_workSpacePool->Rent(); - workSpace->m_deduper.clear(); - workSpace->m_postingIDs.clear(); - - float limitDist = p_queryResults->GetResult(0)->Dist * m_options.m_maxDistRatio; - for (int i = 0; i < p_queryResults->GetResultNum(); ++i) - { - auto res = p_queryResults->GetResult(i); - if (res->VID == -1) break; - - auto postingID = res->VID; - res->VID = static_cast((m_vectorTranslateMap.get())[res->VID]); - if (res->VID == MaxSize) res->Dist = MaxDist; - - // Don't do disk reads for irrelevant pages - if (workSpace->m_postingIDs.size() >= m_options.m_searchInternalResultNum || - (limitDist > 0.1 && res->Dist > limitDist) || - !m_extraSearcher->CheckValidPosting(postingID)) - continue; - workSpace->m_postingIDs.emplace_back(postingID); - } - - p_queryResults->Reverse(); - m_extraSearcher->SearchIndex(workSpace.get(), *p_queryResults, m_index, nullptr); - p_queryResults->SortResult(); - m_workSpacePool->Return(workSpace); - } - - if (p_query.GetResultNum() < m_options.m_searchInternalResultNum) { - std::copy(p_queryResults->GetResults(), p_queryResults->GetResults() + p_query.GetResultNum(), p_query.GetResults()); - delete p_queryResults; - } + IOSTRING(p_configOut, WriteString, "\n"); + return ErrorCode::Success; +} - if (p_query.WithMeta() && nullptr != m_pMetadata) - { - for (int i = 0; i < p_query.GetResultNum(); ++i) - { - SizeType result = p_query.GetResult(i)->VID; - p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); - } - } - return ErrorCode::Success; - } +template +ErrorCode Index::SaveIndexData(const std::vector> &p_indexStreams) +{ + if (m_index == nullptr || m_versionMap.Count() == 0) + return ErrorCode::EmptyIndex; - template - ErrorCode Index::DebugSearchDiskIndex(QueryResult& p_query, int p_subInternalResultNum, int p_internalResultNum, - SearchStats* p_stats, std::set* truth, std::map>* found) - { - if (nullptr == m_extraSearcher) return ErrorCode::EmptyIndex; + while (!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } - COMMON::QueryResultSet newResults(*((COMMON::QueryResultSet*)&p_query)); - for (int i = 0; i < newResults.GetResultNum(); ++i) - { - auto res = newResults.GetResult(i); - if (res->VID == -1) break; + ErrorCode ret; + if ((ret = m_index->SaveIndexData(p_indexStreams)) != ErrorCode::Success) + return ret; - auto global_VID = static_cast((m_vectorTranslateMap.get())[res->VID]); - if (truth && truth->count(global_VID)) (*found)[res->VID].insert(global_VID); - res->VID = global_VID; - if (res->VID == MaxSize) res->Dist = MaxDist; - } - newResults.Reverse(); + if ((ret = m_vectorTranslateMap.Save(p_indexStreams[m_index->GetIndexFiles()->size()])) != ErrorCode::Success) + return ret; - auto auto_ws = m_workSpacePool->Rent(); - auto_ws->m_deduper.clear(); + if ((ret = m_extraSearcher->Checkpoint(m_options.m_indexDirectory)) != ErrorCode::Success) + return ret; + return ErrorCode::Success; +} - int partitions = (p_internalResultNum + p_subInternalResultNum - 1) / p_subInternalResultNum; - float limitDist = p_query.GetResult(0)->Dist * m_options.m_maxDistRatio; - for (SizeType p = 0; p < partitions; p++) { - int subInternalResultNum = min(p_subInternalResultNum, p_internalResultNum - p_subInternalResultNum * p); +#pragma region K - NN search - auto_ws->m_postingIDs.clear(); +template ErrorCode Index::SearchIndex(QueryResult &p_query, bool p_searchDeleted) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - for (int i = p * p_subInternalResultNum; i < p * p_subInternalResultNum + subInternalResultNum; i++) - { - auto res = p_query.GetResult(i); - if (res->VID == -1 || (limitDist > 0.1 && res->Dist > limitDist)) break; - auto_ws->m_postingIDs.emplace_back(res->VID); - } + COMMON::QueryResultSet *p_queryResults; + if (p_query.GetResultNum() >= m_options.m_searchInternalResultNum) + p_queryResults = (COMMON::QueryResultSet *)&p_query; + else + p_queryResults = + new COMMON::QueryResultSet((const T *)p_query.GetTarget(), m_options.m_searchInternalResultNum); - m_extraSearcher->SearchIndex(auto_ws.get(), newResults, m_index, p_stats, truth, found); - } - - m_workSpacePool->Return(auto_ws); + ErrorCode ret; + if ((ret = m_index->SearchIndex(*p_queryResults)) != ErrorCode::Success) + return ret; - newResults.SortResult(); - std::copy(newResults.GetResults(), newResults.GetResults() + newResults.GetResultNum(), p_query.GetResults()); - return ErrorCode::Success; + if (m_extraSearcher != nullptr) + { + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); } -#pragma endregion - - template - void Index::SelectHeadAdjustOptions(int p_vectorCount) { - LOG(Helper::LogLevel::LL_Info, "Begin Adjust Parameters...\n"); + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); - if (m_options.m_headVectorCount != 0) m_options.m_ratio = m_options.m_headVectorCount * 1.0 / p_vectorCount; - int headCnt = static_cast(std::round(m_options.m_ratio * p_vectorCount)); - if (headCnt == 0) + float limitDist = p_queryResults->GetResult(0)->Dist * m_options.m_maxDistRatio; + int i = 0; + for (; i < m_options.m_searchInternalResultNum; ++i) + { + auto res = p_queryResults->GetResult(i); + if (res->VID == -1 || (limitDist > 0.1 && res->Dist > limitDist)) + break; + if (m_extraSearcher->CheckValidPosting(res->VID)) { - for (double minCnt = 1; headCnt == 0; minCnt += 0.2) - { - m_options.m_ratio = minCnt / p_vectorCount; - headCnt = static_cast(std::round(m_options.m_ratio * p_vectorCount)); - } - - LOG(Helper::LogLevel::LL_Info, "Setting requires to select none vectors as head, adjusted it to %d vectors\n", headCnt); + workSpace->m_postingIDs.emplace_back(res->VID); } - - if (m_options.m_iBKTKmeansK > headCnt) + if (m_vectorTranslateMap.R() != 0) + res->VID = static_cast(*(m_vectorTranslateMap[res->VID])); + else { - m_options.m_iBKTKmeansK = headCnt; - LOG(Helper::LogLevel::LL_Info, "Setting of cluster number is less than head count, adjust it to %d\n", headCnt); + res->VID = -1; + res->Dist = MaxDist; } - - if (m_options.m_selectThreshold == 0) + if (res->VID == MaxSize) { - m_options.m_selectThreshold = min(p_vectorCount - 1, static_cast(1 / m_options.m_ratio)); - LOG(Helper::LogLevel::LL_Info, "Set SelectThreshold to %d\n", m_options.m_selectThreshold); + res->VID = -1; + res->Dist = MaxDist; } + } + + for (; i < p_queryResults->GetResultNum(); ++i) + { + auto res = p_queryResults->GetResult(i); + if (res->VID == -1) + break; - if (m_options.m_splitThreshold == 0) + if (m_vectorTranslateMap.R() != 0) + res->VID = static_cast(*(m_vectorTranslateMap[res->VID])); + else { - m_options.m_splitThreshold = min(p_vectorCount - 1, static_cast(m_options.m_selectThreshold * 2)); - LOG(Helper::LogLevel::LL_Info, "Set SplitThreshold to %d\n", m_options.m_splitThreshold); + res->VID = -1; + res->Dist = MaxDist; } - - if (m_options.m_splitFactor == 0) + if (res->VID == MaxSize) { - m_options.m_splitFactor = min(p_vectorCount - 1, static_cast(std::round(1 / m_options.m_ratio) + 0.5)); - LOG(Helper::LogLevel::LL_Info, "Set SplitFactor to %d\n", m_options.m_splitFactor); + res->VID = -1; + res->Dist = MaxDist; } } + p_queryResults->Reverse(); + if ((ret = m_extraSearcher->SearchIndex(workSpace.get(), *p_queryResults, m_index, nullptr)) != + ErrorCode::Success) + return ret; + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + p_queryResults->SortResult(); + } + + if (p_query.GetResultNum() < m_options.m_searchInternalResultNum) + { + std::copy(p_queryResults->GetResults(), p_queryResults->GetResults() + p_query.GetResultNum(), + p_query.GetResults()); + p_query.SetScanned(p_queryResults->GetScanned()); + delete p_queryResults; + } - template - int Index::SelectHeadDynamicallyInternal(const std::shared_ptr p_tree, int p_nodeID, - const Options& p_opts, std::vector& p_selected) + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < p_query.GetResultNum(); ++i) { - typedef std::pair CSPair; - std::vector children; - int childrenSize = 1; - const auto& node = (*p_tree)[p_nodeID]; - if (node.childStart >= 0) - { - children.reserve(node.childEnd - node.childStart); - for (int i = node.childStart; i < node.childEnd; ++i) - { - int cs = SelectHeadDynamicallyInternal(p_tree, i, p_opts, p_selected); - if (cs > 0) - { - children.emplace_back(i, cs); - childrenSize += cs; - } - } - } + SizeType result = p_query.GetResult(i)->VID; + // if (result > m_pMetadata->Count()) SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vid return is beyond the + // metadata set:(%d > (%d, %d))\n", result, GetNumSamples(), m_pMetadata->Count()); + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } + return ErrorCode::Success; +} - if (childrenSize >= p_opts.m_selectThreshold) - { - if (node.centerid < (*p_tree)[0].centerid) - { - p_selected.push_back(node.centerid); - } +template +ErrorCode Index::SearchIndexIterative(QueryResult &p_headQuery, QueryResult &p_query, + COMMON::WorkSpace *p_indexWorkspace, ExtraWorkSpace *p_extraWorkspace, + int p_batch, int &resultCount, bool first) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - if (childrenSize > p_opts.m_splitThreshold) - { - std::sort(children.begin(), children.end(), [](const CSPair& a, const CSPair& b) - { - return a.second > b.second; - }); + COMMON::QueryResultSet *p_headQueryResults = (COMMON::QueryResultSet *)&p_headQuery; + COMMON::QueryResultSet *p_queryResults = (COMMON::QueryResultSet *)&p_query; - size_t selectCnt = static_cast(std::ceil(childrenSize * 1.0 / p_opts.m_splitFactor) + 0.5); - //if (selectCnt > 1) selectCnt -= 1; - for (size_t i = 0; i < selectCnt && i < children.size(); ++i) - { - p_selected.push_back((*p_tree)[children[i].first].centerid); - } - } + if (first) + { + p_headQueryResults->SetResultNum(m_options.m_searchInternalResultNum); + p_headQueryResults->Reset(); + m_index->SearchIndexIterativeFromNeareast(*p_headQueryResults, p_indexWorkspace, true); + p_extraWorkspace->m_loadPosting = true; + } - return 0; - } + bool continueSearch = true; + resultCount = 0; + while (continueSearch && resultCount < p_batch) + { + bool oldRelaxedMono = p_extraWorkspace->m_relaxedMono; + ErrorCode ret = SearchDiskIndexIterative(p_headQuery, p_query, p_extraWorkspace); + bool found = (ret == ErrorCode::Success); + if (!found && ret != ErrorCode::VectorNotFound) return ret; + p_extraWorkspace->m_loadPosting = false; + if (!found) + { + p_headQueryResults->SetResultNum(m_options.m_headBatch); + p_headQueryResults->Reset(); + continueSearch = m_index->SearchIndexIterativeFromNeareast(*p_headQueryResults, p_indexWorkspace, false); + p_extraWorkspace->m_loadPosting = true; - return childrenSize; + if (!oldRelaxedMono && p_extraWorkspace->m_relaxedMono) + continueSearch = false; } + else + resultCount++; + } + p_queryResults->SortResult(); - template - void Index::SelectHeadDynamically(const std::shared_ptr p_tree, int p_vectorCount, std::vector& p_selected) { - p_selected.clear(); - p_selected.reserve(p_vectorCount); - - if (static_cast(std::round(m_options.m_ratio * p_vectorCount)) >= p_vectorCount) - { - for (int i = 0; i < p_vectorCount; ++i) - { - p_selected.push_back(i); - } + if (p_query.WithMeta() && nullptr != m_pMetadata) + { + for (int i = 0; i < resultCount; ++i) + { + SizeType result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadataCopy(result)); + } + } + return ErrorCode::Success; +} - return; - } - Options opts = m_options; +template +std::shared_ptr Index::GetIterator(const void *p_target, bool p_searchDeleted, std::function p_filterFunc, int p_maxCheck) const +{ + if (!m_bReady) + return nullptr; + auto extraWorkspace = m_workSpaceFactory->GetWorkSpace(); + if (!extraWorkspace) + { + extraWorkspace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(extraWorkspace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(extraWorkspace.get(), true); + } + extraWorkspace->m_filterFunc = p_filterFunc; + extraWorkspace->m_relaxedMono = false; + extraWorkspace->m_loadedPostingNum = 0; + extraWorkspace->m_deduper.clear(); + extraWorkspace->m_postingIDs.clear(); + std::shared_ptr resultIterator = std::make_shared>( + this, m_index.get(), p_target, std::move(extraWorkspace), + max(m_options.m_headBatch, m_options.m_searchInternalResultNum), p_maxCheck); + return resultIterator; +} - int selectThreshold = m_options.m_selectThreshold; - int splitThreshold = m_options.m_splitThreshold; +template +ErrorCode Index::SearchIndexIterativeNext(QueryResult &p_query, COMMON::WorkSpace *workSpace, int p_batch, + int &resultCount, bool p_isFirst, bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ITERATIVE NOT SUPPORT FOR SPANN"); + return ErrorCode::Undefined; +} - double minDiff = 100; - for (int select = 2; select <= m_options.m_selectThreshold; ++select) - { - opts.m_selectThreshold = select; - opts.m_splitThreshold = m_options.m_splitThreshold; +template ErrorCode Index::SearchIndexIterativeEnd(std::unique_ptr space) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SearchIndexIterativeEnd NOT SUPPORT FOR SPANN"); + return ErrorCode::Fail; +} - int l = m_options.m_splitFactor; - int r = m_options.m_splitThreshold; +template +ErrorCode Index::SearchIndexIterativeEnd(std::unique_ptr extraWorkspace) const +{ + if (!m_bReady) + return ErrorCode::EmptyIndex; - while (l < r - 1) - { - opts.m_splitThreshold = (l + r) / 2; - p_selected.clear(); + if (extraWorkspace != nullptr) + m_workSpaceFactory->ReturnWorkSpace(std::move(extraWorkspace)); - SelectHeadDynamicallyInternal(p_tree, 0, opts, p_selected); - std::sort(p_selected.begin(), p_selected.end()); - p_selected.erase(std::unique(p_selected.begin(), p_selected.end()), p_selected.end()); + return ErrorCode::Success; +} - double diff = static_cast(p_selected.size()) / p_vectorCount - m_options.m_ratio; +template +bool Index::SearchIndexIterativeFromNeareast(QueryResult &p_query, COMMON::WorkSpace *p_space, bool p_isFirst, + bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SearchIndexIterativeFromNeareast NOT SUPPORT FOR SPANN"); + return false; +} - LOG(Helper::LogLevel::LL_Info, - "Select Threshold: %d, Split Threshold: %d, diff: %.2lf%%.\n", - opts.m_selectThreshold, - opts.m_splitThreshold, - diff * 100.0); +template +ErrorCode Index::SearchIndexWithFilter(QueryResult &p_query, std::function filterFunc, + int maxCheck, bool p_searchDeleted) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Not Support Filter on SPANN Index!\n"); + return ErrorCode::Fail; +} - if (minDiff > fabs(diff)) - { - minDiff = fabs(diff); +template ErrorCode Index::SearchDiskIndex(QueryResult &p_query, SearchStats *p_stats) const +{ + if (nullptr == m_extraSearcher) + return ErrorCode::EmptyIndex; - selectThreshold = opts.m_selectThreshold; - splitThreshold = opts.m_splitThreshold; - } + COMMON::QueryResultSet *p_queryResults = (COMMON::QueryResultSet *)&p_query; - if (diff > 0) - { - l = (l + r) / 2; - } - else - { - r = (l + r) / 2; - } - } - } + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } - opts.m_selectThreshold = selectThreshold; - opts.m_splitThreshold = splitThreshold; + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); - LOG(Helper::LogLevel::LL_Info, - "Final Select Threshold: %d, Split Threshold: %d.\n", - opts.m_selectThreshold, - opts.m_splitThreshold); + float limitDist = p_queryResults->GetResult(0)->Dist * m_options.m_maxDistRatio; + int i = 0; + for (; i < m_options.m_searchInternalResultNum; ++i) + { + auto res = p_queryResults->GetResult(i); + if (res->VID == -1 || (limitDist > 0.1 && res->Dist > limitDist)) + break; + if (m_extraSearcher->CheckValidPosting(res->VID)) + { + workSpace->m_postingIDs.emplace_back(res->VID); + } - p_selected.clear(); - SelectHeadDynamicallyInternal(p_tree, 0, opts, p_selected); - std::sort(p_selected.begin(), p_selected.end()); - p_selected.erase(std::unique(p_selected.begin(), p_selected.end()), p_selected.end()); + if (m_vectorTranslateMap.R() != 0) + res->VID = static_cast(*(m_vectorTranslateMap[res->VID])); + else + { + res->VID = -1; + res->Dist = MaxDist; + } + if (res->VID == MaxSize) + { + res->VID = -1; + res->Dist = MaxDist; } + } - template - template - bool Index::SelectHeadInternal(std::shared_ptr& p_reader) { - std::shared_ptr vectorset = p_reader->GetVectorSet(); - if (m_options.m_distCalcMethod == DistCalcMethod::Cosine && !p_reader->IsNormalized()) - vectorset->Normalize(m_options.m_iSelectHeadNumberOfThreads); - LOG(Helper::LogLevel::LL_Info, "Begin initial data (%d,%d)...\n", vectorset->Count(), vectorset->Dimension()); - - COMMON::Dataset data(vectorset->Count(), vectorset->Dimension(), vectorset->Count(), vectorset->Count() + 1, (InternalDataType*)vectorset->GetData()); - - auto t1 = std::chrono::high_resolution_clock::now(); - SelectHeadAdjustOptions(data.R()); - std::vector selected; - if (data.R() == 1) { - selected.push_back(0); - } - else if (Helper::StrUtils::StrEqualIgnoreCase(m_options.m_selectType.c_str(), "Random")) { - LOG(Helper::LogLevel::LL_Info, "Start generating Random head.\n"); - selected.resize(data.R()); - for (int i = 0; i < data.R(); i++) selected[i] = i; - std::shuffle(selected.begin(), selected.end(), rg); - int headCnt = static_cast(std::round(m_options.m_ratio * data.R())); - selected.resize(headCnt); - } - else if (Helper::StrUtils::StrEqualIgnoreCase(m_options.m_selectType.c_str(), "BKT")) { - LOG(Helper::LogLevel::LL_Info, "Start generating BKT.\n"); - std::shared_ptr bkt = std::make_shared(); - bkt->m_iBKTKmeansK = m_options.m_iBKTKmeansK; - bkt->m_iBKTLeafSize = m_options.m_iBKTLeafSize; - bkt->m_iSamples = m_options.m_iSamples; - bkt->m_iTreeNumber = m_options.m_iTreeNumber; - bkt->m_fBalanceFactor = m_options.m_fBalanceFactor; - LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n"); - LOG(Helper::LogLevel::LL_Info, "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n", - bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber, m_options.m_iSelectHeadNumberOfThreads); - - bkt->BuildTrees(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads, nullptr, nullptr, true); - auto t2 = std::chrono::high_resolution_clock::now(); - double elapsedSeconds = std::chrono::duration_cast(t2 - t1).count(); - LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n"); - LOG(Helper::LogLevel::LL_Info, "Invoking BuildTrees used time: %.2lf minutes (about %.2lf hours).\n", elapsedSeconds / 60.0, elapsedSeconds / 3600.0); - - if (m_options.m_saveBKT) { - std::stringstream bktFileNameBuilder; - bktFileNameBuilder << m_options.m_vectorPath << ".bkt." << m_options.m_iBKTKmeansK << "_" - << m_options.m_iBKTLeafSize << "_" << m_options.m_iTreeNumber << "_" << m_options.m_iSamples << "_" - << static_cast(m_options.m_distCalcMethod) << ".bin"; - bkt->SaveTrees(bktFileNameBuilder.str()); - } - LOG(Helper::LogLevel::LL_Info, "Finish generating BKT.\n"); + for (; i < p_queryResults->GetResultNum(); ++i) + { + auto res = p_queryResults->GetResult(i); + if (res->VID == -1) + break; + if (m_vectorTranslateMap.R() != 0) + res->VID = static_cast(*(m_vectorTranslateMap[res->VID])); + else + { + res->VID = -1; + res->Dist = MaxDist; + } + if (res->VID == MaxSize) + { + res->VID = -1; + res->Dist = MaxDist; + } + } - LOG(Helper::LogLevel::LL_Info, "Start selecting nodes...Select Head Dynamically...\n"); - SelectHeadDynamically(bkt, data.R(), selected); + p_queryResults->Reverse(); + m_extraSearcher->SearchIndex(workSpace.get(), *p_queryResults, m_index, p_stats); + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + p_queryResults->SortResult(); + return ErrorCode::Success; +} - if (selected.empty()) { - LOG(Helper::LogLevel::LL_Error, "Can't select any vector as head with current settings\n"); - return false; - } - } +template +ErrorCode Index::SearchDiskIndexIterative(QueryResult &p_headQuery, QueryResult &p_query, + ExtraWorkSpace *extraWorkspace) const +{ + if (extraWorkspace->m_loadPosting) + { + COMMON::QueryResultSet *p__headQueryResults = (COMMON::QueryResultSet *)&p_headQuery; + // std::shared_ptr workSpace = m_workSpacePool->Rent(); + // workSpace->m_deduper.clear(); + extraWorkspace->m_postingIDs.clear(); - LOG(Helper::LogLevel::LL_Info, - "Seleted Nodes: %u, about %.2lf%% of total.\n", - static_cast(selected.size()), - selected.size() * 100.0 / data.R()); + // float limitDist = p_queryResults->GetResult(0)->Dist * m_options.m_maxDistRatio; - if (!m_options.m_noOutput) + for (int i = 0; i < p__headQueryResults->GetResultNum(); ++i) + { + auto res = p__headQueryResults->GetResult(i); + // break or continue + if (res->VID == -1) + break; + if (m_extraSearcher->CheckValidPosting(res->VID)) { - std::sort(selected.begin(), selected.end()); - - std::shared_ptr output = SPTAG::f_createIO(), outputIDs = SPTAG::f_createIO(); - if (output == nullptr || outputIDs == nullptr || - !output->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str(), std::ios::binary | std::ios::out) || - !outputIDs->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed to create output file:%s %s\n", - (m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str(), - (m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str()); - return false; - } - - SizeType val = static_cast(selected.size()); - if (output->WriteBinary(sizeof(val), reinterpret_cast(&val)) != sizeof(val)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - return false; - } - DimensionType dt = data.C(); - if (output->WriteBinary(sizeof(dt), reinterpret_cast(&dt)) != sizeof(dt)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - return false; - } - - for (int i = 0; i < selected.size(); i++) - { - uint64_t vid = static_cast(selected[i]); - if (outputIDs->WriteBinary(sizeof(vid), reinterpret_cast(&vid)) != sizeof(vid)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - return false; - } + extraWorkspace->m_postingIDs.emplace_back(res->VID); + } - if (output->WriteBinary(sizeof(InternalDataType) * data.C(), (char*)(data[vid])) != sizeof(InternalDataType) * data.C()) { - LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); - return false; - } - } + if (m_vectorTranslateMap.R() != 0) + res->VID = static_cast(*(m_vectorTranslateMap[res->VID])); + else + { + res->VID = -1; + res->Dist = MaxDist; + } + if (res->VID == MaxSize) + { + res->VID = -1; + res->Dist = MaxDist; } - auto t3 = std::chrono::high_resolution_clock::now(); - double elapsedSeconds = std::chrono::duration_cast(t3 - t1).count(); - LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", elapsedSeconds / 60.0, elapsedSeconds / 3600.0); - return true; } + extraWorkspace->m_loadedPostingNum += (int)(extraWorkspace->m_postingIDs.size()); + } - template - ErrorCode Index::BuildIndexInternal(std::shared_ptr& p_reader) { - if (!m_options.m_indexDirectory.empty()) { - if (!direxists(m_options.m_indexDirectory.c_str())) - { - mkdir(m_options.m_indexDirectory.c_str()); - } - } + ErrorCode ret = m_extraSearcher->SearchIterativeNext(extraWorkspace, p_headQuery, p_query, m_index, this); + if (ret == ErrorCode::VectorNotFound && extraWorkspace->m_loadedPostingNum >= m_options.m_searchInternalResultNum) + extraWorkspace->m_relaxedMono = true; + return ret; +} - LOG(Helper::LogLevel::LL_Info, "Begin Select Head...\n"); - auto t1 = std::chrono::high_resolution_clock::now(); - if (m_options.m_selectHead) { - omp_set_num_threads(m_options.m_iSelectHeadNumberOfThreads); - bool success = false; - if (m_pQuantizer) - { - success = SelectHeadInternal(p_reader); - } - else - { - success = SelectHeadInternal(p_reader); - } - if (!success) { - LOG(Helper::LogLevel::LL_Error, "SelectHead Failed!\n"); - return ErrorCode::Fail; - } - } - auto t2 = std::chrono::high_resolution_clock::now(); - double selectHeadTime = std::chrono::duration_cast(t2 - t1).count(); - LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs\n", selectHeadTime); +template std::unique_ptr Index::RentWorkSpace(int batch, std::function p_filterFunc, int p_maxCheck) const +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "RentWorkSpace NOT SUPPORT FOR SPANN"); + return nullptr; +} - LOG(Helper::LogLevel::LL_Info, "Begin Build Head...\n"); - if (m_options.m_buildHead) { - auto valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; - auto dims = m_pQuantizer ? m_pQuantizer->GetNumSubvectors() : m_options.m_dim; +template +ErrorCode Index::DebugSearchDiskIndex(QueryResult &p_query, int p_subInternalResultNum, int p_internalResultNum, + SearchStats *p_stats, std::set *truth, + std::map> *found) const +{ + if (nullptr == m_extraSearcher) + return ErrorCode::EmptyIndex; - m_index = SPTAG::VectorIndex::CreateInstance(m_options.m_indexAlgoType, valueType); - m_index->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(m_options.m_distCalcMethod)); - m_index->SetQuantizer(m_pQuantizer); - for (const auto& iter : m_headParameters) - { - m_index->SetParameter(iter.first.c_str(), iter.second.c_str()); - } + COMMON::QueryResultSet newResults(*((COMMON::QueryResultSet *)&p_query)); + for (int i = 0; i < newResults.GetResultNum(); ++i) + { + auto res = newResults.GetResult(i); + if (res->VID == -1) + break; + + auto global_VID = -1; + if (m_vectorTranslateMap.R() != 0) + global_VID = static_cast(*(m_vectorTranslateMap[res->VID])); + if (truth && truth->count(global_VID)) + (*found)[res->VID].insert(global_VID); + res->VID = global_VID; + } + newResults.Reverse(); - std::shared_ptr vectorOptions(new Helper::ReaderOptions(valueType, dims, VectorFileType::DEFAULT)); - auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); - if (ErrorCode::Success != vectorReader->LoadFile(m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile)) - { - LOG(Helper::LogLevel::LL_Error, "Failed to read head vector file.\n"); - return ErrorCode::Fail; - } - if (m_index->BuildIndex(vectorReader->GetVectorSet(), nullptr, false, true) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to build head index.\n"); - return ErrorCode::Fail; - } - m_index->SetQuantizerFileName(m_options.m_quantizerFilePath.substr(m_options.m_quantizerFilePath.find_last_of("/\\") + 1)); - if (m_index->SaveIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to save head index.\n"); - return ErrorCode::Fail; - } - m_index.reset(); - if (LoadIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder, m_index) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Cannot load head index from %s!\n", (m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder).c_str()); - } - } - auto t3 = std::chrono::high_resolution_clock::now(); - double buildHeadTime = std::chrono::duration_cast(t3 - t2).count(); - LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs build head time: %.2lfs\n", selectHeadTime, buildHeadTime); + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); - LOG(Helper::LogLevel::LL_Info, "Begin Build SSDIndex...\n"); - if (m_options.m_enableSSD) { - omp_set_num_threads(m_options.m_iSSDNumberOfThreads); + int partitions = (p_internalResultNum + p_subInternalResultNum - 1) / p_subInternalResultNum; + float limitDist = p_query.GetResult(0)->Dist * m_options.m_maxDistRatio; + for (SizeType p = 0; p < partitions; p++) + { + int subInternalResultNum = min(p_subInternalResultNum, p_internalResultNum - p_subInternalResultNum * p); - if (m_index == nullptr && LoadIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder, m_index) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Cannot load head index from %s!\n", (m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder).c_str()); - return ErrorCode::Fail; - } - m_index->SetQuantizer(m_pQuantizer); - if (!CheckHeadIndexType()) return ErrorCode::Fail; + workSpace->m_postingIDs.clear(); - m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); - m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); - m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); - m_index->UpdateIndex(); + for (int i = p * p_subInternalResultNum; i < p * p_subInternalResultNum + subInternalResultNum; i++) + { + auto res = p_query.GetResult(i); + if (res->VID == -1 || (limitDist > 0.1 && res->Dist > limitDist)) + break; + if (!m_extraSearcher->CheckValidPosting(res->VID)) + continue; + workSpace->m_postingIDs.emplace_back(res->VID); + } - if (m_pQuantizer) - { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); - } - else { - m_extraSearcher.reset(new ExtraFullGraphSearcher()); - } + m_extraSearcher->SearchIndex(workSpace.get(), newResults, m_index, p_stats, truth, found); + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + newResults.SortResult(); + std::copy(newResults.GetResults(), newResults.GetResults() + newResults.GetResultNum(), p_query.GetResults()); + return ErrorCode::Success; +} +#pragma endregion - if (m_options.m_buildSsdIndex) { - if (!m_options.m_excludehead) { - LOG(Helper::LogLevel::LL_Info, "Include all vectors into SSD index...\n"); - std::shared_ptr ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed to open headIDFile file:%s for overwrite\n", (m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str()); - return ErrorCode::Fail; - } - std::uint64_t vid = (std::uint64_t)MaxSize; - for (int i = 0; i < m_index->GetNumSamples(); i++) { - IOBINARY(ptr, WriteBinary, sizeof(std::uint64_t), (char*)(&vid)); - } - } +template +ErrorCode Index::GetPostingDebug(SizeType vid, std::vector &VIDs, std::shared_ptr &vecs) +{ + VIDs.clear(); + if (!m_extraSearcher) + return ErrorCode::EmptyIndex; + if (!m_extraSearcher->CheckValidPosting(vid)) + return ErrorCode::Fail; + + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); - if (!m_extraSearcher->BuildIndex(p_reader, m_index, m_options)) { - LOG(Helper::LogLevel::LL_Error, "BuildSSDIndex Failed!\n"); - return ErrorCode::Fail; - } - } - if (!m_extraSearcher->LoadIndex(m_options)) { - LOG(Helper::LogLevel::LL_Error, "Cannot Load SSDIndex!\n"); - return ErrorCode::Fail; - } + auto out = m_extraSearcher->GetPostingDebug(workSpace.get(), m_index, vid, VIDs, vecs); + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + return out; +} - m_vectorTranslateMap.reset(new std::uint64_t[m_index->GetNumSamples()], std::default_delete()); - std::shared_ptr ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to open headIDFile file:%s\n", (m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str()); - return ErrorCode::Fail; - } - IOBINARY(ptr, ReadBinary, sizeof(std::uint64_t) * m_index->GetNumSamples(), (char*)(m_vectorTranslateMap.get())); +template void Index::SelectHeadAdjustOptions(int p_vectorCount) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Adjust Parameters...\n"); + + if (m_options.m_headVectorCount != 0) + m_options.m_ratio = m_options.m_headVectorCount * 1.0 / p_vectorCount; + int headCnt = static_cast(std::round(m_options.m_ratio * p_vectorCount)); + if (headCnt == 0) + { + for (double minCnt = 1; headCnt == 0; minCnt += 0.2) + { + m_options.m_ratio = minCnt / p_vectorCount; + headCnt = static_cast(std::round(m_options.m_ratio * p_vectorCount)); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "Setting requires to select none vectors as head, adjusted it to %d vectors\n", headCnt); + } + + if (m_options.m_iBKTKmeansK > headCnt) + { + m_options.m_iBKTKmeansK = headCnt; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Setting of cluster number is less than head count, adjust it to %d\n", + headCnt); + } + + if (m_options.m_selectThreshold == 0) + { + m_options.m_selectThreshold = min(p_vectorCount - 1, static_cast(1 / m_options.m_ratio)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SelectThreshold to %d\n", m_options.m_selectThreshold); + } + + if (m_options.m_splitThreshold == 0) + { + m_options.m_splitThreshold = min(p_vectorCount - 1, static_cast(m_options.m_selectThreshold * 2)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SplitThreshold to %d\n", m_options.m_splitThreshold); + } + + if (m_options.m_splitFactor == 0) + { + m_options.m_splitFactor = min(p_vectorCount - 1, static_cast(std::round(1 / m_options.m_ratio) + 0.5)); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set SplitFactor to %d\n", m_options.m_splitFactor); + } +} + +template +int Index::SelectHeadDynamicallyInternal(const std::shared_ptr p_tree, int p_nodeID, + const Options &p_opts, std::vector &p_selected) +{ + typedef std::pair CSPair; + std::vector children; + int childrenSize = 1; + const auto &node = (*p_tree)[p_nodeID]; + if (node.childStart >= 0) + { + children.reserve(node.childEnd - node.childStart); + for (int i = node.childStart; i < node.childEnd; ++i) + { + int cs = SelectHeadDynamicallyInternal(p_tree, i, p_opts, p_selected); + if (cs > 0) + { + children.emplace_back(i, cs); + childrenSize += cs; } - auto t4 = std::chrono::high_resolution_clock::now(); - double buildSSDTime = std::chrono::duration_cast(t4 - t3).count(); - LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs build head time: %.2lfs build ssd time: %.2lfs\n", selectHeadTime, buildHeadTime, buildSSDTime); + } + } - if (m_options.m_deleteHeadVectors) { - if (fileexists((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str()) && - remove((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str()) != 0) { - LOG(Helper::LogLevel::LL_Warning, "Head vector file can't be removed.\n"); - } + if (childrenSize >= p_opts.m_selectThreshold) + { + if (node.centerid < (*p_tree)[0].centerid) + { + p_selected.push_back(node.centerid); + } + + if (childrenSize > p_opts.m_splitThreshold) + { + std::sort(children.begin(), children.end(), + [](const CSPair &a, const CSPair &b) { return a.second > b.second; }); + + size_t selectCnt = static_cast(std::ceil(childrenSize * 1.0 / p_opts.m_splitFactor) + 0.5); + // if (selectCnt > 1) selectCnt -= 1; + for (size_t i = 0; i < selectCnt && i < children.size(); ++i) + { + p_selected.push_back((*p_tree)[children[i].first].centerid); } + } + + return 0; + } - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); - m_bReady = true; - return ErrorCode::Success; + return childrenSize; +} + +template +void Index::SelectHeadDynamically(const std::shared_ptr p_tree, int p_vectorCount, + std::vector &p_selected) +{ + p_selected.clear(); + p_selected.reserve(p_vectorCount); + + if (static_cast(std::round(m_options.m_ratio * p_vectorCount)) >= p_vectorCount) + { + for (int i = 0; i < p_vectorCount; ++i) + { + p_selected.push_back(i); } - template - ErrorCode Index::BuildIndex(bool p_normalized) + + return; + } + Options opts = m_options; + + int selectThreshold = m_options.m_selectThreshold; + int splitThreshold = m_options.m_splitThreshold; + + double minDiff = 100; + for (int select = 2; select <= m_options.m_selectThreshold; ++select) + { + opts.m_selectThreshold = select; + opts.m_splitThreshold = m_options.m_splitThreshold; + + int l = m_options.m_splitFactor; + int r = m_options.m_splitThreshold; + + while (l < r - 1) { - SPTAG::VectorValueType valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; - SizeType dim = m_pQuantizer ? m_pQuantizer->GetNumSubvectors() : m_options.m_dim; - std::shared_ptr vectorOptions(new Helper::ReaderOptions(valueType, dim, m_options.m_vectorType, m_options.m_vectorDelimiter, m_options.m_iSSDNumberOfThreads, p_normalized)); - auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); - if (m_options.m_vectorPath.empty()) + opts.m_splitThreshold = (l + r) / 2; + p_selected.clear(); + + SelectHeadDynamicallyInternal(p_tree, 0, opts, p_selected); + std::sort(p_selected.begin(), p_selected.end()); + p_selected.erase(std::unique(p_selected.begin(), p_selected.end()), p_selected.end()); + + double diff = static_cast(p_selected.size()) / p_vectorCount - m_options.m_ratio; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Select Threshold: %d, Split Threshold: %d, diff: %.2lf%%.\n", + opts.m_selectThreshold, opts.m_splitThreshold, diff * 100.0); + + if (minDiff > fabs(diff)) { - LOG(Helper::LogLevel::LL_Info, "Vector file is empty. Skipping loading.\n"); + minDiff = fabs(diff); + + selectThreshold = opts.m_selectThreshold; + splitThreshold = opts.m_splitThreshold; } - else { - if (ErrorCode::Success != vectorReader->LoadFile(m_options.m_vectorPath)) - { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); - return ErrorCode::Fail; - } - m_options.m_vectorSize = vectorReader->GetVectorSet()->Count(); + + if (diff > 0) + { + l = (l + r) / 2; + } + else + { + r = (l + r) / 2; } - - return BuildIndexInternal(vectorReader); } + } - template - ErrorCode Index::BuildIndex(const void* p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, bool p_shareOwnership) + opts.m_selectThreshold = selectThreshold; + opts.m_splitThreshold = splitThreshold; + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Final Select Threshold: %d, Split Threshold: %d.\n", + opts.m_selectThreshold, opts.m_splitThreshold); + + p_selected.clear(); + SelectHeadDynamicallyInternal(p_tree, 0, opts, p_selected); + std::sort(p_selected.begin(), p_selected.end()); + p_selected.erase(std::unique(p_selected.begin(), p_selected.end()), p_selected.end()); +} + +template +template +bool Index::SelectHeadInternal(std::shared_ptr &p_reader) +{ + std::shared_ptr vectorset = p_reader->GetVectorSet(); + if (m_options.m_distCalcMethod == DistCalcMethod::Cosine && !p_reader->IsNormalized()) + vectorset->Normalize(m_options.m_iSelectHeadNumberOfThreads); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin initial data (%d,%d)...\n", vectorset->Count(), + vectorset->Dimension()); + + COMMON::Dataset data(vectorset->Count(), vectorset->Dimension(), vectorset->Count(), + vectorset->Count() + 1, (InternalDataType *)vectorset->GetData()); + + auto t1 = std::chrono::high_resolution_clock::now(); + SelectHeadAdjustOptions(data.R()); + std::vector selected; + if (data.R() == 1) + { + selected.push_back(0); + } + else if (Helper::StrUtils::StrEqualIgnoreCase(m_options.m_selectType.c_str(), "Random")) + { + std::mt19937 rg; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating Random head.\n"); + selected.resize(data.R()); + for (int i = 0; i < data.R(); i++) + selected[i] = i; + std::shuffle(selected.begin(), selected.end(), rg); + int headCnt = static_cast(std::round(m_options.m_ratio * data.R())); + selected.resize(headCnt); + } + else if (Helper::StrUtils::StrEqualIgnoreCase(m_options.m_selectType.c_str(), "BKT")) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating BKT.\n"); + std::shared_ptr bkt = std::make_shared(); + bkt->m_iBKTKmeansK = m_options.m_iBKTKmeansK; + bkt->m_iBKTLeafSize = m_options.m_iBKTLeafSize; + bkt->m_iSamples = m_options.m_iSamples; + bkt->m_iTreeNumber = m_options.m_iTreeNumber; + bkt->m_fBalanceFactor = m_options.m_fBalanceFactor; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n"); + SPTAGLIB_LOG( + Helper::LogLevel::LL_Info, + "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n", + bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber, + m_options.m_iSelectHeadNumberOfThreads); + + bkt->BuildTrees(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads, + nullptr, nullptr, true); + auto t2 = std::chrono::high_resolution_clock::now(); + double elapsedSeconds = std::chrono::duration_cast(t2 - t1).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Invoking BuildTrees used time: %.2lf minutes (about %.2lf hours).\n", + elapsedSeconds / 60.0, elapsedSeconds / 3600.0); + + if (m_options.m_saveBKT) + { + std::stringstream bktFileNameBuilder; + bktFileNameBuilder << m_options.m_vectorPath << ".bkt." << m_options.m_iBKTKmeansK << "_" + << m_options.m_iBKTLeafSize << "_" << m_options.m_iTreeNumber << "_" + << m_options.m_iSamples << "_" << static_cast(m_options.m_distCalcMethod) << ".bin"; + bkt->SaveTrees(bktFileNameBuilder.str()); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Finish generating BKT.\n"); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start selecting nodes...Select Head Dynamically...\n"); + SelectHeadDynamically(bkt, data.R(), selected); + + if (selected.empty()) { - if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) return ErrorCode::EmptyData; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Can't select any vector as head with current settings\n"); + return false; + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Seleted Nodes: %u, about %.2lf%% of total.\n", + static_cast(selected.size()), selected.size() * 100.0 / data.R()); - std::shared_ptr vectorSet; - if (p_shareOwnership) { - vectorSet.reset(new BasicVectorSet(ByteArray((std::uint8_t*)p_data, sizeof(T) * p_vectorNum * p_dimension, false), - GetEnumValueType(), p_dimension, p_vectorNum)); + if (!m_options.m_noOutput) + { + std::sort(selected.begin(), selected.end()); + + std::shared_ptr output = SPTAG::f_createIO(), outputIDs = SPTAG::f_createIO(); + if (output == nullptr || outputIDs == nullptr || + !output->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str(), + std::ios::binary | std::ios::out) || + !outputIDs->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str(), + std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to create output file:%s %s\n", + (m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str(), + (m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str()); + return false; + } + + SizeType val = static_cast(selected.size()); + if (output->WriteBinary(sizeof(val), reinterpret_cast(&val)) != sizeof(val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return false; + } + if (outputIDs->WriteBinary(sizeof(val), reinterpret_cast(&val)) != sizeof(val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write outputID file!\n"); + return false; + } + DimensionType dt = data.C(); + if (output->WriteBinary(sizeof(dt), reinterpret_cast(&dt)) != sizeof(dt)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return false; + } + dt = 1; + if (outputIDs->WriteBinary(sizeof(dt), reinterpret_cast(&dt)) != sizeof(dt)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write outputID file!\n"); + return false; + } + for (int i = 0; i < selected.size(); i++) + { + uint64_t vid = static_cast(selected[i]); + if (outputIDs->WriteBinary(sizeof(vid), reinterpret_cast(&vid)) != sizeof(vid)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return false; } - else { - ByteArray arr = ByteArray::Alloc(sizeof(T) * p_vectorNum * p_dimension); - memcpy(arr.Data(), p_data, sizeof(T) * p_vectorNum * p_dimension); - vectorSet.reset(new BasicVectorSet(arr, GetEnumValueType(), p_dimension, p_vectorNum)); + + if (output->WriteBinary(sizeof(InternalDataType) * data.C(), (char *)(data[vid])) != + sizeof(InternalDataType) * data.C()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write output file!\n"); + return false; } + } + } + auto t3 = std::chrono::high_resolution_clock::now(); + double elapsedSeconds = std::chrono::duration_cast(t3 - t1).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Total used time: %.2lf minutes (about %.2lf hours).\n", + elapsedSeconds / 60.0, elapsedSeconds / 3600.0); + return true; +} + +template ErrorCode Index::BuildIndexInternal(std::shared_ptr &p_reader) +{ + if (!(m_options.m_indexDirectory.empty()) && !(direxists(m_options.m_indexDirectory.c_str()))) + { + mkdir(m_options.m_indexDirectory.c_str()); + } + if (!(m_options.m_persistentBufferPath.empty()) && !(direxists(m_options.m_persistentBufferPath.c_str()))) + { + mkdir(m_options.m_persistentBufferPath.c_str()); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Select Head...\n"); + auto t1 = std::chrono::high_resolution_clock::now(); + if (m_options.m_selectHead) + { + bool success = false; + if (m_pQuantizer) + { + success = SelectHeadInternal(p_reader); + } + else + { + success = SelectHeadInternal(p_reader); + } + if (!success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SelectHead Failed!\n"); + return ErrorCode::Fail; + } + } + auto t2 = std::chrono::high_resolution_clock::now(); + double selectHeadTime = std::chrono::duration_cast(t2 - t1).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs\n", selectHeadTime); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Build Head...\n"); + if (m_options.m_buildHead) + { + auto valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; + auto dims = m_pQuantizer ? m_pQuantizer->GetNumSubvectors() : m_options.m_dim; - if (m_options.m_distCalcMethod == DistCalcMethod::Cosine && !p_normalized) { - vectorSet->Normalize(m_options.m_iSSDNumberOfThreads); + m_index = SPTAG::VectorIndex::CreateInstance(m_options.m_indexAlgoType, valueType); + m_index->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(m_options.m_distCalcMethod)); + m_index->SetQuantizer(m_pQuantizer); + for (const auto &iter : m_headParameters) + { + m_index->SetParameter(iter.first.c_str(), iter.second.c_str()); + } + + std::shared_ptr vectorOptions( + new Helper::ReaderOptions(valueType, dims, VectorFileType::DEFAULT)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (ErrorCode::Success != + vectorReader->LoadFile(m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read head vector file.\n"); + return ErrorCode::Fail; + } + { + auto headvectorset = vectorReader->GetVectorSet(); + if (m_index->BuildIndex(headvectorset, nullptr, false, true, true) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to build head index.\n"); + return ErrorCode::Fail; } - SPTAG::VectorValueType valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; - std::shared_ptr vectorReader(new Helper::MemoryVectorReader(std::make_shared(valueType, p_dimension, VectorFileType::DEFAULT, m_options.m_vectorDelimiter, m_options.m_iSSDNumberOfThreads, true), - vectorSet)); - - m_options.m_valueType = GetEnumValueType(); - m_options.m_dim = p_dimension; - m_options.m_vectorSize = p_vectorNum; - return BuildIndexInternal(vectorReader); + if (!m_options.m_quantizerFilePath.empty()) + m_index->SetQuantizerFileName( + m_options.m_quantizerFilePath.substr(m_options.m_quantizerFilePath.find_last_of("/\\") + 1)); + if (m_index->SaveIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save head index.\n"); + return ErrorCode::Fail; + } + } + m_index.reset(); + if (LoadIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder, m_index) != + ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot load head index from %s!\n", + (m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder).c_str()); } + } + auto t3 = std::chrono::high_resolution_clock::now(); + double buildHeadTime = std::chrono::duration_cast(t3 - t2).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs build head time: %.2lfs\n", selectHeadTime, + buildHeadTime); - template - ErrorCode - Index::UpdateIndex() + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Build SSDIndex...\n"); + if (m_options.m_enableSSD) + { + if (m_index == nullptr && LoadIndex(m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder, + m_index) != ErrorCode::Success) { - omp_set_num_threads(m_options.m_iSSDNumberOfThreads); - m_index->UpdateIndex(); - m_workSpacePool.reset(new COMMON::WorkSpacePool()); - m_workSpacePool->Init(m_options.m_iSSDNumberOfThreads, m_options.m_maxCheck, m_options.m_hashExp, m_options.m_searchInternalResultNum, max(m_options.m_postingPageLimit, m_options.m_searchPostingPageLimit + 1) << PageSizeEx, int(m_options.m_enableDataCompression)); - return ErrorCode::Success; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot load head index from %s!\n", + (m_options.m_indexDirectory + FolderSep + m_options.m_headIndexFolder).c_str()); + return ErrorCode::Fail; } + m_index->SetQuantizer(m_pQuantizer); + if (!CheckHeadIndexType()) + return ErrorCode::Fail; - template - ErrorCode - Index::SetParameter(const char* p_param, const char* p_value, const char* p_section) + m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); + m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); + m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); + + m_index->UpdateIndex(); + + if (m_options.m_storage == Storage::STATIC) + { + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraStaticSearcher()); + else + m_extraSearcher.reset(new ExtraStaticSearcher()); + } + else { - if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_section, "BuildHead") && !SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "isExecute")) { - if (m_index != nullptr) return m_index->SetParameter(p_param, p_value); - else m_headParameters[p_param] = p_value; + if (m_pQuantizer) + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + else + m_extraSearcher.reset(new ExtraDynamicSearcher(m_options)); + } + + { + std::shared_ptr ptr = SPTAG::f_createIO(); + if (ptr == nullptr || + !ptr->Initialize((m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str(), + std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open headIDFile file:%s\n", + (m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile).c_str()); + return ErrorCode::Fail; } - else { - m_options.SetParameter(p_section, p_param, p_value); + m_vectorTranslateMap.Load(ptr, m_index->m_iDataBlockSize, m_index->m_iDataCapacity); + } + + if (m_options.m_buildSsdIndex) + { + if (m_options.m_storage != Storage::STATIC && !m_extraSearcher->Available()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Extrasearcher is not available and failed to initialize.\n"); + return ErrorCode::Fail; } - if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) { - if (m_pQuantizer) - { - m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_options.m_distCalcMethod); - m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() : 1; + if (!m_extraSearcher->BuildIndex(p_reader, m_index, m_options, m_versionMap, m_vectorTranslateMap)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "BuildSSDIndex Failed!\n"); + return ErrorCode::Fail; + } + + if (!m_options.m_excludehead) + { + std::uint64_t vid = (std::uint64_t)MaxSize; + for (int i = 0; i < m_vectorTranslateMap.R(); i++) { + *(m_vectorTranslateMap[i]) = vid; } - else + m_vectorTranslateMap.Save(m_options.m_indexDirectory + FolderSep + m_options.m_headIDFile); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Include all vectors into SSD index...\n"); + } + } + + if (!m_extraSearcher->LoadIndex(m_options, m_versionMap, m_vectorTranslateMap, m_index)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot Load SSDIndex!\n"); + if (m_options.m_buildSsdIndex) + { + return ErrorCode::Fail; + } + else + { + m_extraSearcher.reset(); + } + } + + if (m_extraSearcher != nullptr) + { + if ((m_options.m_storage != Storage::STATIC) && m_options.m_preReassign) + { + if (m_extraSearcher->RefineIndex(m_index) != ErrorCode::Success) + return ErrorCode::Fail; + } + } + } + + auto t4 = std::chrono::high_resolution_clock::now(); + double buildSSDTime = std::chrono::duration_cast(t4 - t3).count(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "select head time: %.2lfs build head time: %.2lfs build ssd time: %.2lfs\n", + selectHeadTime, buildHeadTime, buildSSDTime); + + if (m_options.m_deleteHeadVectors) + { + if (fileexists((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str()) && + remove((m_options.m_indexDirectory + FolderSep + m_options.m_headVectorFile).c_str()) != 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Head vector file can't be removed.\n"); + } + } + + m_bReady = true; + return ErrorCode::Success; +} +template ErrorCode Index::BuildIndex(bool p_normalized) +{ + SPTAG::VectorValueType valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; + SizeType dim = m_pQuantizer ? m_pQuantizer->GetNumSubvectors() : m_options.m_dim; + std::shared_ptr vectorOptions( + new Helper::ReaderOptions(valueType, dim, m_options.m_vectorType, m_options.m_vectorDelimiter, + m_options.m_iSSDNumberOfThreads, p_normalized)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (m_options.m_vectorPath.empty()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Vector file is empty. Skipping loading.\n"); + } + else + { + if (ErrorCode::Success != vectorReader->LoadFile(m_options.m_vectorPath)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + return ErrorCode::Fail; + } + m_options.m_vectorSize = vectorReader->GetVectorSet()->Count(); + } + return BuildIndexInternal(vectorReader); +} + +template +ErrorCode Index::BuildIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, bool p_normalized, + bool p_shareOwnership) +{ + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; + + std::shared_ptr vectorSet; + if (p_shareOwnership) + { + vectorSet.reset( + new BasicVectorSet(ByteArray((std::uint8_t *)p_data, sizeof(T) * p_vectorNum * p_dimension, false), + GetEnumValueType(), p_dimension, p_vectorNum)); + } + else + { + ByteArray arr = ByteArray::Alloc(sizeof(T) * p_vectorNum * p_dimension); + memcpy(arr.Data(), p_data, sizeof(T) * p_vectorNum * p_dimension); + vectorSet.reset(new BasicVectorSet(arr, GetEnumValueType(), p_dimension, p_vectorNum)); + } + + if (m_options.m_distCalcMethod == DistCalcMethod::Cosine && !p_normalized) + { + vectorSet->Normalize(m_options.m_iSSDNumberOfThreads); + } + SPTAG::VectorValueType valueType = m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_options.m_valueType; + std::shared_ptr vectorReader(new Helper::MemoryVectorReader( + std::make_shared(valueType, p_dimension, VectorFileType::DEFAULT, + m_options.m_vectorDelimiter, m_options.m_iSSDNumberOfThreads, true), + vectorSet)); + + m_options.m_valueType = GetEnumValueType(); + m_options.m_dim = p_dimension; + m_options.m_vectorSize = p_vectorNum; + return BuildIndexInternal(vectorReader); +} + +template ErrorCode Index::UpdateIndex() +{ + m_index->SetParameter("NumberOfThreads", std::to_string(m_options.m_iSSDNumberOfThreads)); + // m_index->SetParameter("MaxCheck", std::to_string(m_options.m_maxCheck)); + // m_index->SetParameter("HashTableExponent", std::to_string(m_options.m_hashExp)); + m_index->UpdateIndex(); + return ErrorCode::Success; +} + +template +ErrorCode Index::RefineIndex(const std::vector> &p_indexStreams, + IAbortOperation *p_abort, std::vector *p_mapping) +{ + if (m_index == nullptr || m_versionMap.Count() == 0) + return ErrorCode::EmptyIndex; + + while (!AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + std::lock_guard lock(m_dataAddLock); + std::unique_lock uniquelock(m_dataDeleteLock); + + std::vector headOldtoNew; + ErrorCode ret; + + if ((ret = m_index->RefineIndex(p_indexStreams, nullptr, &headOldtoNew)) != ErrorCode::Success) + return ret; + + std::vector OldtoNew; + std::vector NewtoOld; + SizeType newR = m_versionMap.Count(); + if (p_mapping == nullptr) p_mapping = &OldtoNew; + p_mapping->resize(newR); + + for (SizeType i = 0; i < newR; i++) + { + if (!m_versionMap.Deleted(i)) + { + NewtoOld.push_back(i); + (*p_mapping)[i] = i; + } + else + { + while (m_versionMap.Deleted(newR - 1) && newR > i) + newR--; + if (newR == i) + break; + NewtoOld.push_back(newR - 1); + (*p_mapping)[newR - 1] = i; + newR--; + } + } + + COMMON::Dataset new_vectorTranslateMap(m_index->GetNumSamples() - m_index->GetNumDeleted(), 1, + m_index->m_iDataBlockSize, m_index->m_iDataCapacity); + for (int i = 0; i < m_vectorTranslateMap.R(); i++) + { + if (m_index->ContainSample(i)) + { + auto oldID = *(m_vectorTranslateMap[i]); + if (oldID == MaxSize) + { + // Special case: including head vectors in postings means map all IDs to MaxSize + *(new_vectorTranslateMap[headOldtoNew[i]]) = oldID; + } + else if (oldID >= p_mapping->size()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPANNIndex::RefineIndex: Vector %d with old ID %llu cannot be mapped! p_mapping size: %llu.\n", i, oldID, p_mapping->size()); + } + else { + if (m_versionMap.Deleted(oldID)) { - m_fComputeDistance = COMMON::DistanceCalcSelector(m_options.m_distCalcMethod); - m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() : 1; + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "SPANNIndex::RefineIndex: Vector %d with old ID %llu is deleted in disk index, but still in head index!\n", i, oldID); } + *(new_vectorTranslateMap[headOldtoNew[i]]) = (*p_mapping)[oldID]; } - return ErrorCode::Success; } + } + new_vectorTranslateMap.Save(p_indexStreams[m_index->GetIndexFiles()->size()]); + + COMMON::VersionLabel new_versionMap; + new_versionMap.Initialize(newR, m_index->m_iDataBlockSize, m_index->m_iDataCapacity); + for (SizeType i = 0; i < newR; i++) + new_versionMap.SetVersion(i, m_versionMap.GetVersion(NewtoOld[i])); + new_versionMap.Save(m_options.m_indexDirectory + FolderSep + m_options.m_deleteIDFile); + + if ((ret = m_extraSearcher->RefineIndex(m_index, false, &headOldtoNew, p_mapping)) != ErrorCode::Success) + return ret; + + if (nullptr != m_pMetadata) + { + if (p_indexStreams.size() < GetIndexFiles()->size() + 2) + return ErrorCode::LackOfInputs; + if ((ret = m_pMetadata->RefineMetadata(NewtoOld, p_indexStreams[GetIndexFiles()->size()], + p_indexStreams[GetIndexFiles()->size()+1])) != + ErrorCode::Success) + return ret; + } + return ret; +} + +template ErrorCode Index::SetParameter(const char *p_param, const char *p_value, const char *p_section) +{ + if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_section, "BuildHead") && + !SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "isExecute")) + { + if (m_index != nullptr) + return m_index->SetParameter(p_param, p_value); + else + m_headParameters[p_param] = p_value; + } + else + { + m_options.SetParameter(p_section, p_param, p_value); + } + if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "DistCalcMethod")) + { + if (m_pQuantizer) + { + m_fComputeDistance = m_pQuantizer->DistanceCalcSelector(m_options.m_distCalcMethod); + m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) + ? m_pQuantizer->GetBase() * m_pQuantizer->GetBase() + : 1; + } + else + { + m_fComputeDistance = COMMON::DistanceCalcSelector(m_options.m_distCalcMethod); + m_iBaseSquare = (m_options.m_distCalcMethod == DistCalcMethod::Cosine) + ? COMMON::Utils::GetBase() * COMMON::Utils::GetBase() + : 1; + } + } + return ErrorCode::Success; +} + +template std::string Index::GetParameter(const char *p_param, const char *p_section) const +{ + if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_section, "BuildHead") && + !SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "isExecute")) + { + if (m_index != nullptr) + return m_index->GetParameter(p_param); + else + { + auto iter = m_headParameters.find(p_param); + if (iter != m_headParameters.end()) + return iter->second; + return "Undefined!"; + } + } + else + { + return m_options.GetParameter(p_section, p_param); + } +} + +// Add insert entry to persistent buffer +template +ErrorCode Index::AddIndex(const void *p_data, SizeType p_vectorNum, DimensionType p_dimension, + std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) +{ + if ((m_options.m_storage == Storage::STATIC) || m_extraSearcher == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Only Support KV Extra Update\n"); + return ErrorCode::Fail; + } + + if (p_data == nullptr || p_vectorNum == 0 || p_dimension == 0) + return ErrorCode::EmptyData; + if (p_dimension != GetFeatureDim()) + return ErrorCode::DimensionSizeMismatch; + SizeType begin, end; + { + std::lock_guard lock(m_dataAddLock); + + begin = m_versionMap.GetVectorNum(); + end = begin + p_vectorNum; - template - std::string - Index::GetParameter(const char* p_param, const char* p_section) const + if (begin == 0) { - if (SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_section, "BuildHead") && !SPTAG::Helper::StrUtils::StrEqualIgnoreCase(p_param, "isExecute")) { - if (m_index != nullptr) return m_index->GetParameter(p_param); - else { - auto iter = m_headParameters.find(p_param); - if (iter != m_headParameters.end()) return iter->second; - return "Undefined!"; + return ErrorCode::EmptyIndex; + } + + if (m_versionMap.AddBatch(p_vectorNum) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "MemoryOverFlow: VID: %d, Map Size:%d\n", begin, + m_versionMap.BufferSize()); + return ErrorCode::MemoryOverFlow; + } + + if (m_pMetadata != nullptr) + { + if (p_metadataSet != nullptr) + { + m_pMetadata->AddBatch(*p_metadataSet); + if (HasMetaMapping()) + { + for (SizeType i = begin; i < end; i++) + { + ByteArray meta = m_pMetadata->GetMetadata(i); + std::string metastr((char *)meta.Data(), meta.Length()); + UpdateMetaMapping(metastr, i); + } } } - else { - return m_options.GetParameter(p_section, p_param); + else + { + for (SizeType i = begin; i < end; i++) + m_pMetadata->Add(ByteArray::c_empty); + } + } + } + + std::shared_ptr vectorSet; + if (m_options.m_distCalcMethod == DistCalcMethod::Cosine && !p_normalized) + { + ByteArray arr = ByteArray::Alloc(sizeof(T) * p_vectorNum * p_dimension); + memcpy(arr.Data(), p_data, sizeof(T) * p_vectorNum * p_dimension); + vectorSet.reset(new BasicVectorSet(arr, GetEnumValueType(), p_dimension, p_vectorNum)); + int base = COMMON::Utils::GetBase(); + for (SizeType i = 0; i < p_vectorNum; i++) + { + COMMON::Utils::Normalize((T *)(vectorSet->GetVector(i)), p_dimension, base); + } + } + else + { + vectorSet.reset( + new BasicVectorSet(ByteArray((std::uint8_t *)p_data, sizeof(T) * p_vectorNum * p_dimension, false), + GetEnumValueType(), p_dimension, p_vectorNum)); + } + + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); + return m_extraSearcher->AddIndex(workSpace.get(), vectorSet, m_index, begin); +} + +template +ErrorCode Index::Check() +{ + std::atomic ret = ErrorCode::Success; + //if ((ret = m_index->Check()) != ErrorCode::Success) + // return ret; + + std::vector mythreads; + mythreads.reserve(m_options.m_iSSDNumberOfThreads); + std::atomic_size_t sent(0); + std::vector checked(m_extraSearcher->GetNumBlocks(), false); + for (int tid = 0; tid < m_options.m_iSSDNumberOfThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); } + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < m_index->GetNumSamples()) + { + if (m_index->ContainSample(i)) + { + if (m_extraSearcher->CheckPosting(i, &checked, workSpace.get()) != ErrorCode::Success) + { + ret = ErrorCode::Fail; + return; + } + + auto translatedID = *(m_vectorTranslateMap[i]); + if (translatedID >= m_versionMap.Count() && translatedID != MaxSize) + { + ret = ErrorCode::Fail; + return; + } + } + } + else + { + return; + } + } + m_workSpaceFactory->ReturnWorkSpace(std::move(workSpace)); + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + return ret.load(); +} + +template ErrorCode Index::DeleteIndex(const SizeType &p_id) +{ + std::shared_lock sharedlock(m_dataDeleteLock); + // if (m_versionMap.Delete(p_id)) return ErrorCode::Success; + // return ErrorCode::VectorNotFound; + return m_extraSearcher->DeleteIndex(p_id); +} + +template ErrorCode Index::DeleteIndex(const void *p_vectors, SizeType p_vectorNum) +{ + // TODO: Support batch delete + DimensionType p_dimension = GetFeatureDim(); + std::shared_ptr vectorSet; + if (m_options.m_distCalcMethod == DistCalcMethod::Cosine) + { + ByteArray arr = ByteArray::Alloc(sizeof(T) * p_vectorNum * p_dimension); + memcpy(arr.Data(), p_vectors, sizeof(T) * p_vectorNum * p_dimension); + vectorSet.reset(new BasicVectorSet(arr, GetEnumValueType(), p_dimension, p_vectorNum)); + int base = COMMON::Utils::GetBase(); + for (SizeType i = 0; i < p_vectorNum; i++) + { + COMMON::Utils::Normalize((T *)(vectorSet->GetVector(i)), p_dimension, base); } } + else + { + vectorSet.reset(new BasicVectorSet(ByteArray((std::uint8_t *)p_vectors, sizeof(T) * p_vectorNum * p_dimension, false), + GetEnumValueType(), p_dimension, p_vectorNum)); + } + + auto workSpace = m_workSpaceFactory->GetWorkSpace(); + if (!workSpace) + { + workSpace.reset(new ExtraWorkSpace()); + m_extraSearcher->InitWorkSpace(workSpace.get(), false); + } + else + { + m_extraSearcher->InitWorkSpace(workSpace.get(), true); + } + workSpace->m_deduper.clear(); + workSpace->m_postingIDs.clear(); + + SizeType p_id = m_extraSearcher->SearchVector(workSpace.get(), vectorSet, m_index); + if (p_id == -1) + return ErrorCode::VectorNotFound; + return DeleteIndex(p_id); } +} // namespace SPANN +} // namespace SPTAG -#define DefineVectorValueType(Name, Type) \ -template class SPTAG::SPANN::Index; \ +#define DefineVectorValueType(Name, Type) template class SPTAG::SPANN::Index; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - - diff --git a/AnnService/src/Core/SPANNResultIterator.cpp b/AnnService/src/Core/SPANNResultIterator.cpp new file mode 100644 index 000000000..e2b09f0af --- /dev/null +++ b/AnnService/src/Core/SPANNResultIterator.cpp @@ -0,0 +1,68 @@ +/* +#include "inc/Core/SPANNResultIterator.h" + +namespace SPTAG +{ + namespace SPANN + { + template + SPANNResultIterator::SPANNResultIterator(const Index* p_index, const void* p_target, + std::shared_ptr headWorkspace, + std::shared_ptr extraWorkspace, + int batch) + :m_index(p_index), + m_target(p_target), + m_headWorkspace(headWorkspace), + m_extraWorkspace(extraWorkspace), + m_batch(batch) + { + m_headQueryResult = std::make_unique(p_target, batch, false); + m_queryResult = std::make_unique(p_target, 1, true); + m_isFirstResult = true; + } + + template + SPANNResultIterator::~SPANNResultIterator() + { + if (m_index != nullptr && m_headWorkspace != nullptr && m_extraWorkspace != nullptr) { + m_index->SearchIndexIterativeEnd(m_headWorkspace, m_extraWorkspace); + } + m_headQueryResult = nullptr; + m_queryResult = nullptr; + } + + template + bool SPANNResultIterator::Next(BasicResult& result) + { + m_queryResult->Reset(); + m_index->SearchIndexIterativeNext(*m_headQueryResult, *m_queryResult, m_headWorkspace, m_extraWorkspace, +m_isFirstResult); m_isFirstResult = false; if (m_queryResult->GetResult(0) == nullptr || +m_queryResult->GetResult(0)->VID < 0) + { + return false; + } + result.VID = m_queryResult->GetResult(0)->VID; + result.Dist = m_queryResult->GetResult(0)->Dist; + result.Meta = m_queryResult->GetResult(0)->Meta; + return true; + } + + // Add end into destructor. + template + void SPANNResultIterator::Close() + { + if (m_headWorkspace != nullptr && m_extraWorkspace != nullptr) { + m_index->SearchIndexIterativeEnd(m_headWorkspace, m_extraWorkspace); + m_headWorkspace = nullptr; + m_extraWorkspace = nullptr; + } + } + + template + QueryResult* SPANNResultIterator::GetQuery() const + { + return m_queryResult.get(); + } + } // namespace SPANN +} // namespace SPTAG +*/ diff --git a/AnnService/src/Core/VectorIndex.cpp b/AnnService/src/Core/VectorIndex.cpp index f46523156..e329b4051 100644 --- a/AnnService/src/Core/VectorIndex.cpp +++ b/AnnService/src/Core/VectorIndex.cpp @@ -1,963 +1,1255 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "inc/Core/VectorIndex.h" -#include "inc/Helper/CommonHelper.h" -#include "inc/Helper/StringConvert.h" -#include "inc/Helper/SimpleIniReader.h" -#include "inc/Helper/ConcurrentSet.h" - -#include "inc/Core/BKT/Index.h" -#include "inc/Core/KDT/Index.h" -#include "inc/Core/SPANN/Index.h" - -typedef typename SPTAG::Helper::Concurrent::ConcurrentMap MetadataMap; - -using namespace SPTAG; - -#ifdef DEBUG -std::shared_ptr SPTAG::g_pLogger(new Helper::SimpleLogger(Helper::LogLevel::LL_Debug)); -#else -std::shared_ptr SPTAG::g_pLogger(new Helper::SimpleLogger(Helper::LogLevel::LL_Info)); -#endif - -std::mt19937 SPTAG::rg; - -std::shared_ptr(*SPTAG::f_createIO)() = []() -> std::shared_ptr { return std::shared_ptr(new Helper::SimpleFileIO()); }; - -namespace SPTAG { - - bool copyfile(const char* oldpath, const char* newpath) { - auto input = f_createIO(), output = f_createIO(); - if (input == nullptr || !input->Initialize(oldpath, std::ios::binary | std::ios::in) || - output == nullptr || !output->Initialize(newpath, std::ios::binary | std::ios::out)) - { - LOG(Helper::LogLevel::LL_Error, "Unable to open files: %s %s\n", oldpath, newpath); - return false; - } - - const std::size_t bufferSize = 1 << 30; - std::unique_ptr bufferHolder(new char[bufferSize]); - - std::uint64_t readSize; - while ((readSize = input->ReadBinary(bufferSize, bufferHolder.get()))) { - if (output->WriteBinary(readSize, bufferHolder.get()) != readSize) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", newpath); - return false; - } - } - input->ShutDown(); output->ShutDown(); - return true; - } - -#ifndef _MSC_VER - void listdir(std::string path, std::vector& files) { - if (auto dirptr = opendir(path.substr(0, path.length() - 1).c_str())) { - while (auto f = readdir(dirptr)) { - if (!f->d_name || f->d_name[0] == '.') continue; - std::string tmp = path.substr(0, path.length() - 1); - tmp += std::string(f->d_name); - if (f->d_type == DT_DIR) { - listdir(tmp + FolderSep + "*", files); - } - else { - files.push_back(tmp); - } - } - closedir(dirptr); - } - } -#else - void listdir(std::string path, std::vector& files) { - WIN32_FIND_DATA fd; - HANDLE hFile = FindFirstFile(path.c_str(), &fd); - if (hFile != INVALID_HANDLE_VALUE) { - do { - std::string tmp = path.substr(0, path.length() - 1); - tmp += std::string(fd.cFileName); - if (fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { - if (fd.cFileName[0] != '.') { - listdir(tmp + FolderSep + "*", files); - } - } - else { - files.push_back(tmp); - } - } while (FindNextFile(hFile, &fd)); - FindClose(hFile); - } - } -#endif -} - -VectorIndex::VectorIndex() -{ -} - - -VectorIndex::~VectorIndex() -{ -} - - -std::string -VectorIndex::GetParameter(const std::string& p_param, const std::string& p_section) const -{ - return GetParameter(p_param.c_str(), p_section.c_str()); -} - - -ErrorCode -VectorIndex::SetParameter(const std::string& p_param, const std::string& p_value, const std::string& p_section) -{ - return SetParameter(p_param.c_str(), p_value.c_str(), p_section.c_str()); -} - - -void -VectorIndex::SetMetadata(MetadataSet* p_new) { - m_pMetadata.reset(p_new); -} - - -MetadataSet* -VectorIndex::GetMetadata() const { - return m_pMetadata.get(); -} - - -ByteArray -VectorIndex::GetMetadata(SizeType p_vectorID) const { - if (nullptr != m_pMetadata) - { - return m_pMetadata->GetMetadata(p_vectorID); - } - return ByteArray::c_empty; -} - - -std::shared_ptr> VectorIndex::CalculateBufferSize() const -{ - std::shared_ptr> ret = BufferSize(); - - if (m_pMetadata != nullptr) - { - auto metasize = m_pMetadata->BufferSize(); - ret->push_back(metasize.first); - ret->push_back(metasize.second); - } - - if (m_pQuantizer) - { - ret->push_back(m_pQuantizer->BufferSize()); - } - return std::move(ret); -} - - -ErrorCode -VectorIndex::LoadIndexConfig(Helper::IniReader& p_reader) -{ - std::string metadataSection("MetaData"); - if (p_reader.DoesSectionExist(metadataSection)) - { - m_sMetadataFile = p_reader.GetParameter(metadataSection, "MetaDataFilePath", std::string()); - m_sMetadataIndexFile = p_reader.GetParameter(metadataSection, "MetaDataIndexPath", std::string()); - } - - std::string quantizerSection("Quantizer"); - if (p_reader.DoesSectionExist(quantizerSection)) - { - m_sQuantizerFile = p_reader.GetParameter(quantizerSection, "QuantizerFilePath", std::string()); - } - return LoadConfig(p_reader); -} - - -ErrorCode -VectorIndex::SaveIndexConfig(std::shared_ptr p_configOut) -{ - if (nullptr != m_pMetadata) - { - IOSTRING(p_configOut, WriteString, "[MetaData]\n"); - IOSTRING(p_configOut, WriteString, ("MetaDataFilePath=" + m_sMetadataFile + "\n").c_str()); - IOSTRING(p_configOut, WriteString, ("MetaDataIndexPath=" + m_sMetadataIndexFile + "\n").c_str()); - if (nullptr != m_pMetaToVec) IOSTRING(p_configOut, WriteString, "MetaDataToVectorIndex=true\n"); - IOSTRING(p_configOut, WriteString, "\n"); - } - - if (m_pQuantizer) - { - IOSTRING(p_configOut, WriteString, "[Quantizer]\n"); - IOSTRING(p_configOut, WriteString, ("QuantizerFilePath=" + m_sQuantizerFile + "\n").c_str()); - IOSTRING(p_configOut, WriteString, "\n"); - } - - IOSTRING(p_configOut, WriteString, "[Index]\n"); - IOSTRING(p_configOut, WriteString, ("IndexAlgoType=" + Helper::Convert::ConvertToString(GetIndexAlgoType()) + "\n").c_str()); - IOSTRING(p_configOut, WriteString, ("ValueType=" + Helper::Convert::ConvertToString(GetVectorValueType()) + "\n").c_str()); - IOSTRING(p_configOut, WriteString, "\n"); - - return SaveConfig(p_configOut); -} - - -SizeType -VectorIndex::GetMetaMapping(std::string& meta) const -{ - MetadataMap* ptr = static_cast(m_pMetaToVec.get()); - auto iter = ptr->find(meta); - if (iter != ptr->end()) return iter->second; - return -1; -} - - -void -VectorIndex::UpdateMetaMapping(const std::string& meta, SizeType i) -{ - MetadataMap* ptr = static_cast(m_pMetaToVec.get()); - auto iter = ptr->find(meta); - if (iter != ptr->end()) DeleteIndex(iter->second);; - (*ptr)[meta] = i; -} - - -void -VectorIndex::BuildMetaMapping(bool p_checkDeleted) -{ - MetadataMap* ptr = new MetadataMap(m_iDataBlockSize); - for (SizeType i = 0; i < m_pMetadata->Count(); i++) { - if (!p_checkDeleted || ContainSample(i)) { - ByteArray meta = m_pMetadata->GetMetadata(i); - (*ptr)[std::string((char*)meta.Data(), meta.Length())] = i; - } - } - m_pMetaToVec.reset(ptr, std::default_delete()); -} - - -ErrorCode -VectorIndex::SaveIndex(std::string& p_config, const std::vector& p_indexBlobs) -{ - if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) return ErrorCode::EmptyIndex; - - ErrorCode ret = ErrorCode::Success; - { - std::shared_ptr p_configStream(new Helper::SimpleBufferIO()); - if (p_configStream == nullptr || !p_configStream->Initialize(nullptr, std::ios::out)) return ErrorCode::EmptyDiskIO; - if ((ret = SaveIndexConfig(p_configStream)) != ErrorCode::Success) return ret; - p_config.resize(p_configStream->TellP()); - IOBINARY(p_configStream, ReadBinary, p_config.size(), (char*)p_config.c_str(), 0); - } - - std::vector> p_indexStreams; - for (size_t i = 0; i < p_indexBlobs.size(); i++) - { - std::shared_ptr ptr(new Helper::SimpleBufferIO()); - if (ptr == nullptr || !ptr->Initialize((char*)p_indexBlobs[i].Data(), std::ios::binary | std::ios::out, p_indexBlobs[i].Length())) return ErrorCode::EmptyDiskIO; - p_indexStreams.push_back(std::move(ptr)); - } - - size_t metaStart = BufferSize()->size(); - if (NeedRefine()) - { - ret = RefineIndex(p_indexStreams, nullptr); - } - else - { - if (m_pMetadata != nullptr && p_indexStreams.size() >= metaStart + 2) - { - - ret = m_pMetadata->SaveMetadata(p_indexStreams[metaStart], p_indexStreams[metaStart + 1]); - } - if (ErrorCode::Success == ret) ret = SaveIndexData(p_indexStreams); - } - if (m_pMetadata != nullptr) metaStart += 2; - - if (ErrorCode::Success == ret && m_pQuantizer && p_indexStreams.size() > metaStart) { - ret = m_pQuantizer->SaveQuantizer(p_indexStreams[metaStart]); - } - return ret; -} - - -ErrorCode -VectorIndex::SaveIndex(const std::string& p_folderPath) -{ - if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) return ErrorCode::EmptyIndex; - - std::string folderPath(p_folderPath); - if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) - { - folderPath += FolderSep; - } - if (!direxists(folderPath.c_str())) - { - mkdir(folderPath.c_str()); - } - - if (GetIndexAlgoType() == IndexAlgoType::SPANN && GetParameter("IndexDirectory", "Base") != p_folderPath) { - std::vector files; - std::string oldFolder = GetParameter("IndexDirectory", "Base"); - if (!oldFolder.empty() && *(oldFolder.rbegin()) != FolderSep) oldFolder += FolderSep; - listdir((oldFolder + "*").c_str(), files); - for (auto file : files) { - size_t firstSep = oldFolder.length(), lastSep = file.find_last_of(FolderSep); - std::string newFolder = folderPath + ((lastSep > firstSep)? file.substr(firstSep, lastSep - firstSep) : ""), filename = file.substr(lastSep + 1); - if (!direxists(newFolder.c_str())) mkdir(newFolder.c_str()); - LOG(Helper::LogLevel::LL_Info, "Copy file %s to %s...\n", file.c_str(), (newFolder + FolderSep + filename).c_str()); - if (!copyfile(file.c_str(), (newFolder + FolderSep + filename).c_str())) - return ErrorCode::DiskIOFail; - } - SetParameter("IndexDirectory", p_folderPath, "Base"); - } - - ErrorCode ret = ErrorCode::Success; - { - auto configFile = SPTAG::f_createIO(); - if (configFile == nullptr || !configFile->Initialize((folderPath + "indexloader.ini").c_str(), std::ios::out)) return ErrorCode::FailedCreateFile; - if ((ret = SaveIndexConfig(configFile)) != ErrorCode::Success) return ret; - } - - std::shared_ptr> indexfiles = GetIndexFiles(); - if (nullptr != m_pMetadata) { - indexfiles->push_back(m_sMetadataFile); - indexfiles->push_back(m_sMetadataIndexFile); - } - if (m_pQuantizer) { - indexfiles->push_back(m_sQuantizerFile); - } - std::vector> handles; - for (std::string& f : *indexfiles) { - std::string newfile = folderPath + f; - if (!direxists(newfile.substr(0, newfile.find_last_of(FolderSep)).c_str())) mkdir(newfile.substr(0, newfile.find_last_of(FolderSep)).c_str()); - - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize(newfile.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; - handles.push_back(std::move(ptr)); - } - - size_t metaStart = GetIndexFiles()->size(); - if (NeedRefine()) - { - ret = RefineIndex(handles, nullptr); - } - else - { - if (m_pMetadata != nullptr) ret = m_pMetadata->SaveMetadata(handles[metaStart], handles[metaStart + 1]); - if (ErrorCode::Success == ret) ret = SaveIndexData(handles); - } - if (m_pMetadata != nullptr) metaStart += 2; - - if (ErrorCode::Success == ret && m_pQuantizer) { - ret = m_pQuantizer->SaveQuantizer(handles[metaStart]); - } - return ret; -} - - -ErrorCode -VectorIndex::SaveIndexToFile(const std::string& p_file, IAbortOperation* p_abort) -{ - if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) return ErrorCode::EmptyIndex; - - auto fp = SPTAG::f_createIO(); - if (fp == nullptr || !fp->Initialize(p_file.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; - - auto mp = std::shared_ptr(new Helper::SimpleBufferIO()); - if (mp == nullptr || !mp->Initialize(nullptr, std::ios::binary | std::ios::out)) return ErrorCode::FailedCreateFile; - ErrorCode ret = ErrorCode::Success; - if ((ret = SaveIndexConfig(mp)) != ErrorCode::Success) return ret; - - std::uint64_t configSize = mp->TellP(); - mp->ShutDown(); - - IOBINARY(fp, WriteBinary, sizeof(configSize), (char*)&configSize); - if ((ret = SaveIndexConfig(fp)) != ErrorCode::Success) return ret; - - if (p_abort != nullptr && p_abort->ShouldAbort()) ret = ErrorCode::ExternalAbort; - else { - std::uint64_t blobs = CalculateBufferSize()->size(); - IOBINARY(fp, WriteBinary, sizeof(blobs), (char*)&blobs); - std::vector> p_indexStreams(blobs, fp); - - if (NeedRefine()) - { - ret = RefineIndex(p_indexStreams, p_abort); - } - else - { - ret = SaveIndexData(p_indexStreams); - - if (p_abort != nullptr && p_abort->ShouldAbort()) ret = ErrorCode::ExternalAbort; - - if (ErrorCode::Success == ret && m_pMetadata != nullptr) ret = m_pMetadata->SaveMetadata(fp, fp); - } - if (ErrorCode::Success == ret && m_pQuantizer) { - ret = m_pQuantizer->SaveQuantizer(fp); - } - } - fp->ShutDown(); - - if (ret != ErrorCode::Success) std::remove(p_file.c_str()); - return ret; -} - - -ErrorCode -VectorIndex::BuildIndex(std::shared_ptr p_vectorSet, - std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized, bool p_shareOwnership) -{ - LOG(Helper::LogLevel::LL_Info, "Begin build index...\n"); - - bool valueMatches = p_vectorSet->GetValueType() == GetVectorValueType(); - bool quantizerMatches = ((bool)m_pQuantizer) && (p_vectorSet->GetValueType() == SPTAG::VectorValueType::UInt8); - if (nullptr == p_vectorSet || !(valueMatches || quantizerMatches)) - { - return ErrorCode::Fail; - } - m_pMetadata = std::move(p_metadataSet); - if (p_withMetaIndex && m_pMetadata != nullptr) - { - LOG(Helper::LogLevel::LL_Info, "Build meta mapping...\n"); - BuildMetaMapping(false); - } - BuildIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension(), p_normalized, p_shareOwnership); - return ErrorCode::Success; -} - - -ErrorCode -VectorIndex::SearchIndex(const void* p_vector, int p_vectorCount, int p_neighborCount, bool p_withMeta, BasicResult* p_results) const { - size_t vectorSize = GetValueTypeSize(GetVectorValueType()) * GetFeatureDim(); -#pragma omp parallel for schedule(dynamic,10) - for (int i = 0; i < p_vectorCount; i++) { - QueryResult res((char*)p_vector + i * vectorSize, p_neighborCount, p_withMeta, p_results + i * p_neighborCount); - SearchIndex(res); - } - return ErrorCode::Success; -} - - -ErrorCode -VectorIndex::AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet, bool p_withMetaIndex, bool p_normalized) { - if (nullptr == p_vectorSet || p_vectorSet->GetValueType() != GetVectorValueType()) - { - return ErrorCode::Fail; - } - - return AddIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension(), p_metadataSet, p_withMetaIndex, p_normalized); -} - - -ErrorCode -VectorIndex::DeleteIndex(ByteArray p_meta) { - if (m_pMetaToVec == nullptr) return ErrorCode::VectorNotFound; - - std::string meta((char*)p_meta.Data(), p_meta.Length()); - SizeType vid = GetMetaMapping(meta); - if (vid >= 0) return DeleteIndex(vid); - return ErrorCode::VectorNotFound; -} - - -ErrorCode -VectorIndex::MergeIndex(VectorIndex* p_addindex, int p_threadnum, IAbortOperation* p_abort) -{ - ErrorCode ret = ErrorCode::Success; - if (p_addindex->m_pMetadata != nullptr) { -#pragma omp parallel for num_threads(p_threadnum) schedule(dynamic,128) - for (SizeType i = 0; i < p_addindex->GetNumSamples(); i++) - { - if (ret == ErrorCode::ExternalAbort) continue; - - if (p_addindex->ContainSample(i)) - { - ByteArray meta = p_addindex->GetMetadata(i); - std::uint64_t offsets[2] = { 0, meta.Length() }; - std::shared_ptr p_metaSet(new MemMetadataSet(meta, ByteArray((std::uint8_t*)offsets, sizeof(offsets), false), 1)); - AddIndex(p_addindex->GetSample(i), 1, p_addindex->GetFeatureDim(), p_metaSet); - } - - if (p_abort != nullptr && p_abort->ShouldAbort()) - { - ret = ErrorCode::ExternalAbort; - } - } - } - else { -#pragma omp parallel for num_threads(p_threadnum) schedule(dynamic,128) - for (SizeType i = 0; i < p_addindex->GetNumSamples(); i++) - { - if (ret == ErrorCode::ExternalAbort) continue; - - if (p_addindex->ContainSample(i)) - { - AddIndex(p_addindex->GetSample(i), 1, p_addindex->GetFeatureDim(), nullptr); - } - - if (p_abort != nullptr && p_abort->ShouldAbort()) - { - ret = ErrorCode::ExternalAbort; - } - } - } - return ret; -} - - -const void* VectorIndex::GetSample(ByteArray p_meta, bool& deleteFlag) -{ - if (m_pMetaToVec == nullptr) return nullptr; - - std::string meta((char*)p_meta.Data(), p_meta.Length()); - SizeType vid = GetMetaMapping(meta); - if (vid >= 0 && vid < GetNumSamples()) { - deleteFlag = !ContainSample(vid); - return GetSample(vid); - } - return nullptr; -} - - -ErrorCode -VectorIndex::LoadQuantizer(std::string p_quantizerFile) -{ - auto ptr = SPTAG::f_createIO(); - if (!ptr->Initialize(p_quantizerFile.c_str(), std::ios::binary | std::ios::in)) - { - LOG(Helper::LogLevel::LL_Error, "Failed to read quantizer file.\n"); - return ErrorCode::FailedOpenFile; - } - SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(ptr)); - if (!m_pQuantizer) - { - LOG(Helper::LogLevel::LL_Error, "Failed to load quantizer.\n"); - return ErrorCode::FailedParseValue; - } - return ErrorCode::Success; -} - - -std::shared_ptr -VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) -{ - if (IndexAlgoType::Undefined == p_algo || VectorValueType::Undefined == p_valuetype) - { - return nullptr; - } - - if (p_algo == IndexAlgoType::BKT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new BKT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - } - else if (p_algo == IndexAlgoType::KDT) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new KDT::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - } - else if (p_algo == IndexAlgoType::SPANN) { - switch (p_valuetype) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return std::shared_ptr(new SPANN::Index); \ - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - } - return nullptr; -} - - -ErrorCode -VectorIndex::LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex) -{ - std::string folderPath(p_loaderFilePath); - if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) folderPath += FolderSep; - - Helper::IniReader iniReader; - { - auto fp = SPTAG::f_createIO(); - if (fp == nullptr || !fp->Initialize((folderPath + "indexloader.ini").c_str(), std::ios::in)) return ErrorCode::FailedOpenFile; - if (ErrorCode::Success != iniReader.LoadIni(fp)) return ErrorCode::FailedParseValue; - } - - IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); - VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); - if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) return ErrorCode::FailedParseValue; - - ErrorCode ret = ErrorCode::Success; - if ((ret = p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) return ret; - - std::shared_ptr> indexfiles = p_vectorIndex->GetIndexFiles(); - if (iniReader.DoesSectionExist("MetaData")) { - indexfiles->push_back(p_vectorIndex->m_sMetadataFile); - indexfiles->push_back(p_vectorIndex->m_sMetadataIndexFile); - } - if (iniReader.DoesSectionExist("Quantizer")) { - indexfiles->push_back(p_vectorIndex->m_sQuantizerFile); - } - std::vector> handles; - for (std::string& f : *indexfiles) { - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize((folderPath + f).c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Cannot open file %s!\n", (folderPath + f).c_str()); - ptr = nullptr; - } - handles.push_back(std::move(ptr)); - } - - if ((ret = p_vectorIndex->LoadIndexData(handles)) != ErrorCode::Success) return ret; - - size_t metaStart = p_vectorIndex->GetIndexFiles()->size(); - if (iniReader.DoesSectionExist("MetaData")) - { - p_vectorIndex->SetMetadata(new MemMetadataSet(handles[metaStart], handles[metaStart + 1], - p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, p_vectorIndex->m_iMetaRecordSize)); - - if (!(p_vectorIndex->GetMetadata()->Available())) - { - LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); - return ErrorCode::Fail; - } - - if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") - { - p_vectorIndex->BuildMetaMapping(); - } - metaStart += 2; - } - if (iniReader.DoesSectionExist("Quantizer")) { - p_vectorIndex->SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(handles[metaStart])); - if (!p_vectorIndex->m_pQuantizer) return ErrorCode::FailedParseValue; - } - p_vectorIndex->m_bReady = true; - return ErrorCode::Success; -} - - -ErrorCode -VectorIndex::LoadIndexFromFile(const std::string& p_file, std::shared_ptr& p_vectorIndex) -{ - auto fp = SPTAG::f_createIO(); - if (fp == nullptr || !fp->Initialize(p_file.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; - - SPTAG::Helper::IniReader iniReader; - { - std::uint64_t configSize; - IOBINARY(fp, ReadBinary, sizeof(configSize), (char*)&configSize); - std::vector config(configSize + 1, '\0'); - IOBINARY(fp, ReadBinary, configSize, config.data()); - - std::shared_ptr bufferhandle(new Helper::SimpleBufferIO()); - if (bufferhandle == nullptr || !bufferhandle->Initialize(config.data(), std::ios::in, configSize)) return ErrorCode::EmptyDiskIO; - if (SPTAG::ErrorCode::Success != iniReader.LoadIni(bufferhandle)) return ErrorCode::FailedParseValue; - } - - IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); - VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); - - ErrorCode ret = ErrorCode::Success; - - if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) return ErrorCode::FailedParseValue; - - if ((ret = p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) return ret; - - std::uint64_t blobs; - IOBINARY(fp, ReadBinary, sizeof(blobs), (char*)&blobs); - - std::vector> p_indexStreams(blobs, fp); - if ((ret = p_vectorIndex->LoadIndexData(p_indexStreams)) != ErrorCode::Success) return ret; - - if (iniReader.DoesSectionExist("MetaData")) - { - p_vectorIndex->SetMetadata(new MemMetadataSet(fp, fp, p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, p_vectorIndex->m_iMetaRecordSize)); - - if (!(p_vectorIndex->GetMetadata()->Available())) - { - LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); - return ErrorCode::Fail; - } - - if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") - { - p_vectorIndex->BuildMetaMapping(); - } - } - - if (iniReader.DoesSectionExist("Quantizer")) - { - p_vectorIndex->SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(fp)); - if (!p_vectorIndex->m_pQuantizer) return ErrorCode::FailedParseValue; - } - - p_vectorIndex->m_bReady = true; - return ErrorCode::Success; -} - - -ErrorCode -VectorIndex::LoadIndex(const std::string& p_config, const std::vector& p_indexBlobs, std::shared_ptr& p_vectorIndex) -{ - SPTAG::Helper::IniReader iniReader; - std::shared_ptr fp(new Helper::SimpleBufferIO()); - if (fp == nullptr || !fp->Initialize(p_config.c_str(), std::ios::in, p_config.size())) return ErrorCode::EmptyDiskIO; - if (SPTAG::ErrorCode::Success != iniReader.LoadIni(fp)) return ErrorCode::FailedParseValue; - - IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); - VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); - - ErrorCode ret = ErrorCode::Success; - - if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) return ErrorCode::FailedParseValue; - if (!iniReader.GetParameter("Base", "QuantizerFilePath", std::string()).empty()) - { - p_vectorIndex->SetQuantizer(COMMON::IQuantizer::LoadIQuantizer(p_indexBlobs[4])); - if (!p_vectorIndex->m_pQuantizer) return ErrorCode::FailedParseValue; - } - - if ((p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) return ret; - - if ((ret = p_vectorIndex->LoadIndexDataFromMemory(p_indexBlobs)) != ErrorCode::Success) return ret; - - size_t metaStart = p_vectorIndex->BufferSize()->size(); - if (iniReader.DoesSectionExist("MetaData") && p_indexBlobs.size() >= metaStart + 2) - { - ByteArray pMetaIndex = p_indexBlobs[metaStart + 1]; - p_vectorIndex->SetMetadata(new MemMetadataSet(p_indexBlobs[metaStart], - ByteArray(pMetaIndex.Data() + sizeof(SizeType), pMetaIndex.Length() - sizeof(SizeType), false), - *((SizeType*)pMetaIndex.Data()), - p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, p_vectorIndex->m_iMetaRecordSize)); - - if (!(p_vectorIndex->GetMetadata()->Available())) - { - LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); - return ErrorCode::Fail; - } - - if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") - { - p_vectorIndex->BuildMetaMapping(); - } - metaStart += 2; - } - - p_vectorIndex->m_bReady = true; - return ErrorCode::Success; -} - - -std::uint64_t VectorIndex::EstimatedVectorCount(std::uint64_t p_memory, DimensionType p_dimension, VectorValueType p_valuetype, SizeType p_vectorsInBlock, SizeType p_maxmeta, IndexAlgoType p_algo, int p_treeNumber, int p_neighborhoodSize) -{ - size_t treeNodeSize; - if (p_algo == IndexAlgoType::BKT) { - treeNodeSize = sizeof(SizeType) * 3; - } - else if (p_algo == IndexAlgoType::KDT) { - treeNodeSize = sizeof(SizeType) * 2 + sizeof(DimensionType) + sizeof(float); - } - else { - return 0; - } - std::uint64_t unit = GetValueTypeSize(p_valuetype) * p_dimension + p_maxmeta + sizeof(std::uint64_t) + sizeof(SizeType) * p_neighborhoodSize + 1 + treeNodeSize * p_treeNumber; - return ((p_memory / unit) / p_vectorsInBlock) * p_vectorsInBlock; -} - - -std::uint64_t VectorIndex::EstimatedMemoryUsage(std::uint64_t p_vectorCount, DimensionType p_dimension, VectorValueType p_valuetype, SizeType p_vectorsInBlock, SizeType p_maxmeta, IndexAlgoType p_algo, int p_treeNumber, int p_neighborhoodSize) -{ - p_vectorCount = ((p_vectorCount + p_vectorsInBlock - 1) / p_vectorsInBlock) * p_vectorsInBlock; - size_t treeNodeSize; - if (p_algo == IndexAlgoType::BKT) { - treeNodeSize = sizeof(SizeType) * 3; - } - else if (p_algo == IndexAlgoType::KDT) { - treeNodeSize = sizeof(SizeType) * 2 + sizeof(DimensionType) + sizeof(float); - } - else { - return 0; - } - std::uint64_t ret = GetValueTypeSize(p_valuetype) * p_dimension * p_vectorCount; //Vector Size - ret += p_maxmeta * p_vectorCount; // MetaData Size - ret += sizeof(std::uint64_t) * p_vectorCount; // MetaIndex Size - ret += sizeof(SizeType) * p_neighborhoodSize * p_vectorCount; // Graph Size - ret += p_vectorCount; // DeletedFlag Size - ret += treeNodeSize * p_treeNumber * p_vectorCount; // Tree Size - return ret; -} - - - -#if defined(GPU) - -#include "inc/Core/Common/cuda/TailNeighbors.hxx" - -void VectorIndex::SortSelections(std::vector* selections) { - LOG(Helper::LogLevel::LL_Debug, "Starting sort of final input on GPU\n"); - GPU_SortSelections(selections); -} - - - -void VectorIndex::ApproximateRNG(std::shared_ptr& fullVectors, std::unordered_set& exceptIDS, int candidateNum, Edge* selections, int replicaCount, int numThreads, int numTrees, int leafSize, float RNGFactor, int numGPUs) -{ - - LOG(Helper::LogLevel::LL_Info, "Starting GPU SSD Index build stage...\n"); - - int metric = (GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine); - - if(m_pQuantizer) { - getTailNeighborsTPT((uint8_t*)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, metric, numGPUs, selections); - } - else if(GetVectorValueType() != VectorValueType::Float) { - typedef int32_t SUMTYPE; - switch (GetVectorValueType()) - { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - getTailNeighborsTPT((Type*)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, metric, numGPUs, selections); \ - break; - -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - - default: break; - } - } - else { - getTailNeighborsTPT((float*)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, metric, numGPUs, selections); - } - -} -#else - -void VectorIndex::SortSelections(std::vector* selections) { - EdgeCompare edgeComparer; - std::sort(selections->begin(), selections->end(), edgeComparer); -} - -void VectorIndex::ApproximateRNG(std::shared_ptr& fullVectors, std::unordered_set& exceptIDS, int candidateNum, Edge* selections, int replicaCount, int numThreads, int numTrees, int leafSize, float RNGFactor, int numGPUs) -{ - std::vector threads; - threads.reserve(numThreads); - - std::atomic_int nextFullID(0); - std::atomic_size_t rngFailedCountTotal(0); - - for (int tid = 0; tid < numThreads; ++tid) - { - threads.emplace_back([&, tid]() - { - QueryResult resultSet(NULL, candidateNum, false); - - size_t rngFailedCount = 0; - - while (true) - { - int fullID = nextFullID.fetch_add(1); - if (fullID >= fullVectors->Count()) - { - break; - } - - if (exceptIDS.count(fullID) > 0) - { - continue; - } - - void* reconstructed_vector = nullptr; - if (m_pQuantizer) - { - reconstructed_vector = ALIGN_ALLOC(m_pQuantizer->ReconstructSize()); - m_pQuantizer->ReconstructVector((const uint8_t*)fullVectors->GetVector(fullID), reconstructed_vector); - switch (m_pQuantizer->GetReconstructType()) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - (*((COMMON::QueryResultSet*)&resultSet)).SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); \ - break; -#include "inc/Core/DefinitionList.h" -#undef DefineVectorValueType - default: - LOG(Helper::LogLevel::LL_Error, "Unable to get quantizer reconstruct type %s", Helper::Convert::ConvertToString(m_pQuantizer->GetReconstructType())); - } - } - else - { - resultSet.SetTarget(fullVectors->GetVector(fullID)); - } - resultSet.Reset(); - - SearchIndex(resultSet); - - size_t selectionOffset = static_cast(fullID)* replicaCount; - - BasicResult* queryResults = resultSet.GetResults(); - int currReplicaCount = 0; - for (int i = 0; i < candidateNum && currReplicaCount < replicaCount; ++i) - { - if (queryResults[i].VID == -1) - { - break; - } - - // RNG Check. - bool rngAccpeted = true; - for (int j = 0; j < currReplicaCount; ++j) - { - float nnDist = ComputeDistance(GetSample(queryResults[i].VID), GetSample(selections[selectionOffset+j].node)); - - if (RNGFactor * nnDist < queryResults[i].Dist) - { - rngAccpeted = false; - break; - } - } - - if (!rngAccpeted) - { - ++rngFailedCount; - continue; - } - - selections[selectionOffset + currReplicaCount].node = queryResults[i].VID; - selections[selectionOffset + currReplicaCount].distance = queryResults[i].Dist; - ++currReplicaCount; - } - - if (reconstructed_vector) - { - ALIGN_FREE(reconstructed_vector); - } - } - rngFailedCountTotal += rngFailedCount; - }); - } - - for (int tid = 0; tid < numThreads; ++tid) - { - threads[tid].join(); - } - LOG(Helper::LogLevel::LL_Info, "Searching replicas ended. RNG failed count: %llu\n", static_cast(rngFailedCountTotal.load())); -} -#endif +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/ConcurrentSet.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/StringConvert.h" + +#include "inc/Core/BKT/Index.h" +#include "inc/Core/KDT/Index.h" +#include "inc/Core/SPANN/Index.h" + +#include + +typedef typename SPTAG::Helper::Concurrent::ConcurrentMap MetadataMap; + +namespace fs = std::filesystem; +using namespace SPTAG; + +Helper::LoggerHolder &SPTAG::GetLoggerHolder() +{ +#ifdef DEBUG + auto logLevel = Helper::LogLevel::LL_Debug; +#else + auto logLevel = Helper::LogLevel::LL_Info; +#endif +#ifdef _WINDOWS_ + if (auto exeHandle = GetModuleHandleW(nullptr)) + { + if (auto SPTAG_GetLoggerLevel = + reinterpret_cast(GetProcAddress(exeHandle, "SPTAG_GetLoggerLevel"))) + { + logLevel = SPTAG_GetLoggerLevel(); + } + } +#endif // _WINDOWS_ + static Helper::LoggerHolder s_pLoggerHolder(std::make_shared(logLevel)); + return s_pLoggerHolder; +} + +std::shared_ptr SPTAG::GetLogger() +{ + return GetLoggerHolder().GetLogger(); +} + +void SPTAG::SetLogger(std::shared_ptr p_logger) +{ + GetLoggerHolder().SetLogger(p_logger); +} + +std::shared_ptr (*SPTAG::f_createIO)() = []() -> std::shared_ptr { + return std::shared_ptr(new Helper::SimpleFileIO()); +}; + +namespace SPTAG +{ + +bool copyfile(const char *oldpath, const char *newpath) +{ + auto input = f_createIO(), output = f_createIO(); + if (input == nullptr || !input->Initialize(oldpath, std::ios::binary | std::ios::in) || output == nullptr || + !output->Initialize(newpath, std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to open files: %s %s\n", oldpath, newpath); + return false; + } + + const std::size_t bufferSize = 1 << 30; + std::unique_ptr bufferHolder(new char[bufferSize]); + + std::uint64_t readSize = input->ReadBinary(bufferSize, bufferHolder.get()); + while (readSize != 0) + { + if (output->WriteBinary(readSize, bufferHolder.get()) != readSize) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", newpath); + return false; + } + readSize = input->ReadBinary(bufferSize, bufferHolder.get()); + } + input->ShutDown(); + output->ShutDown(); + return true; +} + +bool copydirectory(const fs::path &sourceDir, const fs::path &destinationDir) +{ + try + { + // Ensure the destination directory exists. + // create_directories will create parent directories if they don't exist. + if (!fs::exists(destinationDir)) + { + if (!fs::create_directories(destinationDir)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Could not create destination directory %s\n", + destinationDir.string().c_str()); + return false; + } + } + else if (!fs::is_directory(destinationDir)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Destination path %s exists but is not a directory.\n", + destinationDir.string().c_str()); + return false; + } + + // Iterate through all entries in the source directory + for (const auto &entry : fs::directory_iterator(sourceDir)) + { + fs::path currentPath = entry.path(); + fs::path destinationPath = destinationDir / currentPath.filename(); // Append filename to destination path + + if (fs::is_directory(currentPath)) + { + // If it's a subdirectory, recursively copy its contents + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Copying directory: %s to %s\n", currentPath.string().c_str(), + destinationPath.string().c_str()); + if (!copydirectory(currentPath, destinationPath)) + { + return false; // Propagate error from recursive call + } + } + else + { + // If it's a file, copy it directly + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Copying file: %s to %s\n", currentPath.string().c_str(), + destinationPath.string().c_str()); + // Use copy_options::overwrite_existing to replace files if they exist in destination + fs::copy(currentPath, destinationPath, fs::copy_options::overwrite_existing); + } + } + } + catch (const fs::filesystem_error &e) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Filesystem error: %s\n", e.what()); + return false; + } + catch (const std::exception &e) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "General error: %s\n", e.what()); + return false; + } + return true; +} + +#ifndef _MSC_VER +void listdir(std::string path, std::vector &files) +{ + if (auto dirptr = opendir(path.substr(0, path.length() - 1).c_str())) + { + while (auto f = readdir(dirptr)) + { + if (!f->d_name || f->d_name[0] == '.') + continue; + std::string tmp = path.substr(0, path.length() - 1); + tmp += std::string(f->d_name); + if (f->d_type == DT_DIR) + { + listdir(tmp + FolderSep + "*", files); + } + else + { + files.push_back(tmp); + } + } + closedir(dirptr); + } +} +#else +void listdir(std::string path, std::vector &files) +{ + WIN32_FIND_DATA fd; + HANDLE hFile = FindFirstFile(path.c_str(), &fd); + if (hFile != INVALID_HANDLE_VALUE) + { + do + { + std::string tmp = path.substr(0, path.length() - 1); + tmp += std::string(fd.cFileName); + if (fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) + { + if (fd.cFileName[0] != '.') + { + listdir(tmp + FolderSep + "*", files); + } + } + else + { + files.push_back(tmp); + } + } while (FindNextFile(hFile, &fd)); + FindClose(hFile); + } +} +#endif +} // namespace SPTAG + +VectorIndex::VectorIndex() +{ +} + +VectorIndex::~VectorIndex() +{ +} + +std::string VectorIndex::GetParameter(const std::string &p_param, const std::string &p_section) const +{ + return GetParameter(p_param.c_str(), p_section.c_str()); +} + +ErrorCode VectorIndex::SetParameter(const std::string &p_param, const std::string &p_value, + const std::string &p_section) +{ + return SetParameter(p_param.c_str(), p_value.c_str(), p_section.c_str()); +} + +void VectorIndex::SetMetadata(MetadataSet *p_new) +{ + m_pMetadata.reset(p_new); +} + +MetadataSet *VectorIndex::GetMetadata() const +{ + return m_pMetadata.get(); +} + +ByteArray VectorIndex::GetMetadata(SizeType p_vectorID) const +{ + if (nullptr != m_pMetadata) + { + return m_pMetadata->GetMetadata(p_vectorID); + } + return ByteArray::c_empty; +} + +std::shared_ptr> VectorIndex::CalculateBufferSize() const +{ + std::shared_ptr> ret = BufferSize(); + + if (m_pMetadata != nullptr) + { + auto metasize = m_pMetadata->BufferSize(); + ret->push_back(metasize.first); + ret->push_back(metasize.second); + } + + if (m_pQuantizer) + { + ret->push_back(m_pQuantizer->BufferSize()); + } + return std::move(ret); +} + +ErrorCode VectorIndex::LoadIndexConfig(Helper::IniReader &p_reader) +{ + std::string metadataSection("MetaData"); + if (p_reader.DoesSectionExist(metadataSection)) + { + m_sMetadataFile = p_reader.GetParameter(metadataSection, "MetaDataFilePath", std::string()); + m_sMetadataIndexFile = p_reader.GetParameter(metadataSection, "MetaDataIndexPath", std::string()); + } + + std::string quantizerSection("Quantizer"); + if (p_reader.DoesSectionExist(quantizerSection)) + { + m_sQuantizerFile = p_reader.GetParameter(quantizerSection, "QuantizerFilePath", std::string()); + } + return LoadConfig(p_reader); +} + +ErrorCode VectorIndex::SaveIndexConfig(std::shared_ptr p_configOut) +{ + if (nullptr != m_pMetadata) + { + IOSTRING(p_configOut, WriteString, "[MetaData]\n"); + IOSTRING(p_configOut, WriteString, ("MetaDataFilePath=" + m_sMetadataFile + "\n").c_str()); + IOSTRING(p_configOut, WriteString, ("MetaDataIndexPath=" + m_sMetadataIndexFile + "\n").c_str()); + if (nullptr != m_pMetaToVec) + IOSTRING(p_configOut, WriteString, "MetaDataToVectorIndex=true\n"); + IOSTRING(p_configOut, WriteString, "\n"); + } + + if (m_pQuantizer) + { + IOSTRING(p_configOut, WriteString, "[Quantizer]\n"); + IOSTRING(p_configOut, WriteString, ("QuantizerFilePath=" + m_sQuantizerFile + "\n").c_str()); + IOSTRING(p_configOut, WriteString, "\n"); + } + + IOSTRING(p_configOut, WriteString, "[Index]\n"); + IOSTRING(p_configOut, WriteString, + ("IndexAlgoType=" + Helper::Convert::ConvertToString(GetIndexAlgoType()) + "\n").c_str()); + IOSTRING(p_configOut, WriteString, + ("ValueType=" + Helper::Convert::ConvertToString(GetVectorValueType()) + "\n").c_str()); + IOSTRING(p_configOut, WriteString, "\n"); + + return SaveConfig(p_configOut); +} + +SizeType VectorIndex::GetMetaMapping(std::string &meta) const +{ + MetadataMap *ptr = static_cast(m_pMetaToVec.get()); + auto iter = ptr->find(meta); + if (iter != ptr->end()) + return iter->second; + return -1; +} + +void VectorIndex::UpdateMetaMapping(const std::string &meta, SizeType i) +{ + MetadataMap *ptr = static_cast(m_pMetaToVec.get()); + auto iter = ptr->find(meta); + if (iter != ptr->end()) + DeleteIndex(iter->second); + ; + (*ptr)[meta] = i; +} + +void VectorIndex::BuildMetaMapping(bool p_checkDeleted) +{ + MetadataMap *ptr = new MetadataMap(m_iDataBlockSize); + for (SizeType i = 0; i < m_pMetadata->Count(); i++) + { + if (!p_checkDeleted || ContainSample(i)) + { + ByteArray meta = m_pMetadata->GetMetadata(i); + (*ptr)[std::string((char *)meta.Data(), meta.Length())] = i; + } + } + m_pMetaToVec.reset(ptr, std::default_delete()); +} + +ErrorCode VectorIndex::SaveIndex(std::string &p_config, const std::vector &p_indexBlobs) +{ + if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) + return ErrorCode::EmptyIndex; + + ErrorCode ret = ErrorCode::Success; + { + std::shared_ptr p_configStream(new Helper::SimpleBufferIO()); + auto bufsize = 2 << 20; + std::vector buf(bufsize); // Allocate 1 MB scratch space + if (p_configStream == nullptr || !p_configStream->Initialize(buf.data(), std::ios::out, bufsize)) + return ErrorCode::EmptyDiskIO; + if ((ret = SaveIndexConfig(p_configStream)) != ErrorCode::Success) + return ret; + p_config.resize(p_configStream->TellP()); + IOBINARY(p_configStream, ReadBinary, p_config.size(), (char *)p_config.c_str(), 0); + } + + std::vector> p_indexStreams; + for (size_t i = 0; i < p_indexBlobs.size(); i++) + { + std::shared_ptr ptr(new Helper::SimpleBufferIO()); + if (ptr == nullptr || !ptr->Initialize((char *)p_indexBlobs[i].Data(), std::ios::binary | std::ios::out, + p_indexBlobs[i].Length())) + return ErrorCode::EmptyDiskIO; + p_indexStreams.push_back(std::move(ptr)); + } + + size_t metaStart = BufferSize()->size(); + if (NeedRefine()) + { + ret = RefineIndex(p_indexStreams, nullptr, nullptr); + } + else + { + if (m_pMetadata != nullptr && p_indexStreams.size() >= metaStart + 2) + { + + ret = m_pMetadata->SaveMetadata(p_indexStreams[metaStart], p_indexStreams[metaStart + 1]); + } + if (ErrorCode::Success == ret) + ret = SaveIndexData(p_indexStreams); + } + if (m_pMetadata != nullptr) + metaStart += 2; + + if (ErrorCode::Success == ret && m_pQuantizer && p_indexStreams.size() > metaStart) + { + ret = m_pQuantizer->SaveQuantizer(p_indexStreams[metaStart]); + } + return ret; +} + +ErrorCode VectorIndex::SaveIndex(const std::string &p_folderPath) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "SaveIndex(%s) begin...ready:%d GetNumSamples():%d GetNumDeleted():%d\n", + p_folderPath.c_str(), m_bReady, GetNumSamples(), GetNumDeleted()); + if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) + return ErrorCode::EmptyIndex; + + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + + if (GetIndexAlgoType() == IndexAlgoType::SPANN && GetParameter("IndexDirectory", "Base") != p_folderPath) + { + std::vector files; + std::string oldFolder = GetParameter("IndexDirectory", "Base"); + if (!oldFolder.empty() && *(oldFolder.rbegin()) != FolderSep) + oldFolder += FolderSep; + listdir((oldFolder + "*").c_str(), files); + for (auto file : files) + { + size_t firstSep = oldFolder.length(), lastSep = file.find_last_of(FolderSep); + std::string newFolder = + folderPath + ((lastSep > firstSep) ? file.substr(firstSep, lastSep - firstSep) : ""), + filename = file.substr(lastSep + 1); + if (!direxists(newFolder.c_str())) + mkdir(newFolder.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Copy file %s to %s...\n", file.c_str(), + (newFolder + FolderSep + filename).c_str()); + if (!copyfile(file.c_str(), (newFolder + FolderSep + filename).c_str())) + return ErrorCode::DiskIOFail; + } + SetParameter("IndexDirectory", p_folderPath, "Base"); + } + + ErrorCode ret = ErrorCode::Success; + { + auto configFile = SPTAG::f_createIO(); + if (configFile == nullptr || !configFile->Initialize((folderPath + "indexloader.ini").c_str(), std::ios::out)) + return ErrorCode::FailedCreateFile; + if ((ret = SaveIndexConfig(configFile)) != ErrorCode::Success) + return ret; + } + + std::shared_ptr> indexfiles = GetIndexFiles(); + if (nullptr != m_pMetadata) + { + indexfiles->push_back(m_sMetadataFile); + indexfiles->push_back(m_sMetadataIndexFile); + } + if (m_pQuantizer) + { + indexfiles->push_back(m_sQuantizerFile); + } + std::vector> handles; + for (std::string &f : *indexfiles) + { + std::string newfile = folderPath + f; + if (!direxists(newfile.substr(0, newfile.find_last_of(FolderSep)).c_str())) + mkdir(newfile.substr(0, newfile.find_last_of(FolderSep)).c_str()); + + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize(newfile.c_str(), std::ios::binary | std::ios::out)) + return ErrorCode::FailedCreateFile; + handles.push_back(std::move(ptr)); + } + + size_t metaStart = GetIndexFiles()->size(); + if (NeedRefine()) + { + ret = RefineIndex(handles, nullptr, nullptr); + } + else + { + if (m_pMetadata != nullptr) + ret = m_pMetadata->SaveMetadata(handles[metaStart], handles[metaStart + 1]); + if (ErrorCode::Success == ret) + ret = SaveIndexData(handles); + } + if (m_pMetadata != nullptr) + metaStart += 2; + + if (ErrorCode::Success == ret && m_pQuantizer) + { + ret = m_pQuantizer->SaveQuantizer(handles[metaStart]); + } + return ret; +} + +ErrorCode VectorIndex::SaveIndexToFile(const std::string &p_file, IAbortOperation *p_abort) +{ + if (!m_bReady || GetNumSamples() - GetNumDeleted() == 0) + return ErrorCode::EmptyIndex; + + auto fp = SPTAG::f_createIO(); + if (fp == nullptr || !fp->Initialize(p_file.c_str(), std::ios::binary | std::ios::out)) + return ErrorCode::FailedCreateFile; + + auto mp = std::shared_ptr(new Helper::SimpleBufferIO()); + auto bufsize = 2 << 20; + std::vector buf(bufsize); // Allocate 1 MB scratch space + if (mp == nullptr || !mp->Initialize(buf.data(), std::ios::binary | std::ios::out, bufsize)) + return ErrorCode::FailedCreateFile; + ErrorCode ret = ErrorCode::Success; + if ((ret = SaveIndexConfig(mp)) != ErrorCode::Success) + return ret; + + std::uint64_t configSize = mp->TellP(); + mp->ShutDown(); + + IOBINARY(fp, WriteBinary, sizeof(configSize), (char *)&configSize); + if ((ret = SaveIndexConfig(fp)) != ErrorCode::Success) + return ret; + + if (p_abort != nullptr && p_abort->ShouldAbort()) + ret = ErrorCode::ExternalAbort; + else + { + std::uint64_t blobs = CalculateBufferSize()->size(); + IOBINARY(fp, WriteBinary, sizeof(blobs), (char *)&blobs); + std::vector> p_indexStreams(blobs, fp); + + if (NeedRefine()) + { + ret = RefineIndex(p_indexStreams, p_abort, nullptr); + } + else + { + ret = SaveIndexData(p_indexStreams); + + if (p_abort != nullptr && p_abort->ShouldAbort()) + ret = ErrorCode::ExternalAbort; + + if (ErrorCode::Success == ret && m_pMetadata != nullptr) + ret = m_pMetadata->SaveMetadata(fp, fp); + } + if (ErrorCode::Success == ret && m_pQuantizer) + { + ret = m_pQuantizer->SaveQuantizer(fp); + } + } + fp->ShutDown(); + + if (ret != ErrorCode::Success) + std::remove(p_file.c_str()); + return ret; +} + +ErrorCode VectorIndex::BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet, + bool p_withMetaIndex, bool p_normalized, bool p_shareOwnership) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin build index...\n"); + + bool valueMatches = p_vectorSet->GetValueType() == GetVectorValueType(); + bool quantizerMatches = ((bool)m_pQuantizer) && (p_vectorSet->GetValueType() == SPTAG::VectorValueType::UInt8); + if (nullptr == p_vectorSet || !(valueMatches || quantizerMatches)) + { + return ErrorCode::Fail; + } + m_pMetadata = std::move(p_metadataSet); + if (p_withMetaIndex && m_pMetadata != nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build meta mapping...\n"); + BuildMetaMapping(false); + } + ErrorCode ret = BuildIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension(), p_normalized, p_shareOwnership); + return ret; +} + +ErrorCode VectorIndex::SearchIndex(const void *p_vector, int p_vectorCount, int p_neighborCount, bool p_withMeta, + BasicResult *p_results) const +{ + size_t vectorSize = GetValueTypeSize(GetVectorValueType()) * GetFeatureDim(); + std::vector mythreads; + int maxthreads = std::thread::hardware_concurrency(); + mythreads.reserve(maxthreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < maxthreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < p_vectorCount) + { + QueryResult res((char *)p_vector + i * vectorSize, p_neighborCount, p_withMeta, + p_results + i * p_neighborCount); + SearchIndex(res); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + return ErrorCode::Success; +} + +ErrorCode VectorIndex::AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet, + bool p_withMetaIndex, bool p_normalized) +{ + if (nullptr == p_vectorSet || p_vectorSet->GetValueType() != GetVectorValueType()) + { + return ErrorCode::Fail; + } + + return AddIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension(), p_metadataSet, + p_withMetaIndex, p_normalized); +} + +ErrorCode VectorIndex::DeleteIndex(ByteArray p_meta) +{ + if (m_pMetaToVec == nullptr) + return ErrorCode::VectorNotFound; + + std::string meta((char *)p_meta.Data(), p_meta.Length()); + SizeType vid = GetMetaMapping(meta); + if (vid >= 0) + return DeleteIndex(vid); + return ErrorCode::VectorNotFound; +} + +ErrorCode VectorIndex::MergeIndex(VectorIndex *p_addindex, int p_threadnum, IAbortOperation *p_abort) +{ + ErrorCode ret = ErrorCode::Success; + if (p_addindex->m_pMetadata != nullptr) + { + std::vector mythreads; + mythreads.reserve(p_threadnum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < p_threadnum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < p_addindex->GetNumSamples()) + { + if (ret == ErrorCode::ExternalAbort) + continue; + + if (p_addindex->ContainSample(i)) + { + ByteArray meta = p_addindex->GetMetadata(i); + std::uint64_t offsets[2] = {0, meta.Length()}; + std::shared_ptr p_metaSet(new MemMetadataSet( + meta, ByteArray((std::uint8_t *)offsets, sizeof(offsets), false), 1)); + AddIndex(p_addindex->GetSample(i), 1, p_addindex->GetFeatureDim(), p_metaSet); + } + + if (p_abort != nullptr && p_abort->ShouldAbort()) + { + ret = ErrorCode::ExternalAbort; + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + } + else + { + std::vector mythreads; + mythreads.reserve(p_threadnum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < p_threadnum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < p_addindex->GetNumSamples()) + { + if (ret == ErrorCode::ExternalAbort) + continue; + + if (p_addindex->ContainSample(i)) + { + AddIndex(p_addindex->GetSample(i), 1, p_addindex->GetFeatureDim(), nullptr); + } + + if (p_abort != nullptr && p_abort->ShouldAbort()) + { + ret = ErrorCode::ExternalAbort; + } + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + } + return ret; +} + +const void *VectorIndex::GetSample(ByteArray p_meta, bool &deleteFlag) +{ + if (m_pMetaToVec == nullptr) + return nullptr; + + std::string meta((char *)p_meta.Data(), p_meta.Length()); + SizeType vid = GetMetaMapping(meta); + if (vid >= 0 && vid < GetNumSamples()) + { + deleteFlag = !ContainSample(vid); + return GetSample(vid); + } + return nullptr; +} + +ErrorCode VectorIndex::LoadQuantizer(std::string p_quantizerFile) +{ + auto ptr = SPTAG::f_createIO(); + if (!ptr->Initialize(p_quantizerFile.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read quantizer file.\n"); + return ErrorCode::FailedOpenFile; + } + SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(ptr)); + if (!m_pQuantizer) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to load quantizer.\n"); + return ErrorCode::FailedParseValue; + } + return ErrorCode::Success; +} + +std::shared_ptr VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) +{ + if (IndexAlgoType::Undefined == p_algo || VectorValueType::Undefined == p_valuetype) + { + return nullptr; + } + + if (p_algo == IndexAlgoType::BKT) + { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new BKT::Index); + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + } + else if (p_algo == IndexAlgoType::KDT) + { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new KDT::Index); + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + } + else if (p_algo == IndexAlgoType::SPANN) + { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new SPANN::Index); + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + } + return nullptr; +} + +ErrorCode VectorIndex::LoadIndex(const std::string &p_loaderFilePath, std::shared_ptr &p_vectorIndex) +{ + std::string folderPath(p_loaderFilePath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + folderPath += FolderSep; + + Helper::IniReader iniReader; + { + auto fp = SPTAG::f_createIO(); + if (fp == nullptr || !fp->Initialize((folderPath + "indexloader.ini").c_str(), std::ios::in)) + return ErrorCode::FailedOpenFile; + if (ErrorCode::Success != iniReader.LoadIni(fp)) + return ErrorCode::FailedParseValue; + } + + IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); + if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) + return ErrorCode::FailedParseValue; + + ErrorCode ret = ErrorCode::Success; + if ((ret = p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) + return ret; + + std::shared_ptr> indexfiles = p_vectorIndex->GetIndexFiles(); + if (iniReader.DoesSectionExist("MetaData")) + { + indexfiles->push_back(p_vectorIndex->m_sMetadataFile); + indexfiles->push_back(p_vectorIndex->m_sMetadataIndexFile); + } + if (iniReader.DoesSectionExist("Quantizer")) + { + indexfiles->push_back(p_vectorIndex->m_sQuantizerFile); + } + std::vector> handles; + for (std::string &f : *indexfiles) + { + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize((folderPath + f).c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open file %s!\n", (folderPath + f).c_str()); + ptr = nullptr; + } + handles.push_back(std::move(ptr)); + } + + if (algoType == IndexAlgoType::SPANN) + { + p_vectorIndex->SetParameter("IndexDirectory", p_loaderFilePath, "Base"); + } + + if ((ret = p_vectorIndex->LoadIndexData(handles)) != ErrorCode::Success) + return ret; + + size_t metaStart = p_vectorIndex->GetIndexFiles()->size(); + if (iniReader.DoesSectionExist("MetaData")) + { + p_vectorIndex->SetMetadata(new MemMetadataSet(handles[metaStart], handles[metaStart + 1], + p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, + p_vectorIndex->m_iMetaRecordSize)); + + if (!(p_vectorIndex->GetMetadata()->Available())) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); + return ErrorCode::Fail; + } + + if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") + { + p_vectorIndex->BuildMetaMapping(); + } + metaStart += 2; + } + if (iniReader.DoesSectionExist("Quantizer")) + { + p_vectorIndex->SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(handles[metaStart])); + if (!p_vectorIndex->m_pQuantizer) + return ErrorCode::FailedParseValue; + } + p_vectorIndex->m_bReady = true; + return ErrorCode::Success; +} + +ErrorCode VectorIndex::LoadIndexFromFile(const std::string &p_file, std::shared_ptr &p_vectorIndex) +{ + auto fp = SPTAG::f_createIO(); + if (fp == nullptr || !fp->Initialize(p_file.c_str(), std::ios::binary | std::ios::in)) + return ErrorCode::FailedOpenFile; + + SPTAG::Helper::IniReader iniReader; + { + std::uint64_t configSize; + IOBINARY(fp, ReadBinary, sizeof(configSize), (char *)&configSize); + std::vector config(configSize + 1, '\0'); + IOBINARY(fp, ReadBinary, configSize, config.data()); + + std::shared_ptr bufferhandle(new Helper::SimpleBufferIO()); + if (bufferhandle == nullptr || !bufferhandle->Initialize(config.data(), std::ios::in, configSize)) + return ErrorCode::EmptyDiskIO; + if (SPTAG::ErrorCode::Success != iniReader.LoadIni(bufferhandle)) + return ErrorCode::FailedParseValue; + } + + IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); + + ErrorCode ret = ErrorCode::Success; + + if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) + return ErrorCode::FailedParseValue; + + if ((ret = p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) + return ret; + + std::uint64_t blobs; + IOBINARY(fp, ReadBinary, sizeof(blobs), (char *)&blobs); + + std::vector> p_indexStreams(blobs, fp); + if ((ret = p_vectorIndex->LoadIndexData(p_indexStreams)) != ErrorCode::Success) + return ret; + + if (iniReader.DoesSectionExist("MetaData")) + { + p_vectorIndex->SetMetadata(new MemMetadataSet( + fp, fp, p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, p_vectorIndex->m_iMetaRecordSize)); + + if (!(p_vectorIndex->GetMetadata()->Available())) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); + return ErrorCode::Fail; + } + + if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") + { + p_vectorIndex->BuildMetaMapping(); + } + } + + if (iniReader.DoesSectionExist("Quantizer")) + { + p_vectorIndex->SetQuantizer(SPTAG::COMMON::IQuantizer::LoadIQuantizer(fp)); + if (!p_vectorIndex->m_pQuantizer) + return ErrorCode::FailedParseValue; + } + + p_vectorIndex->m_bReady = true; + return ErrorCode::Success; +} + +ErrorCode VectorIndex::LoadIndex(const std::string &p_config, const std::vector &p_indexBlobs, + std::shared_ptr &p_vectorIndex) +{ + SPTAG::Helper::IniReader iniReader; + std::shared_ptr fp(new Helper::SimpleBufferIO()); + if (fp == nullptr || !fp->Initialize(p_config.c_str(), std::ios::in, p_config.size())) + return ErrorCode::EmptyDiskIO; + if (SPTAG::ErrorCode::Success != iniReader.LoadIni(fp)) + return ErrorCode::FailedParseValue; + + IndexAlgoType algoType = iniReader.GetParameter("Index", "IndexAlgoType", IndexAlgoType::Undefined); + VectorValueType valueType = iniReader.GetParameter("Index", "ValueType", VectorValueType::Undefined); + + ErrorCode ret = ErrorCode::Success; + + if ((p_vectorIndex = CreateInstance(algoType, valueType)) == nullptr) + return ErrorCode::FailedParseValue; + if (!iniReader.GetParameter("Base", "QuantizerFilePath", std::string()).empty()) + { + p_vectorIndex->SetQuantizer(COMMON::IQuantizer::LoadIQuantizer(p_indexBlobs[4])); + if (!p_vectorIndex->m_pQuantizer) + return ErrorCode::FailedParseValue; + } + + if ((p_vectorIndex->LoadIndexConfig(iniReader)) != ErrorCode::Success) + return ret; + + if ((ret = p_vectorIndex->LoadIndexDataFromMemory(p_indexBlobs)) != ErrorCode::Success) + return ret; + + size_t metaStart = p_vectorIndex->BufferSize()->size(); + if (iniReader.DoesSectionExist("MetaData") && p_indexBlobs.size() >= metaStart + 2) + { + ByteArray pMetaIndex = p_indexBlobs[metaStart + 1]; + p_vectorIndex->SetMetadata(new MemMetadataSet( + p_indexBlobs[metaStart], + ByteArray(pMetaIndex.Data() + sizeof(SizeType), pMetaIndex.Length() - sizeof(SizeType), false), + *((SizeType *)pMetaIndex.Data()), p_vectorIndex->m_iDataBlockSize, p_vectorIndex->m_iDataCapacity, + p_vectorIndex->m_iMetaRecordSize)); + + if (!(p_vectorIndex->GetMetadata()->Available())) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: Failed to load metadata.\n"); + return ErrorCode::Fail; + } + + if (iniReader.GetParameter("MetaData", "MetaDataToVectorIndex", std::string()) == "true") + { + p_vectorIndex->BuildMetaMapping(); + } + metaStart += 2; + } + + p_vectorIndex->m_bReady = true; + return ErrorCode::Success; +} + +std::shared_ptr VectorIndex::Clone(std::string p_clone) +{ + ErrorCode ret = ErrorCode::Success; + if (GetIndexAlgoType() != IndexAlgoType::SPANN) + ret = SaveIndex(p_clone); + else + { + std::string indexFolder = GetParameter("IndexDirectory", "Base"); + //ret = SaveIndex(indexFolder); + if (!copydirectory(indexFolder, p_clone)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to copy index directory contents to %s!\n", + p_clone.c_str()); + return nullptr; + } + } + if (ret != ErrorCode::Success) + return nullptr; + + std::shared_ptr clone; + auto status = VectorIndex::LoadIndex(p_clone, clone); + if (status != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to load index from %s!\n", p_clone.c_str()); + return nullptr; + } + return clone; +} + +std::uint64_t VectorIndex::EstimatedVectorCount(std::uint64_t p_memory, DimensionType p_dimension, + VectorValueType p_valuetype, SizeType p_vectorsInBlock, + SizeType p_maxmeta, IndexAlgoType p_algo, int p_treeNumber, + int p_neighborhoodSize) +{ + size_t treeNodeSize; + if (p_algo == IndexAlgoType::BKT) + { + treeNodeSize = sizeof(SizeType) * 3; + } + else if (p_algo == IndexAlgoType::KDT) + { + treeNodeSize = sizeof(SizeType) * 2 + sizeof(DimensionType) + sizeof(float); + } + else + { + return 0; + } + std::uint64_t unit = GetValueTypeSize(p_valuetype) * p_dimension + p_maxmeta + sizeof(std::uint64_t) + + sizeof(SizeType) * p_neighborhoodSize + 1 + treeNodeSize * p_treeNumber; + return ((p_memory / unit) / p_vectorsInBlock) * p_vectorsInBlock; +} + +std::uint64_t VectorIndex::EstimatedMemoryUsage(std::uint64_t p_vectorCount, DimensionType p_dimension, + VectorValueType p_valuetype, SizeType p_vectorsInBlock, + SizeType p_maxmeta, IndexAlgoType p_algo, int p_treeNumber, + int p_neighborhoodSize) +{ + p_vectorCount = ((p_vectorCount + p_vectorsInBlock - 1) / p_vectorsInBlock) * p_vectorsInBlock; + size_t treeNodeSize; + if (p_algo == IndexAlgoType::BKT) + { + treeNodeSize = sizeof(SizeType) * 3; + } + else if (p_algo == IndexAlgoType::KDT) + { + treeNodeSize = sizeof(SizeType) * 2 + sizeof(DimensionType) + sizeof(float); + } + else + { + return 0; + } + std::uint64_t ret = GetValueTypeSize(p_valuetype) * p_dimension * p_vectorCount; // Vector Size + ret += p_maxmeta * p_vectorCount; // MetaData Size + ret += sizeof(std::uint64_t) * p_vectorCount; // MetaIndex Size + ret += sizeof(SizeType) * p_neighborhoodSize * p_vectorCount; // Graph Size + ret += p_vectorCount; // DeletedFlag Size + ret += treeNodeSize * p_treeNumber * p_vectorCount; // Tree Size + return ret; +} + +#if defined(GPU) + +#include "inc/Core/Common/cuda/TailNeighbors.hxx" + +void VectorIndex::SortSelections(std::vector *selections) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "Starting sort of final input on GPU\n"); + GPU_SortSelections(selections); +} + +void VectorIndex::ApproximateRNG(std::shared_ptr &fullVectors, std::unordered_set &exceptIDS, + int candidateNum, Edge *selections, int replicaCount, int numThreads, int numTrees, + int leafSize, float RNGFactor, int numGPUs) +{ + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Starting GPU SSD Index build stage...\n"); + + int metric = (GetDistCalcMethod() == SPTAG::DistCalcMethod::Cosine); + + if (m_pQuantizer) + { + getTailNeighborsTPT((uint8_t *)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, + fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, + metric, numGPUs, selections); + } + else if (GetVectorValueType() != VectorValueType::Float) + { + typedef int32_t SUMTYPE; + switch (GetVectorValueType()) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + getTailNeighborsTPT((Type *)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, \ + fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, \ + metric, numGPUs, selections); \ + break; + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: + break; + } + } + else + { + getTailNeighborsTPT((float *)fullVectors->GetData(), fullVectors->Count(), this, exceptIDS, + fullVectors->Dimension(), replicaCount, numThreads, numTrees, leafSize, + metric, numGPUs, selections); + } +} +#else + +void VectorIndex::SortSelections(std::vector *selections) +{ + EdgeCompare edgeComparer; + std::sort(selections->begin(), selections->end(), edgeComparer); +} + +void VectorIndex::ApproximateRNG(std::shared_ptr &fullVectors, std::unordered_set &exceptIDS, + int candidateNum, Edge *selections, int replicaCount, int numThreads, int numTrees, + int leafSize, float RNGFactor, int numGPUs) +{ + std::vector threads; + threads.reserve(numThreads); + + std::atomic_int nextFullID(0); + std::atomic_size_t rngFailedCountTotal(0); + + for (int tid = 0; tid < numThreads; ++tid) + { + threads.emplace_back([&, tid]() { + QueryResult resultSet(NULL, candidateNum, false); + + size_t rngFailedCount = 0; + + while (true) + { + int fullID = nextFullID.fetch_add(1); + if (fullID >= fullVectors->Count()) + { + break; + } + + if (exceptIDS.count(fullID) > 0) + { + continue; + } + + void *reconstructed_vector = nullptr; + if (m_pQuantizer) + { + reconstructed_vector = ALIGN_ALLOC(m_pQuantizer->ReconstructSize()); + m_pQuantizer->ReconstructVector((const uint8_t *)fullVectors->GetVector(fullID), + reconstructed_vector); + switch (m_pQuantizer->GetReconstructType()) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + (*((COMMON::QueryResultSet *)&resultSet)) \ + .SetTarget(reinterpret_cast(reconstructed_vector), m_pQuantizer); \ + break; +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + default: + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, "Unable to get quantizer reconstruct type %s", + Helper::Convert::ConvertToString(m_pQuantizer->GetReconstructType())); + } + } + else + { + resultSet.SetTarget(fullVectors->GetVector(fullID)); + } + resultSet.Reset(); + + SearchIndex(resultSet); + + size_t selectionOffset = static_cast(fullID) * replicaCount; + + BasicResult *queryResults = resultSet.GetResults(); + int currReplicaCount = 0; + for (int i = 0; i < candidateNum && currReplicaCount < replicaCount; ++i) + { + if (queryResults[i].VID == -1) + { + break; + } + + // RNG Check. + bool rngAccpeted = true; + for (int j = 0; j < currReplicaCount; ++j) + { + float nnDist = ComputeDistance(GetSample(queryResults[i].VID), + GetSample(selections[selectionOffset + j].node)); + + if (RNGFactor * nnDist < queryResults[i].Dist) + { + rngAccpeted = false; + break; + } + } + + if (!rngAccpeted) + { + ++rngFailedCount; + continue; + } + + selections[selectionOffset + currReplicaCount].node = queryResults[i].VID; + selections[selectionOffset + currReplicaCount].distance = queryResults[i].Dist; + ++currReplicaCount; + } + + if (reconstructed_vector) + { + ALIGN_FREE(reconstructed_vector); + } + } + rngFailedCountTotal += rngFailedCount; + }); + } + + for (int tid = 0; tid < numThreads; ++tid) + { + threads[tid].join(); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Searching replicas ended. RNG failed count: %llu\n", + static_cast(rngFailedCountTotal.load())); +} +#endif diff --git a/AnnService/src/Core/VectorSet.cpp b/AnnService/src/Core/VectorSet.cpp index 68cea06a4..e8b84c090 100644 --- a/AnnService/src/Core/VectorSet.cpp +++ b/AnnService/src/Core/VectorSet.cpp @@ -3,6 +3,7 @@ #include "inc/Core/VectorSet.h" #include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/VectorIndex.h" using namespace SPTAG; @@ -10,91 +11,72 @@ VectorSet::VectorSet() { } - VectorSet::~VectorSet() { } - -BasicVectorSet::BasicVectorSet(const ByteArray& p_bytesArray, - VectorValueType p_valueType, - DimensionType p_dimension, +BasicVectorSet::BasicVectorSet(const ByteArray &p_bytesArray, VectorValueType p_valueType, DimensionType p_dimension, SizeType p_vectorCount) - : m_data(p_bytesArray), - m_valueType(p_valueType), - m_dimension(p_dimension), - m_vectorCount(p_vectorCount), + : m_data(p_bytesArray), m_valueType(p_valueType), m_dimension(p_dimension), m_vectorCount(p_vectorCount), m_perVectorDataSize(static_cast(p_dimension * GetValueTypeSize(p_valueType))) { } - BasicVectorSet::~BasicVectorSet() { } - -VectorValueType -BasicVectorSet::GetValueType() const +VectorValueType BasicVectorSet::GetValueType() const { return m_valueType; } - -void* -BasicVectorSet::GetVector(SizeType p_vectorID) const +void *BasicVectorSet::GetVector(SizeType p_vectorID) const { if (p_vectorID < 0 || p_vectorID >= m_vectorCount) { return nullptr; } - return reinterpret_cast(m_data.Data() + ((size_t)p_vectorID) * m_perVectorDataSize); + return reinterpret_cast(m_data.Data() + ((size_t)p_vectorID) * m_perVectorDataSize); } -void* -BasicVectorSet::GetData() const +void *BasicVectorSet::GetData() const { - return reinterpret_cast(m_data.Data()); + return reinterpret_cast(m_data.Data()); } -DimensionType -BasicVectorSet::Dimension() const +DimensionType BasicVectorSet::Dimension() const { return m_dimension; } - -SizeType -BasicVectorSet::Count() const +SizeType BasicVectorSet::Count() const { return m_vectorCount; } - -bool -BasicVectorSet::Available() const +bool BasicVectorSet::Available() const { return m_data.Data() != nullptr; } - -ErrorCode -BasicVectorSet::Save(const std::string& p_vectorFile) const +ErrorCode BasicVectorSet::Save(const std::string &p_vectorFile) const { auto fp = SPTAG::f_createIO(); - if (fp == nullptr || !fp->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out)) return ErrorCode::FailedOpenFile; + if (fp == nullptr || !fp->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out)) + return ErrorCode::FailedOpenFile; - IOBINARY(fp, WriteBinary, sizeof(SizeType), (char*)&m_vectorCount); - IOBINARY(fp, WriteBinary, sizeof(DimensionType), (char*)&m_dimension); - IOBINARY(fp, WriteBinary, m_data.Length(), (char*)m_data.Data()); + IOBINARY(fp, WriteBinary, sizeof(SizeType), (char *)&m_vectorCount); + IOBINARY(fp, WriteBinary, sizeof(DimensionType), (char *)&m_dimension); + IOBINARY(fp, WriteBinary, m_data.Length(), (char *)m_data.Data()); return ErrorCode::Success; } -ErrorCode BasicVectorSet::AppendSave(const std::string& p_vectorFile) const +ErrorCode BasicVectorSet::AppendSave(const std::string &p_vectorFile) const { auto append = fileexists(p_vectorFile.c_str()); - + SizeType count; SizeType dim; @@ -102,10 +84,14 @@ ErrorCode BasicVectorSet::AppendSave(const std::string& p_vectorFile) const if (append) { auto fp_read = SPTAG::f_createIO(); - if (fp_read == nullptr || !fp_read->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::in)) return ErrorCode::FailedOpenFile; - IOBINARY(fp_read, ReadBinary, sizeof(SizeType), (char*)&count); - IOBINARY(fp_read, ReadBinary, sizeof(DimensionType), (char*)&dim); - if (dim != m_dimension) { return ErrorCode::DimensionSizeMismatch; } + if (fp_read == nullptr || !fp_read->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::in)) + return ErrorCode::FailedOpenFile; + IOBINARY(fp_read, ReadBinary, sizeof(SizeType), (char *)&count); + IOBINARY(fp_read, ReadBinary, sizeof(DimensionType), (char *)&dim); + if (dim != m_dimension) + { + return ErrorCode::DimensionSizeMismatch; + } count += m_vectorCount; } else @@ -117,37 +103,39 @@ ErrorCode BasicVectorSet::AppendSave(const std::string& p_vectorFile) const // Update count header { auto fp_write = SPTAG::f_createIO(); - if (fp_write == nullptr || !fp_write->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out | (append ? std::ios::in : 0))) return ErrorCode::FailedOpenFile; - IOBINARY(fp_write, WriteBinary, sizeof(SizeType), (char*)&count); - IOBINARY(fp_write, WriteBinary, sizeof(DimensionType), (char*)&dim); + if (fp_write == nullptr || + !fp_write->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out | (append ? std::ios::in : 0))) + return ErrorCode::FailedOpenFile; + IOBINARY(fp_write, WriteBinary, sizeof(SizeType), (char *)&count); + IOBINARY(fp_write, WriteBinary, sizeof(DimensionType), (char *)&dim); } - // Write new vectors to end of file { auto fp_append = SPTAG::f_createIO(); - if (fp_append == nullptr || !fp_append->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out | std::ios::app)) return ErrorCode::FailedOpenFile; - IOBINARY(fp_append, WriteBinary, m_data.Length(), (char*)m_data.Data()); + if (fp_append == nullptr || + !fp_append->Initialize(p_vectorFile.c_str(), std::ios::binary | std::ios::out | std::ios::app)) + return ErrorCode::FailedOpenFile; + IOBINARY(fp_append, WriteBinary, m_data.Length(), (char *)m_data.Data()); } return ErrorCode::Success; } -SizeType BasicVectorSet::PerVectorDataSize() const +SizeType BasicVectorSet::PerVectorDataSize() const { return (SizeType)m_perVectorDataSize; } - -void -BasicVectorSet::Normalize(int p_threads) +void BasicVectorSet::Normalize(int p_threads) { switch (m_valueType) { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); \ -break; \ +#define DefineVectorValueType(Name, Type) \ + case SPTAG::VectorValueType::Name: \ + SPTAG::COMMON::Utils::BatchNormalize(reinterpret_cast(m_data.Data()), m_vectorCount, \ + m_dimension, SPTAG::COMMON::Utils::GetBase(), p_threads); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType diff --git a/AnnService/src/Helper/ArgumentsParser.cpp b/AnnService/src/Helper/ArgumentsParser.cpp index 9d1fcfbc7..64a5f3da6 100644 --- a/AnnService/src/Helper/ArgumentsParser.cpp +++ b/AnnService/src/Helper/ArgumentsParser.cpp @@ -5,38 +5,32 @@ using namespace SPTAG::Helper; - ArgumentsParser::IArgument::IArgument() { } - ArgumentsParser::IArgument::~IArgument() { } - ArgumentsParser::ArgumentsParser() { } - ArgumentsParser::~ArgumentsParser() { } - -bool -ArgumentsParser::Parse(int p_argc, char** p_args) +bool ArgumentsParser::Parse(int p_argc, char **p_args) { while (p_argc > 0) { int last = p_argc; - for (auto& option : m_arguments) + for (auto &option : m_arguments) { if (!option->ParseValue(p_argc, p_args)) { - LOG(Helper::LogLevel::LL_Empty, "Failed to parse args around \"%s\"\n", *p_args); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "Failed to parse args around \"%s\"\n", *p_args); PrintHelp(); return false; } @@ -50,20 +44,20 @@ ArgumentsParser::Parse(int p_argc, char** p_args) } bool isValid = true; - for (auto& option : m_arguments) + for (auto &option : m_arguments) { if (option->IsRequiredButNotSet()) { - LOG(Helper::LogLevel::LL_Empty, "Required option not set:\n "); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "Required option not set:\n "); option->PrintDescription(); - LOG(Helper::LogLevel::LL_Empty, "\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "\n"); isValid = false; } } if (!isValid) { - LOG(Helper::LogLevel::LL_Empty, "\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "\n"); PrintHelp(); return false; } @@ -71,16 +65,14 @@ ArgumentsParser::Parse(int p_argc, char** p_args) return true; } - -void -ArgumentsParser::PrintHelp() +void ArgumentsParser::PrintHelp() { - LOG(Helper::LogLevel::LL_Empty, "Usage: "); - for (auto& option : m_arguments) + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "Usage: "); + for (auto &option : m_arguments) { - LOG(Helper::LogLevel::LL_Empty, "\n "); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "\n "); option->PrintDescription(); } - LOG(Helper::LogLevel::LL_Empty, "\n\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Empty, "\n\n"); } diff --git a/AnnService/src/Helper/AsyncFileReader.cpp b/AnnService/src/Helper/AsyncFileReader.cpp index cb9b5ee6c..062940365 100644 --- a/AnnService/src/Helper/AsyncFileReader.cpp +++ b/AnnService/src/Helper/AsyncFileReader.cpp @@ -3,124 +3,250 @@ #include "inc/Helper/AsyncFileReader.h" -namespace SPTAG { - namespace Helper { +namespace SPTAG +{ +namespace Helper +{ #ifndef _MSC_VER - struct timespec AIOTimeout {0, 30000}; - void BatchReadFileAsync(std::vector>& handlers, AsyncReadRequest* readRequests, int num) +void SetThreadAffinity(int threadID, std::thread &thread, NumaStrategy socketStrategy, OrderStrategy idStrategy) +{ +#ifdef NUMA + int numGroups = numa_num_task_nodes(); + int numCpus = numa_num_task_cpus() / numGroups; + + int group = threadID / numCpus; + int cpuid = threadID % numCpus; + if (socketStrategy == NumaStrategy::SCATTER) + { + group = threadID % numGroups; + cpuid = (threadID / numGroups) % numCpus; + } + + struct bitmask *cpumask = numa_allocate_cpumask(); + if (!numa_node_to_cpus(group, cpumask)) + { + unsigned int nodecpu = 0; + for (unsigned int i = 0; i < cpumask->size; i++) { - std::vector myiocbs(num); - std::vector> iocbs(handlers.size()); - std::vector submitted(handlers.size(), 0); - std::vector done(handlers.size(), 0); - int totalToSubmit = 0, channel = 0; - - memset(myiocbs.data(), 0, num * sizeof(struct iocb)); - for (int i = 0; i < num; i++) { - AsyncReadRequest* readRequest = &(readRequests[i]); - - channel = readRequest->m_status & 0xffff; - int fileid = (readRequest->m_status >> 16); - - struct iocb* myiocb = &(myiocbs[totalToSubmit++]); - myiocb->aio_data = reinterpret_cast(readRequest); - myiocb->aio_lio_opcode = IOCB_CMD_PREAD; - myiocb->aio_fildes = ((AsyncFileIO*)(handlers[fileid].get()))->GetFileHandler(); - myiocb->aio_buf = (std::uint64_t)(readRequest->m_buffer); - myiocb->aio_nbytes = readRequest->m_readSize; - myiocb->aio_offset = static_cast(readRequest->m_offset); - - iocbs[fileid].emplace_back(myiocb); - } - std::vector events(totalToSubmit); - int totalDone = 0, totalSubmitted = 0, totalQueued = 0; - while (totalDone < totalToSubmit) { - if (totalSubmitted < totalToSubmit) { - for (int i = 0; i < handlers.size(); i++) { - if (submitted[i] < iocbs[i].size()) { - AsyncFileIO* handler = (AsyncFileIO*)(handlers[i].get()); - int s = syscall(__NR_io_submit, handler->GetIOCP(channel), iocbs[i].size() - submitted[i], iocbs[i].data() + submitted[i]); - if (s > 0) { - submitted[i] += s; - totalSubmitted += s; - } - else { - LOG(Helper::LogLevel::LL_Error, "fid:%d channel %d, to submit:%d, submitted:%s\n", i, channel, iocbs[i].size() - submitted[i], strerror(-s)); - } - } + if (numa_bitmask_isbitset(cpumask, i)) + { + if (cpuid == nodecpu) + { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(i, &cpuset); + int rc = pthread_setaffinity_np(thread.native_handle(), sizeof(cpu_set_t), &cpuset); + if (rc != 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Error calling pthread_setaffinity_np for thread %d: %d\n", threadID, rc); } + break; } + nodecpu++; + } + } + } +#else + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(threadID, &cpuset); + int rc = pthread_setaffinity_np(thread.native_handle(), sizeof(cpu_set_t), &cpuset); + if (rc != 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error calling pthread_setaffinity_np for thread %d: %d\n", threadID, + rc); + } +#endif +} + +struct timespec AIOTimeout +{ + 0, 30000 +}; +bool BatchReadFileAsync(std::vector> &handlers, AsyncReadRequest *readRequests, int num) +{ + std::vector myiocbs(num); + std::vector> iocbs(handlers.size()); + std::vector submitted(handlers.size(), 0); + std::vector done(handlers.size(), 0); + int totalToSubmit = 0, iocp = 0; + + memset(myiocbs.data(), 0, num * sizeof(struct iocb)); + for (int i = 0; i < num; i++) + { + AsyncReadRequest *readRequest = &(readRequests[i]); + + iocp = readRequest->m_status & 0xffff; + int fileid = (readRequest->m_status >> 16); - for (int i = totalQueued; i < totalDone; i++) { - AsyncReadRequest* req = reinterpret_cast((events[i].data)); - if (nullptr != req) + struct iocb *myiocb = &(myiocbs[totalToSubmit++]); + myiocb->aio_data = reinterpret_cast(readRequest); + myiocb->aio_lio_opcode = IOCB_CMD_PREAD; + myiocb->aio_fildes = ((AsyncFileIO *)(handlers[fileid].get()))->GetFileHandler(); + myiocb->aio_buf = (std::uint64_t)(readRequest->m_buffer); + myiocb->aio_nbytes = readRequest->m_readSize; + myiocb->aio_offset = static_cast(readRequest->m_offset); + + iocbs[fileid].emplace_back(myiocb); + } + + std::vector events(totalToSubmit); + int totalDone = 0, totalSubmitted = 0, totalQueued = 0; + while (totalDone < totalToSubmit) + { + if (totalSubmitted < totalToSubmit) + { + for (int i = 0; i < handlers.size(); i++) + { + if (submitted[i] < iocbs[i].size()) + { + AsyncFileIO *handler = (AsyncFileIO *)(handlers[i].get()); + int s = syscall(__NR_io_submit, handler->GetIOCP(iocp), iocbs[i].size() - submitted[i], + iocbs[i].data() + submitted[i]); + if (s > 0) { - req->m_callback(req); + submitted[i] += s; + totalSubmitted += s; } - } - totalQueued = totalDone; - - for (int i = 0; i < handlers.size(); i++) { - if (done[i] < submitted[i]) { - int wait = submitted[i] - done[i]; - AsyncFileIO* handler = (AsyncFileIO*)(handlers[i].get()); - auto d = syscall(__NR_io_getevents, handler->GetIOCP(channel), wait, wait, events.data() + totalDone, &AIOTimeout); - done[i] += d; - totalDone += d; + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "fid:%d channel %d, to submit:%d, submitted:%s\n", i, + iocp, iocbs[i].size() - submitted[i], strerror(-s)); + return false; } } } + } - for (int i = totalQueued; i < totalDone; i++) { - AsyncReadRequest* req = reinterpret_cast((events[i].data)); - if (nullptr != req) - { - req->m_callback(req); - } + for (int i = totalQueued; i < totalDone; i++) + { + AsyncReadRequest *req = reinterpret_cast((events[i].data)); + if (nullptr != req) + { + req->m_callback(true); } } -#else - using BatchOp = bool (DiskIO::*)(AsyncReadRequest*, uint32_t); - void CallOnAppropriateBatch(std::vector>& handlers, AsyncReadRequest* readRequests, int num, BatchOp f) + totalQueued = totalDone; + + for (int i = 0; i < handlers.size(); i++) { - if (handlers.size() == 1) { - (handlers[0].get()->*f)(readRequests, num); - } - else { - int currFileId = 0, currReqStart = 0; - for (int i = 0; i < num; i++) { - AsyncReadRequest* readRequest = &(readRequests[i]); - - int fileid = (readRequest->m_status >> 16); - if (fileid != currFileId) { - (handlers[currFileId].get()->*f)(readRequests + currReqStart, i - currReqStart); - currFileId = fileid; - currReqStart = i; - } - } - if (currReqStart < num) { - (handlers[currFileId].get()->*f)(readRequests + currReqStart, num - currReqStart); + if (done[i] < submitted[i]) + { + int wait = submitted[i] - done[i]; + AsyncFileIO *handler = (AsyncFileIO *)(handlers[i].get()); + auto d = syscall(__NR_io_getevents, handler->GetIOCP(iocp), wait, wait, events.data() + totalDone, + &AIOTimeout); + if (d > 0) + { + done[i] += d; + totalDone += d; } } } + } - void BatchReadFileAsync(std::vector>& handlers, AsyncReadRequest* readRequests, int num) + for (int i = totalQueued; i < totalDone; i++) + { + AsyncReadRequest *req = reinterpret_cast((events[i].data)); + if (nullptr != req) { - - CallOnAppropriateBatch(handlers, readRequests, num, &DiskIO::BatchReadFile); + req->m_callback(true); + } + } + return true; +} +#else +ULONGLONG GetCpuMasks(WORD group, DWORD numCpus) +{ + ULONGLONG masks = 0, mask = 1; + for (DWORD i = 0; i < numCpus; ++i) + { + masks |= mask; + mask <<= 1; + } - for (int i = 0; i < num; i++) { - AsyncReadRequest* readRequest = &(readRequests[i]); - - if (readRequest->m_success && readRequest->m_callback) - { - readRequest->m_callback(readRequest); - readRequest->m_callback = nullptr; - } - } + return masks; +} + +void SetThreadAffinity(int threadID, std::thread &thread, NumaStrategy socketStrategy, OrderStrategy idStrategy) +{ + WORD numGroups = GetActiveProcessorGroupCount(); + DWORD numCpus = GetActiveProcessorCount(0); + + GROUP_AFFINITY ga; + memset(&ga, 0, sizeof(ga)); + PROCESSOR_NUMBER pn; + memset(&pn, 0, sizeof(pn)); + + WORD group = (WORD)(threadID / numCpus); + pn.Number = (BYTE)(threadID % numCpus); + if (socketStrategy == NumaStrategy::SCATTER) + { + group = (WORD)(threadID % numGroups); + pn.Number = (BYTE)((threadID / numGroups) % numCpus); + } + + ga.Group = group; + ga.Mask = GetCpuMasks(group, numCpus); + BOOL res = SetThreadGroupAffinity(GetCurrentThread(), &ga, NULL); + if (!res) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, + "Failed SetThreadGroupAffinity for group %d and mask %I64x for thread %d.\n", ga.Group, ga.Mask, + threadID); + return; + } + pn.Group = group; + if (idStrategy == OrderStrategy::DESC) + { + pn.Number = (BYTE)(numCpus - 1 - pn.Number); + } + res = SetThreadIdealProcessorEx(GetCurrentThread(), &pn, NULL); + if (!res) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to set ideal processor for thread %d.\n", threadID); + return; + } + + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "numGroup:%d numCPUs:%d threadID:%d group:%d cpuid:%d\n", + // (int)(numGroups), (int)numCpus, threadID, (int)(group), (int)(pn.Number)); + YieldProcessor(); +} - CallOnAppropriateBatch(handlers, readRequests, num, &DiskIO::BatchCleanRequests); +bool BatchReadFileAsync(std::vector> &handlers, AsyncReadRequest *readRequests, int num) +{ + if (handlers.size() == 1) + { + if (handlers[0]->BatchReadFile(readRequests, num, MaxTimeout) < num) + return false; + } + else + { + int currFileId = 0, currReqStart = 0; + for (int i = 0; i < num; i++) + { + AsyncReadRequest *readRequest = &(readRequests[i]); + + int fileid = (readRequest->m_status >> 16); + if (fileid != currFileId) + { + if (handlers[currFileId]->BatchReadFile(readRequests + currReqStart, i - currReqStart, MaxTimeout) < + i - currReqStart) + return false; + currFileId = fileid; + currReqStart = i; + } } -#endif + if (currReqStart < num) + { + if (handlers[currFileId]->BatchReadFile(readRequests + currReqStart, num - currReqStart, MaxTimeout) < + num - currReqStart) + return false; + } + return true; } } +#endif +} // namespace Helper +} // namespace SPTAG diff --git a/AnnService/src/Helper/Base64Encode.cpp b/AnnService/src/Helper/Base64Encode.cpp index 5992fa5a3..c06c394e0 100644 --- a/AnnService/src/Helper/Base64Encode.cpp +++ b/AnnService/src/Helper/Base64Encode.cpp @@ -10,21 +10,17 @@ namespace { namespace Local { -const char c_encTable[] = -{ - 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', - 'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z', - '0','1','2','3','4','5','6','7','8','9','+','/' -}; +const char c_encTable[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; - -const std::uint8_t c_decTable[] = -{ +const std::uint8_t c_decTable[] = { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x00 - 0x0f 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0x10 - 0x1f 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 62, 64, 64, 64, 63, // 0x20 - 0x2f 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 64, 64, 64, 64, 64, 64, // 0x30 - 0x3f - 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 0x40 - 0x4f + 64, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 0x40 - 0x4f 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 64, 64, 64, 64, 64, // 0x50 - 0x5f 64, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, // 0x60 - 0x6f 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 64, 64, 64, 64, 64, // 0x70 - 0x7f @@ -38,14 +34,11 @@ const std::uint8_t c_decTable[] = 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, // 0xf0 - 0xff }; - const char c_paddingChar = '='; -} -} - +} // namespace Local +} // namespace -bool -Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen) +bool Base64::Encode(const std::uint8_t *p_in, std::size_t p_inLen, char *p_out, std::size_t &p_outLen) { using namespace Local; @@ -87,11 +80,9 @@ Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std:: return true; } - -bool -Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen) +bool Base64::Encode(const std::uint8_t *p_in, std::size_t p_inLen, std::ostream &p_out, std::size_t &p_outLen) { - using namespace Local; + using namespace Local; p_outLen = 0; while (p_inLen >= 3) @@ -133,9 +124,7 @@ Base64::Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_ou return true; } - -bool -Base64::Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen) +bool Base64::Decode(const char *p_in, std::size_t p_inLen, std::uint8_t *p_out, std::size_t &p_outLen) { using namespace Local; @@ -193,7 +182,6 @@ Base64::Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std:: return false; } - p_out[0] = (u0 << 2) | (u1 >> 4); ++p_outLen; if (c_paddingChar == p_in[2]) @@ -224,17 +212,12 @@ Base64::Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std:: return true; } - -std::size_t -Base64::CapacityForEncode(std::size_t p_inLen) +std::size_t Base64::CapacityForEncode(std::size_t p_inLen) { return ((p_inLen + 2) / 3) * 4; } - -std::size_t -Base64::CapacityForDecode(std::size_t p_inLen) +std::size_t Base64::CapacityForDecode(std::size_t p_inLen) { return (p_inLen / 4) * 3 + ((p_inLen % 4) * 2) / 3; } - diff --git a/AnnService/src/Helper/CommonHelper.cpp b/AnnService/src/Helper/CommonHelper.cpp index ed5d662c7..034c4a6ec 100644 --- a/AnnService/src/Helper/CommonHelper.cpp +++ b/AnnService/src/Helper/CommonHelper.cpp @@ -9,10 +9,9 @@ using namespace SPTAG; using namespace SPTAG::Helper; -void -StrUtils::ToLowerInPlace(std::string& p_str) +void StrUtils::ToLowerInPlace(std::string &p_str) { - for (char& ch : p_str) + for (char &ch : p_str) { if (std::isupper(ch)) { @@ -21,9 +20,7 @@ StrUtils::ToLowerInPlace(std::string& p_str) } } - -std::vector -StrUtils::SplitString(const std::string& p_str, const std::string& p_separator) +std::vector StrUtils::SplitString(const std::string &p_str, const std::string &p_separator) { std::vector ret; @@ -47,11 +44,8 @@ StrUtils::SplitString(const std::string& p_str, const std::string& p_separator) return ret; } - -std::pair -StrUtils::FindTrimmedSegment(const char* p_begin, - const char* p_end, - const std::function& p_isSkippedChar) +std::pair StrUtils::FindTrimmedSegment(const char *p_begin, const char *p_end, + const std::function &p_isSkippedChar) { while (p_begin < p_end) { @@ -76,9 +70,7 @@ StrUtils::FindTrimmedSegment(const char* p_begin, return std::make_pair(p_begin, p_end); } - -bool -StrUtils::StartsWith(const char* p_str, const char* p_prefix) +bool StrUtils::StartsWith(const char *p_str, const char *p_prefix) { if (nullptr == p_prefix) { @@ -103,9 +95,7 @@ StrUtils::StartsWith(const char* p_str, const char* p_prefix) return '\0' == *p_prefix; } - -bool -StrUtils::StrEqualIgnoreCase(const char* p_left, const char* p_right) +bool StrUtils::StrEqualIgnoreCase(const char *p_left, const char *p_right) { if (p_left == p_right) { @@ -117,8 +107,7 @@ StrUtils::StrEqualIgnoreCase(const char* p_left, const char* p_right) return false; } - auto tryConv = [](char p_ch) -> char - { + auto tryConv = [](char p_ch) -> char { if ('a' <= p_ch && p_ch <= 'z') { return p_ch - 32; @@ -141,15 +130,15 @@ StrUtils::StrEqualIgnoreCase(const char* p_left, const char* p_right) return *p_left == *p_right; } - -std::string -StrUtils::ReplaceAll(const std::string& orig, const std::string& from, const std::string& to) +std::string StrUtils::ReplaceAll(const std::string &orig, const std::string &from, const std::string &to) { std::string ret = orig; - if (from.empty()) return ret; + if (from.empty()) + return ret; size_t pos = 0; - while ((pos = ret.find(from, pos)) != std::string::npos) { + while ((pos = ret.find(from, pos)) != std::string::npos) + { ret.replace(pos, from.length(), to); pos += to.length(); } diff --git a/AnnService/src/Helper/Concurrent.cpp b/AnnService/src/Helper/Concurrent.cpp index cbb1bdb64..0278ab06c 100644 --- a/AnnService/src/Helper/Concurrent.cpp +++ b/AnnService/src/Helper/Concurrent.cpp @@ -6,20 +6,14 @@ using namespace SPTAG; using namespace SPTAG::Helper::Concurrent; -WaitSignal::WaitSignal() - : m_isWaiting(false), - m_unfinished(0) +WaitSignal::WaitSignal() : m_isWaiting(false), m_unfinished(0) { } - -WaitSignal::WaitSignal(std::uint32_t p_unfinished) - : m_isWaiting(false), - m_unfinished(p_unfinished) +WaitSignal::WaitSignal(std::uint32_t p_unfinished) : m_isWaiting(false), m_unfinished(p_unfinished) { } - WaitSignal::~WaitSignal() { std::lock_guard guard(m_mutex); @@ -29,9 +23,7 @@ WaitSignal::~WaitSignal() } } - -void -WaitSignal::Reset(std::uint32_t p_unfinished) +void WaitSignal::Reset(std::uint32_t p_unfinished) { std::lock_guard guard(m_mutex); if (m_isWaiting) @@ -43,9 +35,7 @@ WaitSignal::Reset(std::uint32_t p_unfinished) m_unfinished = p_unfinished; } - -void -WaitSignal::Wait() +void WaitSignal::Wait() { std::unique_lock lock(m_mutex); if (m_unfinished > 0) @@ -55,9 +45,7 @@ WaitSignal::Wait() } } - -void -WaitSignal::FinishOne() +void WaitSignal::FinishOne() { if (1 == m_unfinished.fetch_sub(1)) { diff --git a/AnnService/src/Helper/DynamicNeighbors.cpp b/AnnService/src/Helper/DynamicNeighbors.cpp index 7c0d29202..b4cf63052 100644 --- a/AnnService/src/Helper/DynamicNeighbors.cpp +++ b/AnnService/src/Helper/DynamicNeighbors.cpp @@ -1,22 +1,18 @@ #include "inc/Helper/DynamicNeighbors.h" #include "inc/Core/Common.h" +#include "inc/Core/VectorIndex.h" using namespace SPTAG::Helper; -DynamicNeighbors::DynamicNeighbors(const int* p_data, const int p_length) - : c_data(p_data), - c_length(p_length) +DynamicNeighbors::DynamicNeighbors(const int *p_data, const int p_length) : c_data(p_data), c_length(p_length) { } - -DynamicNeighbors:: ~DynamicNeighbors() +DynamicNeighbors::~DynamicNeighbors() { } - -int -DynamicNeighbors::operator[](const int p_id) const +int DynamicNeighbors::operator[](const int p_id) const { if (p_id < c_length && p_id >= 0) { @@ -26,54 +22,52 @@ DynamicNeighbors::operator[](const int p_id) const return -1; } - -int -DynamicNeighbors::Size() const +int DynamicNeighbors::Size() const { return c_length; } - -DynamicNeighborsSet::DynamicNeighborsSet(const char* p_filePath) +DynamicNeighborsSet::DynamicNeighborsSet(const char *p_filePath) { auto fp = f_createIO(); - if (fp == nullptr || !fp->Initialize(p_filePath, std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed open graph file: %s\n", p_filePath); + if (fp == nullptr || !fp->Initialize(p_filePath, std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open graph file: %s\n", p_filePath); throw std::runtime_error("Opening graph file failed"); } - if (fp->ReadBinary(sizeof(m_vectorCount), (char*)&m_vectorCount) != sizeof(m_vectorCount)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read DynamicNeighborsSet!\n"); + if (fp->ReadBinary(sizeof(m_vectorCount), (char *)&m_vectorCount) != sizeof(m_vectorCount)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read DynamicNeighborsSet!\n"); throw std::runtime_error("reading DynamicNeighborsSet failed"); } m_neighborOffset.reset(new int[m_vectorCount + 1]); m_neighborOffset[0] = 0; - if (fp->ReadBinary(m_vectorCount * sizeof(int), (char*)(m_neighborOffset.get() + 1)) != m_vectorCount * sizeof(int)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read DynamicNeighborsSet!\n"); + if (fp->ReadBinary(m_vectorCount * sizeof(int), (char *)(m_neighborOffset.get() + 1)) != + m_vectorCount * sizeof(int)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read DynamicNeighborsSet!\n"); throw std::runtime_error("reading DynamicNeighborsSet failed"); } size_t graphSize = static_cast(m_neighborOffset[m_vectorCount]); - LOG(Helper::LogLevel::LL_Error, "Vector count: %d, Graph size: %zu\n", m_vectorCount, graphSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Vector count: %d, Graph size: %zu\n", m_vectorCount, graphSize); m_data.reset(new int[graphSize]); - auto readSize = fp->ReadBinary(graphSize * sizeof(int), (char*)(m_data.get())); - if (readSize != graphSize * sizeof(int)) { - LOG(Helper::LogLevel::LL_Error, - "Failed read graph: size not match, expected %zu, actually %zu\n", - static_cast(graphSize * sizeof(int)), - static_cast(readSize)); + auto readSize = fp->ReadBinary(graphSize * sizeof(int), (char *)(m_data.get())); + if (readSize != graphSize * sizeof(int)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed read graph: size not match, expected %zu, actually %zu\n", + static_cast(graphSize * sizeof(int)), static_cast(readSize)); throw std::runtime_error("Graph size doesn't match expected"); } } - DynamicNeighborsSet::~DynamicNeighborsSet() { } - DynamicNeighbors DynamicNeighborsSet::operator[](const int p_id) const { if (p_id >= m_vectorCount) @@ -82,5 +76,5 @@ DynamicNeighbors DynamicNeighborsSet::operator[](const int p_id) const } return DynamicNeighbors(m_data.get() + static_cast(m_neighborOffset[p_id]), - m_neighborOffset[p_id + 1] - m_neighborOffset[p_id]); + m_neighborOffset[p_id + 1] - m_neighborOffset[p_id]); } diff --git a/AnnService/src/Helper/SimpleIniReader.cpp b/AnnService/src/Helper/SimpleIniReader.cpp index cbc2eb551..d52d36860 100644 --- a/AnnService/src/Helper/SimpleIniReader.cpp +++ b/AnnService/src/Helper/SimpleIniReader.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "inc/Helper/SimpleIniReader.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/CommonHelper.h" #include @@ -12,17 +13,14 @@ using namespace SPTAG::Helper; const IniReader::ParameterValueMap IniReader::c_emptyParameters; - IniReader::IniReader() { } - IniReader::~IniReader() { } - ErrorCode IniReader::LoadIni(std::shared_ptr p_input) { std::uint64_t c_bufferSize = 1 << 16; @@ -37,14 +35,12 @@ ErrorCode IniReader::LoadIni(std::shared_ptr p_input) m_parameters.emplace(currSection, currParamMap); } - auto isSpace = [](char p_ch) -> bool - { - return std::isspace(p_ch) != 0; - }; + auto isSpace = [](char p_ch) -> bool { return std::isspace(p_ch) != 0; }; while (true) { - if (!p_input->ReadString(c_bufferSize, line)) break; + if (!p_input->ReadString(c_bufferSize, line)) + break; std::uint64_t len = 0; while (len < c_bufferSize && line[len] != '\0') @@ -97,7 +93,7 @@ ErrorCode IniReader::LoadIni(std::shared_ptr p_input) else { // Parameter Value Pair. - const char* equalSignLoc = nonSpaceSeg.first; + const char *equalSignLoc = nonSpaceSeg.first; while (equalSignLoc < nonSpaceSeg.second && '=' != *equalSignLoc) { ++equalSignLoc; @@ -132,27 +128,22 @@ ErrorCode IniReader::LoadIni(std::shared_ptr p_input) return ErrorCode::Success; } - -ErrorCode -IniReader::LoadIniFile(const std::string& p_iniFilePath) +ErrorCode IniReader::LoadIniFile(const std::string &p_iniFilePath) { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(p_iniFilePath.c_str(), std::ios::in)) return ErrorCode::FailedOpenFile; + if (ptr == nullptr || !ptr->Initialize(p_iniFilePath.c_str(), std::ios::in)) + return ErrorCode::FailedOpenFile; return LoadIni(ptr); } - -bool -IniReader::DoesSectionExist(const std::string& p_section) const +bool IniReader::DoesSectionExist(const std::string &p_section) const { std::string section(p_section); StrUtils::ToLowerInPlace(section); return m_parameters.count(section) != 0; } - -bool -IniReader::DoesParameterExist(const std::string& p_section, const std::string& p_param) const +bool IniReader::DoesParameterExist(const std::string &p_section, const std::string &p_param) const { std::string name(p_section); StrUtils::ToLowerInPlace(name); @@ -162,7 +153,7 @@ IniReader::DoesParameterExist(const std::string& p_section, const std::string& p return false; } - const auto& paramMap = iter->second; + const auto ¶mMap = iter->second; if (paramMap == nullptr) { return false; @@ -173,9 +164,7 @@ IniReader::DoesParameterExist(const std::string& p_section, const std::string& p return paramMap->count(name) != 0; } - -bool -IniReader::GetRawValue(const std::string& p_section, const std::string& p_param, std::string& p_value) const +bool IniReader::GetRawValue(const std::string &p_section, const std::string &p_param, std::string &p_value) const { std::string name(p_section); StrUtils::ToLowerInPlace(name); @@ -185,7 +174,7 @@ IniReader::GetRawValue(const std::string& p_section, const std::string& p_param, return false; } - const auto& paramMap = sectionIter->second; + const auto ¶mMap = sectionIter->second; if (paramMap == nullptr) { return false; @@ -203,9 +192,7 @@ IniReader::GetRawValue(const std::string& p_section, const std::string& p_param, return true; } - -const IniReader::ParameterValueMap& -IniReader::GetParameters(const std::string& p_section) const +const IniReader::ParameterValueMap &IniReader::GetParameters(const std::string &p_section) const { std::string name(p_section); StrUtils::ToLowerInPlace(name); @@ -218,13 +205,12 @@ IniReader::GetParameters(const std::string& p_section) const return *(sectionIter->second); } -void -IniReader::SetParameter(const std::string& p_section, const std::string& p_param, const std::string& p_val) +void IniReader::SetParameter(const std::string &p_section, const std::string &p_param, const std::string &p_val) { std::string name(p_section); StrUtils::ToLowerInPlace(name); auto sectionIter = m_parameters.find(name); - if (sectionIter == m_parameters.cend() || sectionIter->second == nullptr) + if (sectionIter == m_parameters.cend() || sectionIter->second == nullptr) { m_parameters[name] = std::shared_ptr(new ParameterValueMap); } diff --git a/AnnService/src/Helper/VectorSetReader.cpp b/AnnService/src/Helper/VectorSetReader.cpp index 08c763f99..00089fe54 100644 --- a/AnnService/src/Helper/VectorSetReader.cpp +++ b/AnnService/src/Helper/VectorSetReader.cpp @@ -9,9 +9,10 @@ using namespace SPTAG; using namespace SPTAG::Helper; - -ReaderOptions::ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, VectorFileType p_fileType, std::string p_vectorDelimiter, std::uint32_t p_threadNum, bool p_normalized) - : m_inputValueType(p_valueType), m_dimension(p_dimension), m_inputFileType(p_fileType), m_vectorDelimiter(p_vectorDelimiter), m_threadNum(p_threadNum), m_normalized(p_normalized) +ReaderOptions::ReaderOptions(VectorValueType p_valueType, DimensionType p_dimension, VectorFileType p_fileType, + std::string p_vectorDelimiter, std::uint32_t p_threadNum, bool p_normalized) + : m_inputValueType(p_valueType), m_dimension(p_dimension), m_inputFileType(p_fileType), + m_vectorDelimiter(p_vectorDelimiter), m_threadNum(p_threadNum), m_normalized(p_normalized) { AddOptionalOption(m_threadNum, "-t", "--thread", "Thread Number."); AddOptionalOption(m_vectorDelimiter, "-dl", "--delimiter", "Vector delimiter."); @@ -21,36 +22,31 @@ ReaderOptions::ReaderOptions(VectorValueType p_valueType, DimensionType p_dimens AddRequiredOption(m_inputFileType, "-f", "--filetype", "Input file type (DEFAULT, TXT, XVEC). Default is DEFAULT."); } - ReaderOptions::~ReaderOptions() { } - -VectorSetReader::VectorSetReader(std::shared_ptr p_options) - : m_options(p_options) +VectorSetReader::VectorSetReader(std::shared_ptr p_options) : m_options(p_options) { } - -VectorSetReader:: ~VectorSetReader() +VectorSetReader::~VectorSetReader() { } - -std::shared_ptr -VectorSetReader::CreateInstance(std::shared_ptr p_options) +std::shared_ptr VectorSetReader::CreateInstance(std::shared_ptr p_options) { - if (p_options->m_inputFileType == VectorFileType::DEFAULT) { + if (p_options->m_inputFileType == VectorFileType::DEFAULT) + { return std::make_shared(p_options); } - else if (p_options->m_inputFileType == VectorFileType::TXT) { + else if (p_options->m_inputFileType == VectorFileType::TXT) + { return std::make_shared(p_options); } - else if (p_options->m_inputFileType == VectorFileType::XVEC) { + else if (p_options->m_inputFileType == VectorFileType::XVEC) + { return std::make_shared(p_options); } return nullptr; } - - diff --git a/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp b/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp index b471098e5..eeb89a5a2 100644 --- a/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp +++ b/AnnService/src/Helper/VectorSetReaders/DefaultReader.cpp @@ -2,82 +2,82 @@ // Licensed under the MIT License. #include "inc/Helper/VectorSetReaders/DefaultReader.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/CommonHelper.h" using namespace SPTAG; using namespace SPTAG::Helper; -DefaultVectorReader::DefaultVectorReader(std::shared_ptr p_options) - : VectorSetReader(p_options) +DefaultVectorReader::DefaultVectorReader(std::shared_ptr p_options) : VectorSetReader(p_options) { m_vectorOutput = ""; m_metadataConentOutput = ""; m_metadataIndexOutput = ""; } - DefaultVectorReader::~DefaultVectorReader() { } - -ErrorCode -DefaultVectorReader::LoadFile(const std::string& p_filePaths) +ErrorCode DefaultVectorReader::LoadFile(const std::string &p_filePaths) { - const auto& files = SPTAG::Helper::StrUtils::SplitString(p_filePaths, ","); + const auto &files = SPTAG::Helper::StrUtils::SplitString(p_filePaths, ","); m_vectorOutput = files[0]; - if (files.size() >= 3) { + if (files.size() >= 3) + { m_metadataConentOutput = files[1]; m_metadataIndexOutput = files[2]; } return ErrorCode::Success; } - -std::shared_ptr -DefaultVectorReader::GetVectorSet(SizeType start, SizeType end) const +std::shared_ptr DefaultVectorReader::GetVectorSet(SizeType start, SizeType end) const { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); + if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); throw std::runtime_error("Failed read file"); } SizeType row; DimensionType col; - if (ptr->ReadBinary(sizeof(SizeType), (char*)&row) != sizeof(SizeType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(SizeType), (char *)&row) != sizeof(SizeType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } - if (ptr->ReadBinary(sizeof(DimensionType), (char*)&col) != sizeof(DimensionType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(DimensionType), (char *)&col) != sizeof(DimensionType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } - - if (start > row) start = row; - if (end < 0 || end > row) end = row; - std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; + + if (start > row) + start = row; + if (end < 0 || end > row) + end = row; + std::uint64_t totalRecordVectorBytes = + ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; ByteArray vectorSet; - if (totalRecordVectorBytes > 0) { + if (totalRecordVectorBytes > 0) + { vectorSet = ByteArray::Alloc(totalRecordVectorBytes); - char* vecBuf = reinterpret_cast(vectorSet.Data()); - std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + sizeof(SizeType) + sizeof(DimensionType); - if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + char *vecBuf = reinterpret_cast(vectorSet.Data()); + std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + + sizeof(SizeType) + sizeof(DimensionType); + if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } } - LOG(Helper::LogLevel::LL_Info, "Load Vector(%d,%d)\n", end - start, col); - return std::make_shared(vectorSet, - m_options->m_inputValueType, - col, - end - start); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load Vector(%d,%d)\n", end - start, col); + return std::make_shared(vectorSet, m_options->m_inputValueType, col, end - start); } - -std::shared_ptr -DefaultVectorReader::GetMetadataSet() const +std::shared_ptr DefaultVectorReader::GetMetadataSet() const { if (fileexists(m_metadataIndexOutput.c_str()) && fileexists(m_metadataConentOutput.c_str())) return std::shared_ptr(new FileMetadataSet(m_metadataConentOutput, m_metadataIndexOutput)); diff --git a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp index e15c77f93..8bf9fb55c 100644 --- a/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp +++ b/AnnService/src/Helper/VectorSetReaders/TxtReader.cpp @@ -2,20 +2,17 @@ // Licensed under the MIT License. #include "inc/Helper/VectorSetReaders/TxtReader.h" -#include "inc/Helper/StringConvert.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/CommonHelper.h" - +#include "inc/Helper/StringConvert.h" #include using namespace SPTAG; using namespace SPTAG::Helper; TxtVectorReader::TxtVectorReader(std::shared_ptr p_options) - : VectorSetReader(p_options), - m_subTaskBlocksize(0) + : VectorSetReader(p_options), m_subTaskBlocksize(0) { - omp_set_num_threads(m_options->m_threadNum); - std::string tempFolder("tempfolder"); if (!direxists(tempFolder.c_str())) { @@ -30,7 +27,6 @@ TxtVectorReader::TxtVectorReader(std::shared_ptr p_options) m_metadataIndexOutput = tempFolder + "metadataindex.bin." + randstr; } - TxtVectorReader::~TxtVectorReader() { if (fileexists(m_vectorOutput.c_str())) @@ -49,29 +45,28 @@ TxtVectorReader::~TxtVectorReader() } } - -ErrorCode -TxtVectorReader::LoadFile(const std::string& p_filePaths) +ErrorCode TxtVectorReader::LoadFile(const std::string &p_filePaths) { - const auto& files = GetFileSizes(p_filePaths); + const auto &files = GetFileSizes(p_filePaths); std::vector> subWorks; subWorks.reserve(files.size() * m_options->m_threadNum); m_subTaskCount = 0; - for (const auto& fileInfo : files) + for (const auto &fileInfo : files) { if (fileInfo.second == (std::numeric_limits::max)()) { - LOG(Helper::LogLevel::LL_Error, "File %s not exists or can't access.\n", fileInfo.first.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File %s not exists or can't access.\n", fileInfo.first.c_str()); return ErrorCode::FailedOpenFile; } - std::uint32_t fileTaskCount = 1; + std::uint32_t fileTaskCount = 0; std::size_t blockSize = m_subTaskBlocksize; if (0 == blockSize) { fileTaskCount = m_options->m_threadNum; -if(fileTaskCount == 0) fileTaskCount = 1; + if (fileTaskCount == 0) + fileTaskCount = 1; blockSize = (fileInfo.second + fileTaskCount - 1) / fileTaskCount; } else @@ -81,12 +76,8 @@ if(fileTaskCount == 0) fileTaskCount = 1; for (std::uint32_t i = 0; i < fileTaskCount; ++i) { - subWorks.emplace_back(std::bind(&TxtVectorReader::LoadFileInternal, - this, - fileInfo.first, - m_subTaskCount++, - i, - blockSize)); + subWorks.emplace_back( + std::bind(&TxtVectorReader::LoadFileInternal, this, fileInfo.first, m_subTaskCount++, i, blockSize)); } } @@ -97,77 +88,94 @@ if(fileTaskCount == 0) fileTaskCount = 1; m_waitSignal.Reset(m_subTaskCount); -#pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < (int64_t)subWorks.size(); i++) + std::vector mythreads; + mythreads.reserve(m_options->m_threadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < m_options->m_threadNum; tid++) { - ErrorCode code = subWorks[i](); - if (ErrorCode::Success != code) - { - throw std::runtime_error("LoadFileInternal failed"); - } - + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < subWorks.size()) + { + ErrorCode code = subWorks[i](); + if (ErrorCode::Success != code) + { + throw std::runtime_error("LoadFileInternal failed"); + } + } + else + { + return; + } + } + }); } - + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); m_waitSignal.Wait(); return MergeData(); } - -std::shared_ptr -TxtVectorReader::GetVectorSet(SizeType start, SizeType end) const +std::shared_ptr TxtVectorReader::GetVectorSet(SizeType start, SizeType end) const { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); + if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); throw std::runtime_error("Failed to read vectorset file"); } SizeType row; DimensionType col; - if (ptr->ReadBinary(sizeof(SizeType), (char*)&row) != sizeof(SizeType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(SizeType), (char *)&row) != sizeof(SizeType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed to read vectorset file"); } - if (ptr->ReadBinary(sizeof(DimensionType), (char*)&col) != sizeof(DimensionType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(DimensionType), (char *)&col) != sizeof(DimensionType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed to read vectorset file"); } - if (start > row) start = row; - if (end < 0 || end > row) end = row; - std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; + if (start > row) + start = row; + if (end < 0 || end > row) + end = row; + std::uint64_t totalRecordVectorBytes = + ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; ByteArray vectorSet; - if (totalRecordVectorBytes > 0) { + if (totalRecordVectorBytes > 0) + { vectorSet = ByteArray::Alloc(totalRecordVectorBytes); - char* vecBuf = reinterpret_cast(vectorSet.Data()); - std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + +sizeof(SizeType) + sizeof(DimensionType); - if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + char *vecBuf = reinterpret_cast(vectorSet.Data()); + std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + + +sizeof(SizeType) + sizeof(DimensionType); + if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed to read vectorset file"); } } - return std::shared_ptr(new BasicVectorSet(vectorSet, - m_options->m_inputValueType, - col, - end - start)); + return std::shared_ptr(new BasicVectorSet(vectorSet, m_options->m_inputValueType, col, end - start)); } - -std::shared_ptr -TxtVectorReader::GetMetadataSet() const +std::shared_ptr TxtVectorReader::GetMetadataSet() const { if (fileexists(m_metadataIndexOutput.c_str()) && fileexists(m_metadataConentOutput.c_str())) return std::shared_ptr(new FileMetadataSet(m_metadataConentOutput, m_metadataIndexOutput)); return nullptr; } - -ErrorCode -TxtVectorReader::LoadFileInternal(const std::string& p_filePath, - std::uint32_t p_subTaskID, - std::uint32_t p_fileBlockID, - std::size_t p_fileBlockSize) +ErrorCode TxtVectorReader::LoadFileInternal(const std::string &p_filePath, std::uint32_t p_subTaskID, + std::uint32_t p_fileBlockID, std::size_t p_fileBlockSize) { std::uint64_t lineBufferSize = 1 << 16; std::unique_ptr currentLine(new char[lineBufferSize]); @@ -177,24 +185,30 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, std::size_t totalRead = 0; std::streamoff startpos = p_fileBlockID * p_fileBlockSize; - std::shared_ptr input = f_createIO(), output = f_createIO(), meta = f_createIO(), metaIndex = f_createIO(); + std::shared_ptr input = f_createIO(), output = f_createIO(), meta = f_createIO(), + metaIndex = f_createIO(); if (input == nullptr || !input->Initialize(p_filePath.c_str(), std::ios::in | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n",p_filePath.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", p_filePath.c_str()); return ErrorCode::FailedOpenFile; } - LOG(Helper::LogLevel::LL_Info, "Begin Subtask: %u, start offset position: %lld\n", p_subTaskID, startpos); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Subtask: %u, start offset position: %lld\n", p_subTaskID, startpos); std::string subFileSuffix("_"); subFileSuffix += std::to_string(p_subTaskID); subFileSuffix += ".tmp"; - if (output == nullptr || !output->Initialize((m_vectorOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out) || - meta == nullptr || !meta->Initialize((m_metadataConentOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out) || - metaIndex == nullptr || !metaIndex->Initialize((m_metadataIndexOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out)) + if (output == nullptr || + !output->Initialize((m_vectorOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out) || + meta == nullptr || + !meta->Initialize((m_metadataConentOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out) || + metaIndex == nullptr || + !metaIndex->Initialize((m_metadataIndexOutput + subFileSuffix).c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Unable to create files: %s %s %s\n", (m_vectorOutput + subFileSuffix).c_str(), (m_metadataConentOutput + subFileSuffix).c_str(), (m_metadataIndexOutput + subFileSuffix).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to create files: %s %s %s\n", + (m_vectorOutput + subFileSuffix).c_str(), (m_metadataConentOutput + subFileSuffix).c_str(), + (m_metadataIndexOutput + subFileSuffix).c_str()); return ErrorCode::FailedCreateFile; } @@ -210,7 +224,8 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, while (totalRead <= p_fileBlockSize) { std::uint64_t lineLength = input->ReadString(lineBufferSize, currentLine); - if (lineLength == 0) break; + if (lineLength == 0) + break; totalRead += lineLength; std::size_t tabIndex = lineLength - 1; @@ -221,17 +236,18 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, if (0 == tabIndex && currentLine[tabIndex] != '\t') { - LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot parsing line:%s\n", p_subTaskID, currentLine.get()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot parsing line:%s\n", p_subTaskID, + currentLine.get()); return ErrorCode::FailedParseValue; } bool parseSuccess = false; switch (m_options->m_inputValueType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); \ - break; \ +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + parseSuccess = TranslateVector(currentLine.get() + tabIndex + 1, reinterpret_cast(vector.get())); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType @@ -243,21 +259,25 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, if (!parseSuccess) { - LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot parsing vector:%s\n", p_subTaskID, currentLine.get()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot parsing vector:%s\n", p_subTaskID, + currentLine.get()); return ErrorCode::FailedParseValue; } ++recordCount; - if (output->WriteBinary(vectorByteSize, (char*)vector.get()) != vectorByteSize || + if (output->WriteBinary(vectorByteSize, (char *)vector.get()) != vectorByteSize || meta->WriteBinary(tabIndex, currentLine.get()) != tabIndex || - metaIndex->WriteBinary(sizeof(metaOffset), (const char*)&metaOffset) != sizeof(metaOffset)) { - LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot write line:%s\n", p_subTaskID, currentLine.get()); + metaIndex->WriteBinary(sizeof(metaOffset), (const char *)&metaOffset) != sizeof(metaOffset)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot write line:%s\n", p_subTaskID, + currentLine.get()); return ErrorCode::DiskIOFail; } metaOffset += tabIndex; } - if (metaIndex->WriteBinary(sizeof(metaOffset), (const char*)&metaOffset) != sizeof(metaOffset)) { - LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot write final offset!\n", p_subTaskID); + if (metaIndex->WriteBinary(sizeof(metaOffset), (const char *)&metaOffset) != sizeof(metaOffset)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Subtask: %u cannot write final offset!\n", p_subTaskID); return ErrorCode::DiskIOFail; } @@ -269,33 +289,36 @@ TxtVectorReader::LoadFileInternal(const std::string& p_filePath, return ErrorCode::Success; } - -ErrorCode -TxtVectorReader::MergeData() +ErrorCode TxtVectorReader::MergeData() { const std::size_t bufferSize = 1 << 30; const std::size_t bufferSizeTrim64 = (bufferSize / sizeof(std::uint64_t)) * sizeof(std::uint64_t); - std::shared_ptr input = f_createIO(), output = f_createIO(), meta = f_createIO(), metaIndex = f_createIO(); + std::shared_ptr input = f_createIO(), output = f_createIO(), meta = f_createIO(), + metaIndex = f_createIO(); if (output == nullptr || !output->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::out) || meta == nullptr || !meta->Initialize(m_metadataConentOutput.c_str(), std::ios::binary | std::ios::out) || metaIndex == nullptr || !metaIndex->Initialize(m_metadataIndexOutput.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Unable to create files: %s %s %s\n", m_vectorOutput.c_str(), m_metadataConentOutput.c_str(), m_metadataIndexOutput.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to create files: %s %s %s\n", m_vectorOutput.c_str(), + m_metadataConentOutput.c_str(), m_metadataIndexOutput.c_str()); return ErrorCode::FailedCreateFile; } std::unique_ptr bufferHolder(new char[bufferSize]); - char* buf = bufferHolder.get(); + char *buf = bufferHolder.get(); SizeType totalRecordCount = m_totalRecordCount; - if (output->WriteBinary(sizeof(totalRecordCount), (char*)(&totalRecordCount)) != sizeof(totalRecordCount)) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); + if (output->WriteBinary(sizeof(totalRecordCount), (char *)(&totalRecordCount)) != sizeof(totalRecordCount)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); return ErrorCode::DiskIOFail; } - if (output->WriteBinary(sizeof(m_options->m_dimension), (char*)&(m_options->m_dimension)) != sizeof(m_options->m_dimension)) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); + if (output->WriteBinary(sizeof(m_options->m_dimension), (char *)&(m_options->m_dimension)) != + sizeof(m_options->m_dimension)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); return ErrorCode::DiskIOFail; } @@ -308,16 +331,19 @@ TxtVectorReader::MergeData() if (input == nullptr || !input->Initialize(file.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); return ErrorCode::FailedOpenFile; } - std::uint64_t readSize; - while ((readSize = input->ReadBinary(bufferSize, bufferHolder.get()))) { - if (output->WriteBinary(readSize, bufferHolder.get()) != readSize) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); + std::uint64_t readSize = input->ReadBinary(bufferSize, bufferHolder.get()); + while (readSize != 0) + { + if (output->WriteBinary(readSize, bufferHolder.get()) != readSize) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_vectorOutput.c_str()); return ErrorCode::DiskIOFail; } + readSize = input->ReadBinary(bufferSize, bufferHolder.get()); } input->ShutDown(); remove(file.c_str()); @@ -332,23 +358,27 @@ TxtVectorReader::MergeData() if (input == nullptr || !input->Initialize(file.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); return ErrorCode::FailedOpenFile; } - std::uint64_t readSize; - while ((readSize = input->ReadBinary(bufferSize, bufferHolder.get()))) { - if (meta->WriteBinary(readSize, bufferHolder.get()) != readSize) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataConentOutput.c_str()); + std::uint64_t readSize = input->ReadBinary(bufferSize, bufferHolder.get()); + while (readSize != 0) + { + if (meta->WriteBinary(readSize, bufferHolder.get()) != readSize) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataConentOutput.c_str()); return ErrorCode::DiskIOFail; } + readSize = input->ReadBinary(bufferSize, bufferHolder.get()); } input->ShutDown(); remove(file.c_str()); } - if (metaIndex->WriteBinary(sizeof(totalRecordCount), (char*)(&totalRecordCount)) != sizeof(totalRecordCount)) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); + if (metaIndex->WriteBinary(sizeof(totalRecordCount), (char *)(&totalRecordCount)) != sizeof(totalRecordCount)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); return ErrorCode::DiskIOFail; } @@ -362,55 +392,57 @@ TxtVectorReader::MergeData() if (input == nullptr || !input->Initialize(file.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to open file: %s\n", file.c_str()); return ErrorCode::FailedOpenFile; } for (SizeType remains = m_subTaskRecordCount[i]; remains > 0;) { std::size_t readBytesCount = min(remains * sizeof(std::uint64_t), bufferSizeTrim64); - if (input->ReadBinary(readBytesCount, buf) != readBytesCount) { - LOG(Helper::LogLevel::LL_Error, "Unable to read file: %s\n", file.c_str()); + if (input->ReadBinary(readBytesCount, buf) != readBytesCount) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to read file: %s\n", file.c_str()); return ErrorCode::DiskIOFail; } - std::uint64_t* offset = reinterpret_cast(buf); - for (std::uint64_t i = 0; i < readBytesCount / sizeof(std::uint64_t); ++i) + std::uint64_t *offset = reinterpret_cast(buf); + for (std::uint64_t j = 0; j < readBytesCount / sizeof(std::uint64_t); ++j) { - offset[i] += totalOffset; + offset[j] += totalOffset; } - if (metaIndex->WriteBinary(readBytesCount, buf) != readBytesCount) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); + if (metaIndex->WriteBinary(readBytesCount, buf) != readBytesCount) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); return ErrorCode::DiskIOFail; } remains -= static_cast(readBytesCount / sizeof(std::uint64_t)); } - if (input->ReadBinary(sizeof(std::uint64_t), buf) != sizeof(std::uint64_t)) { - LOG(Helper::LogLevel::LL_Error, "Unable to read file: %s\n", file.c_str()); + if (input->ReadBinary(sizeof(std::uint64_t), buf) != sizeof(std::uint64_t)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to read file: %s\n", file.c_str()); return ErrorCode::DiskIOFail; } - totalOffset += *(reinterpret_cast(buf)); + totalOffset += *(reinterpret_cast(buf)); input->ShutDown(); remove(file.c_str()); } - if (metaIndex->WriteBinary(sizeof(totalOffset), (char*)&totalOffset) != sizeof(totalOffset)) { - LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); + if (metaIndex->WriteBinary(sizeof(totalOffset), (char *)&totalOffset) != sizeof(totalOffset)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Unable to write file: %s\n", m_metadataIndexOutput.c_str()); return ErrorCode::DiskIOFail; } return ErrorCode::Success; } - -std::vector -TxtVectorReader::GetFileSizes(const std::string& p_filePaths) +std::vector TxtVectorReader::GetFileSizes(const std::string &p_filePaths) { - const auto& files = Helper::StrUtils::SplitString(p_filePaths, ","); + const auto &files = Helper::StrUtils::SplitString(p_filePaths, ","); std::vector res; res.reserve(files.size()); - for (const auto& filePath : files) + for (const auto &filePath : files) { if (!fileexists(filePath.c_str())) { @@ -430,5 +462,3 @@ TxtVectorReader::GetFileSizes(const std::string& p_filePaths) return res; } - - diff --git a/AnnService/src/Helper/VectorSetReaders/XvecReader.cpp b/AnnService/src/Helper/VectorSetReaders/XvecReader.cpp index 43387d153..98bb95db0 100644 --- a/AnnService/src/Helper/VectorSetReaders/XvecReader.cpp +++ b/AnnService/src/Helper/VectorSetReaders/XvecReader.cpp @@ -2,15 +2,14 @@ // Licensed under the MIT License. #include "inc/Helper/VectorSetReaders/XvecReader.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/CommonHelper.h" - #include using namespace SPTAG; using namespace SPTAG::Helper; -XvecVectorReader::XvecVectorReader(std::shared_ptr p_options) - : VectorSetReader(p_options) +XvecVectorReader::XvecVectorReader(std::shared_ptr p_options) : VectorSetReader(p_options) { std::string tempFolder("tempfolder"); if (!direxists(tempFolder.c_str())) @@ -21,7 +20,6 @@ XvecVectorReader::XvecVectorReader(std::shared_ptr p_options) m_vectorOutput = tempFolder + FolderSep + "vectorset.bin." + std::to_string(std::rand()); } - XvecVectorReader::~XvecVectorReader() { if (fileexists(m_vectorOutput.c_str())) @@ -30,36 +28,41 @@ XvecVectorReader::~XvecVectorReader() } } - -ErrorCode -XvecVectorReader::LoadFile(const std::string& p_filePaths) +ErrorCode XvecVectorReader::LoadFile(const std::string &p_filePaths) { - const auto& files = Helper::StrUtils::SplitString(p_filePaths, ","); + const auto &files = Helper::StrUtils::SplitString(p_filePaths, ","); auto fp = f_createIO(); - if (fp == nullptr || !fp->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::out)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write file: %s \n", m_vectorOutput.c_str()); + if (fp == nullptr || !fp->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::out)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write file: %s \n", m_vectorOutput.c_str()); return ErrorCode::FailedCreateFile; } SizeType vectorCount = 0; - IOBINARY(fp, WriteBinary, sizeof(vectorCount), (char*)&vectorCount); - IOBINARY(fp, WriteBinary, sizeof(m_options->m_dimension), (char*)&(m_options->m_dimension)); - + IOBINARY(fp, WriteBinary, sizeof(vectorCount), (char *)&vectorCount); + IOBINARY(fp, WriteBinary, sizeof(m_options->m_dimension), (char *)&(m_options->m_dimension)); + size_t vectorDataSize = GetValueTypeSize(m_options->m_inputValueType) * m_options->m_dimension; std::unique_ptr buffer(new char[vectorDataSize]); for (std::string file : files) { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(file.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file: %s \n", file.c_str()); + if (ptr == nullptr || !ptr->Initialize(file.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file: %s \n", file.c_str()); return ErrorCode::FailedOpenFile; } while (true) { DimensionType dim; - if (ptr->ReadBinary(sizeof(DimensionType), (char*)&dim) == 0) break; + if (ptr->ReadBinary(sizeof(DimensionType), (char *)&dim) == 0) + break; - if (dim != m_options->m_dimension) { - LOG(Helper::LogLevel::LL_Error, "Xvec file %s has No.%d vector whose dims are not as many as expected. Expected: %d, Fact: %d\n", file.c_str(), vectorCount, m_options->m_dimension, dim); + if (dim != m_options->m_dimension) + { + SPTAGLIB_LOG( + Helper::LogLevel::LL_Error, + "Xvec file %s has No.%d vector whose dims are not as many as expected. Expected: %d, Fact: %d\n", + file.c_str(), vectorCount, m_options->m_dimension, dim); return ErrorCode::DimensionSizeMismatch; } IOBINARY(ptr, ReadBinary, vectorDataSize, buffer.get()); @@ -67,53 +70,55 @@ XvecVectorReader::LoadFile(const std::string& p_filePaths) vectorCount++; } } - IOBINARY(fp, WriteBinary, sizeof(vectorCount), (char*)&vectorCount, 0); + IOBINARY(fp, WriteBinary, sizeof(vectorCount), (char *)&vectorCount, 0); return ErrorCode::Success; } - -std::shared_ptr -XvecVectorReader::GetVectorSet(SizeType start, SizeType end) const +std::shared_ptr XvecVectorReader::GetVectorSet(SizeType start, SizeType end) const { auto ptr = f_createIO(); - if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); + if (ptr == nullptr || !ptr->Initialize(m_vectorOutput.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s.\n", m_vectorOutput.c_str()); throw std::runtime_error("Failed read file"); } SizeType row; DimensionType col; - if (ptr->ReadBinary(sizeof(SizeType), (char*)&row) != sizeof(SizeType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(SizeType), (char *)&row) != sizeof(SizeType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } - if (ptr->ReadBinary(sizeof(DimensionType), (char*)&col) != sizeof(DimensionType)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + if (ptr->ReadBinary(sizeof(DimensionType), (char *)&col) != sizeof(DimensionType)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } - if (start > row) start = row; - if (end < 0 || end > row) end = row; - std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; + if (start > row) + start = row; + if (end < 0 || end > row) + end = row; + std::uint64_t totalRecordVectorBytes = + ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * (end - start) * col; ByteArray vectorSet; - if (totalRecordVectorBytes > 0) { + if (totalRecordVectorBytes > 0) + { vectorSet = ByteArray::Alloc(totalRecordVectorBytes); - char* vecBuf = reinterpret_cast(vectorSet.Data()); - std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + +sizeof(SizeType) + sizeof(DimensionType); - if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) { - LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); + char *vecBuf = reinterpret_cast(vectorSet.Data()); + std::uint64_t offset = ((std::uint64_t)GetValueTypeSize(m_options->m_inputValueType)) * start * col + + +sizeof(SizeType) + sizeof(DimensionType); + if (ptr->ReadBinary(totalRecordVectorBytes, vecBuf, offset) != totalRecordVectorBytes) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read VectorSet!\n"); throw std::runtime_error("Failed read file"); } } - return std::shared_ptr(new BasicVectorSet(vectorSet, - m_options->m_inputValueType, - col, - end - start)); + return std::shared_ptr(new BasicVectorSet(vectorSet, m_options->m_inputValueType, col, end - start)); } - -std::shared_ptr -XvecVectorReader::GetMetadataSet() const +std::shared_ptr XvecVectorReader::GetMetadataSet() const { return nullptr; } diff --git a/AnnService/src/IndexBuilder/main.cpp b/AnnService/src/IndexBuilder/main.cpp index 3673479cc..7dd2730eb 100644 --- a/AnnService/src/IndexBuilder/main.cpp +++ b/AnnService/src/IndexBuilder/main.cpp @@ -1,21 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Helper/VectorSetReader.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/VectorSetReader.h" -#include #include +#include using namespace SPTAG; class BuilderOptions : public Helper::ReaderOptions { -public: + public: BuilderOptions() : Helper::ReaderOptions(VectorValueType::Float, 0, VectorFileType::TXT, "|", 32) - { + { AddRequiredOption(m_outputFolder, "-o", "--outputfolder", "Output folder."); AddRequiredOption(m_indexAlgoType, "-a", "--algo", "Index Algorithm type."); AddOptionalOption(m_inputFiles, "-i", "--input", "Input raw data."); @@ -24,7 +24,9 @@ class BuilderOptions : public Helper::ReaderOptions AddOptionalOption(m_metaMapping, "-m", "--metaindex", "Enable delete vectors through metadata"); } - ~BuilderOptions() {} + ~BuilderOptions() + { + } std::string m_inputFiles; @@ -39,14 +41,14 @@ class BuilderOptions : public Helper::ReaderOptions bool m_metaMapping = false; }; -int main(int argc, char* argv[]) +int main(int argc, char *argv[]) { std::shared_ptr options(new BuilderOptions); if (!options->Parse(argc - 1, argv + 1)) { exit(1); } - LOG(Helper::LogLevel::LL_Info, "Set QuantizerFile = %s\n", options->m_quantizerFile.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set QuantizerFile = %s\n", options->m_quantizerFile.c_str()); auto indexBuilder = VectorIndex::CreateInstance(options->m_indexAlgoType, options->m_inputValueType); if (!options->m_quantizerFile.empty()) @@ -59,9 +61,10 @@ int main(int argc, char* argv[]) } Helper::IniReader iniReader; - if (!options->m_builderConfigFile.empty() && iniReader.LoadIniFile(options->m_builderConfigFile) != ErrorCode::Success) + if (!options->m_builderConfigFile.empty() && + iniReader.LoadIniFile(options->m_builderConfigFile) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Cannot open index configure file!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open index configure file!"); return -1; } @@ -69,46 +72,55 @@ int main(int argc, char* argv[]) { std::string param(argv[i]); size_t idx = param.find("="); - if (idx == std::string::npos) continue; + if (idx == std::string::npos) + continue; std::string paramName = param.substr(0, idx); std::string paramVal = param.substr(idx + 1); std::string sectionName; idx = paramName.find("."); - if (idx != std::string::npos) { + if (idx != std::string::npos) + { sectionName = paramName.substr(0, idx); paramName = paramName.substr(idx + 1); } iniReader.SetParameter(sectionName, paramName, paramVal); - LOG(Helper::LogLevel::LL_Info, "Set [%s]%s = %s\n", sectionName.c_str(), paramName.c_str(), paramVal.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set [%s]%s = %s\n", sectionName.c_str(), paramName.c_str(), + paramVal.c_str()); } - std::string sections[] = { "Base", "SelectHead", "BuildHead", "BuildSSDIndex", "Index" }; - for (int i = 0; i < 5; i++) { - if (!iniReader.DoesParameterExist(sections[i], "NumberOfThreads")) { + std::string sections[] = {"Base", "SelectHead", "BuildHead", "BuildSSDIndex", "Index"}; + for (int i = 0; i < 5; i++) + { + if (!iniReader.DoesParameterExist(sections[i], "NumberOfThreads")) + { iniReader.SetParameter(sections[i], "NumberOfThreads", std::to_string(options->m_threadNum)); } - for (const auto& iter : iniReader.GetParameters(sections[i])) + for (const auto &iter : iniReader.GetParameters(sections[i])) { indexBuilder->SetParameter(iter.first.c_str(), iter.second.c_str(), sections[i]); } } - + ErrorCode code; std::shared_ptr vecset; - if (options->m_inputFiles != "") { + if (options->m_inputFiles != "") + { auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile(options->m_inputFiles)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); exit(1); } vecset = vectorReader->GetVectorSet(); - code = indexBuilder->BuildIndex(vecset, vectorReader->GetMetadataSet(), options->m_metaMapping, options->m_normalized, true); + code = indexBuilder->BuildIndex(vecset, vectorReader->GetMetadataSet(), options->m_metaMapping, + options->m_normalized, true); } - else { - indexBuilder->SetQuantizerFileName(options->m_quantizerFile.substr(options->m_quantizerFile.find_last_of("/\\") + 1)); - code = indexBuilder->BuildIndex(options->m_normalized); + else + { + indexBuilder->SetQuantizerFileName( + options->m_quantizerFile.substr(options->m_quantizerFile.find_last_of("/\\") + 1)); + code = indexBuilder->BuildIndex(options->m_normalized); } if (code == ErrorCode::Success) { @@ -116,7 +128,7 @@ int main(int argc, char* argv[]) } else { - LOG(Helper::LogLevel::LL_Error, "Failed to build index.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to build index.\n"); exit(1); } return 0; diff --git a/AnnService/src/IndexSearcher/main.cpp b/AnnService/src/IndexSearcher/main.cpp index 8497f831f..356d2db1f 100644 --- a/AnnService/src/IndexSearcher/main.cpp +++ b/AnnService/src/IndexSearcher/main.cpp @@ -1,27 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Helper/VectorSetReader.h" -#include "inc/Helper/SimpleIniReader.h" -#include "inc/Helper/CommonHelper.h" -#include "inc/Helper/StringConvert.h" #include "inc/Core/Common/CommonUtils.h" -#include "inc/Core/Common/TruthSet.h" #include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/Common/TruthSet.h" #include "inc/Core/VectorIndex.h" +#include "inc/Helper/AsyncFileReader.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/VectorSetReader.h" #include -#include -#include #include +#include #include +#include +#include #include -#include using namespace SPTAG; class SearcherOptions : public Helper::ReaderOptions { -public: + public: SearcherOptions() : Helper::ReaderOptions(VectorValueType::Float, 0, VectorFileType::TXT, "|", 32) { AddRequiredOption(m_queryFile, "-i", "--input", "Input raw data."); @@ -41,7 +42,9 @@ class SearcherOptions : public Helper::ReaderOptions AddOptionalOption(m_outputformat, "-of", "--ouputformat", "0: TXT 1: BINARY."); } - ~SearcherOptions() {} + ~SearcherOptions() + { + } std::string m_queryFile; @@ -74,33 +77,33 @@ class SearcherOptions : public Helper::ReaderOptions int m_outputformat = 0; }; -template -int Process(std::shared_ptr options, VectorIndex& index) +template int Process(std::shared_ptr options, VectorIndex &index) { std::ofstream log("Recall-result.out", std::ios::app); if (!log.is_open()) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open logging file!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open logging file!\n"); exit(-1); } auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile(options->m_queryFile)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); exit(1); } auto queryVectors = vectorReader->GetVectorSet(0, options->m_debugQuery); auto queryMetas = vectorReader->GetMetadataSet(); - std::shared_ptr dataOptions(new Helper::ReaderOptions(queryVectors->GetValueType(), queryVectors->Dimension(), options->m_dataFileType)); + std::shared_ptr dataOptions( + new Helper::ReaderOptions(queryVectors->GetValueType(), queryVectors->Dimension(), options->m_dataFileType)); auto dataReader = Helper::VectorSetReader::CreateInstance(dataOptions); std::shared_ptr dataVectors; if (options->m_dataFile != "") { if (ErrorCode::Success != dataReader->LoadFile(options->m_dataFile)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read data file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read data file.\n"); exit(1); } dataVectors = dataReader->GetVectorSet(); @@ -110,25 +113,33 @@ int Process(std::shared_ptr options, VectorIndex& index) int truthDim = 0; if (options->m_truthFile != "") { - if (options->m_genTruth) { - if (dataVectors == nullptr) { - LOG(Helper::LogLevel::LL_Error, "Cannot load data vectors to generate groundtruth! Please speicify data vector file by setting -df option.\n"); + if (options->m_genTruth) + { + if (dataVectors == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot load data vectors to generate groundtruth! Please " + "speicify data vector file by setting -df option.\n"); exit(1); } - COMMON::TruthSet::GenerateTruth(queryVectors, dataVectors, options->m_truthFile, index.GetDistCalcMethod(), options->m_truthK, - (options->m_truthFile.find("bin") != std::string::npos) ? TruthFileType::DEFAULT : TruthFileType::TXT, index.m_pQuantizer); + COMMON::TruthSet::GenerateTruth( + queryVectors, dataVectors, options->m_truthFile, index.GetDistCalcMethod(), options->m_truthK, + (options->m_truthFile.find("bin") != std::string::npos) ? TruthFileType::DEFAULT : TruthFileType::TXT, + index.m_pQuantizer); } ftruth = SPTAG::f_createIO(); - if (ftruth == nullptr || !ftruth->Initialize(options->m_truthFile.c_str(), std::ios::in | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open %s for read!\n", options->m_truthFile.c_str()); + if (ftruth == nullptr || !ftruth->Initialize(options->m_truthFile.c_str(), std::ios::in | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open %s for read!\n", options->m_truthFile.c_str()); exit(1); } - if (options->m_truthFile.find("bin") != std::string::npos) { - LOG(Helper::LogLevel::LL_Info, "Load binary truth...\n"); + if (options->m_truthFile.find("bin") != std::string::npos) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load binary truth...\n"); } - else { - LOG(Helper::LogLevel::LL_Info, "Load txt truth...\n"); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Load txt truth...\n"); } } @@ -136,94 +147,125 @@ int Process(std::shared_ptr options, VectorIndex& index) if (options->m_resultFile != "") { fp = SPTAG::f_createIO(); - if (fp == nullptr || !fp->Initialize(options->m_resultFile.c_str(), std::ios::out | std::ios::binary)) { - LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open %s for write!\n", options->m_resultFile.c_str()); + if (fp == nullptr || !fp->Initialize(options->m_resultFile.c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ERROR: Cannot open %s for write!\n", + options->m_resultFile.c_str()); } - if (options->m_outputformat == 1) { - LOG(Helper::LogLevel::LL_Info, "Using output format binary..."); + if (options->m_outputformat == 1) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using output format binary..."); int32_t i32Val = queryVectors->Count(); - if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); exit(1); } i32Val = options->m_K; - if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); exit(1); } } } std::vector maxCheck = Helper::StrUtils::SplitString(options->m_maxCheck, "#"); - if (options->m_truthK < 0) options->m_truthK = options->m_K; + if (options->m_truthK < 0) + options->m_truthK = options->m_K; std::vector> truth(options->m_batch); int internalResultNum = options->m_K; - if (index.GetIndexAlgoType() == IndexAlgoType::SPANN) { + if (index.GetIndexAlgoType() == IndexAlgoType::SPANN) + { int SPANNInternalResultNum; - if (SPTAG::Helper::Convert::ConvertStringTo(index.GetParameter("SearchInternalResultNum", "BuildSSDIndex").c_str(), SPANNInternalResultNum)) + if (SPTAG::Helper::Convert::ConvertStringTo( + index.GetParameter("SearchInternalResultNum", "BuildSSDIndex").c_str(), SPANNInternalResultNum)) internalResultNum = max(internalResultNum, SPANNInternalResultNum); } std::vector results(options->m_batch, QueryResult(NULL, internalResultNum, options->m_withMeta != 0)); std::vector latencies(options->m_batch, 0); int baseSquare = SPTAG::COMMON::Utils::GetBase() * SPTAG::COMMON::Utils::GetBase(); - LOG(Helper::LogLevel::LL_Info, "[query]\t\t[maxcheck]\t[avg] \t[99%] \t[95%] \t[recall] \t[qps] \t[mem]\n"); - std::vector totalAvg(maxCheck.size(), 0.0), total99(maxCheck.size(), 0.0), total95(maxCheck.size(), 0.0), totalRecall(maxCheck.size(), 0.0), totalLatency(maxCheck.size(), 0.0); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, + "[query]\t\t[maxcheck]\t[avg] \t[99%] \t[95%] \t[recall] \t[qps] \t[mem]\n"); + std::vector totalAvg(maxCheck.size(), 0.0), total99(maxCheck.size(), 0.0), total95(maxCheck.size(), 0.0), + totalRecall(maxCheck.size(), 0.0), totalLatency(maxCheck.size(), 0.0); for (int startQuery = 0; startQuery < queryVectors->Count(); startQuery += options->m_batch) { int numQuerys = min(options->m_batch, queryVectors->Count() - startQuery); - for (SizeType i = 0; i < numQuerys; i++) results[i].SetTarget(queryVectors->GetVector(startQuery + i)); - if (ftruth != nullptr) COMMON::TruthSet::LoadTruth(ftruth, truth, numQuerys, truthDim, options->m_truthK, (options->m_truthFile.find("bin") != std::string::npos)? TruthFileType::DEFAULT : TruthFileType::TXT); - + for (SizeType i = 0; i < numQuerys; i++) + results[i].SetTarget(queryVectors->GetVector(startQuery + i)); + if (ftruth != nullptr) + COMMON::TruthSet::LoadTruth(ftruth, truth, numQuerys, truthDim, options->m_truthK, + (options->m_truthFile.find("bin") != std::string::npos) ? TruthFileType::DEFAULT + : TruthFileType::TXT); for (int mc = 0; mc < maxCheck.size(); mc++) { index.SetParameter("MaxCheck", maxCheck[mc].c_str()); - for (SizeType i = 0; i < numQuerys; i++) results[i].Reset(); + for (SizeType i = 0; i < numQuerys; i++) + results[i].Reset(); std::atomic_size_t queriesSent(0); std::vector threads; - auto func = [&]() + threads.reserve(options->m_threadNum); + auto batchstart = std::chrono::high_resolution_clock::now(); + + for (std::uint32_t i = 0; i < options->m_threadNum; i++) { - size_t qid = 0; - while (true) - { - qid = queriesSent.fetch_add(1); - if (qid < numQuerys) + threads.emplace_back([&, i] { + NumaStrategy ns = (index.GetIndexAlgoType() == IndexAlgoType::SPANN) + ? NumaStrategy::SCATTER + : NumaStrategy::LOCAL; // Only for SPANN, we need to avoid IO threads overlap + // with search threads. + Helper::SetThreadAffinity(i, threads[i], ns, OrderStrategy::ASC); + + size_t qid = 0; + while (true) { - auto t1 = std::chrono::high_resolution_clock::now(); - index.SearchIndex(results[qid]); - auto t2 = std::chrono::high_resolution_clock::now(); - latencies[qid] = (float)(std::chrono::duration_cast(t2 - t1).count() / 1000000.0); - } - else - { - return; + qid = queriesSent.fetch_add(1); + if (qid < numQuerys) + { + auto t1 = std::chrono::high_resolution_clock::now(); + index.SearchIndex(results[qid]); + auto t2 = std::chrono::high_resolution_clock::now(); + latencies[qid] = + (float)(std::chrono::duration_cast(t2 - t1).count() / + 1000000.0); + } + else + { + return; + } } - } - }; - - auto batchstart = std::chrono::high_resolution_clock::now(); - - for (std::uint32_t i = 0; i < options->m_threadNum; i++) { threads.emplace_back(func); } - for (auto& thread : threads) { thread.join(); } + }); + } + for (auto &thread : threads) + { + thread.join(); + } auto batchend = std::chrono::high_resolution_clock::now(); - float batchLatency = (float)(std::chrono::duration_cast(batchend - batchstart).count() / 1000000.0); + float batchLatency = + (float)(std::chrono::duration_cast(batchend - batchstart).count() / + 1000000.0); float timeMean = 0, timeMin = MaxDist, timeMax = 0, timeStd = 0; for (int qid = 0; qid < numQuerys; qid++) { timeMean += latencies[qid]; - if (latencies[qid] > timeMax) timeMax = latencies[qid]; - if (latencies[qid] < timeMin) timeMin = latencies[qid]; + if (latencies[qid] > timeMax) + timeMax = latencies[qid]; + if (latencies[qid] < timeMin) + timeMin = latencies[qid]; } timeMean /= numQuerys; - for (int qid = 0; qid < numQuerys; qid++) timeStd += ((float)latencies[qid] - timeMean) * ((float)latencies[qid] - timeMean); + for (int qid = 0; qid < numQuerys; qid++) + timeStd += ((float)latencies[qid] - timeMean) * ((float)latencies[qid] - timeMean); timeStd = std::sqrt(timeStd / numQuerys); log << timeMean << " " << timeStd << " " << timeMin << " " << timeMax << " "; @@ -234,7 +276,9 @@ int Process(std::shared_ptr options, VectorIndex& index) float recall = 0; if (ftruth != nullptr) { - recall = COMMON::TruthSet::CalculateRecall(&index, results, truth, options->m_K, options->m_truthK, queryVectors, dataVectors, numQuerys, &log, options->m_debugQuery > 0); + recall = COMMON::TruthSet::CalculateRecall(&index, results, truth, options->m_K, options->m_truthK, + queryVectors, dataVectors, numQuerys, &log, + options->m_debugQuery > 0); } #ifndef _MSC_VER @@ -246,7 +290,9 @@ int Process(std::shared_ptr options, VectorIndex& index) GetProcessMemoryInfo(GetCurrentProcess(), &pmc, sizeof(pmc)); unsigned long long peakWSS = pmc.PeakWorkingSetSize / 1000000000; #endif - LOG(Helper::LogLevel::LL_Info, "%d-%d\t%s\t%.4f\t%.4f\t%.4f\t%.4f\t\t%.4f\t\t%lluGB\n", startQuery, (startQuery + numQuerys), maxCheck[mc].c_str(), timeMean, l99, l95, recall, (numQuerys / batchLatency), peakWSS); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d-%d\t%s\t%.4f\t%.4f\t%.4f\t%.4f\t\t%.4f\t\t%lluGB\n", startQuery, + (startQuery + numQuerys), maxCheck[mc].c_str(), timeMean, l99, l95, recall, + (numQuerys / batchLatency), peakWSS); totalAvg[mc] += timeMean * numQuerys; total95[mc] += l95 * numQuerys; total99[mc] += l99 * numQuerys; @@ -256,20 +302,25 @@ int Process(std::shared_ptr options, VectorIndex& index) if (fp != nullptr) { - if (options->m_outputformat == 0) { + if (options->m_outputformat == 0) + { for (SizeType i = 0; i < numQuerys; i++) { - if (queryMetas != nullptr) { + if (queryMetas != nullptr) + { ByteArray qmeta = queryMetas->GetMetadata(startQuery + i); - if (fp->WriteBinary(qmeta.Length(), (const char*)qmeta.Data()) != qmeta.Length()) { - LOG(Helper::LogLevel::LL_Error, "Cannot write qmeta %d bytes!\n", qmeta.Length()); + if (fp->WriteBinary(qmeta.Length(), (const char *)qmeta.Data()) != qmeta.Length()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write qmeta %d bytes!\n", qmeta.Length()); exit(1); } } - else { + else + { std::string qid = std::to_string(i); - if (fp->WriteBinary(qid.length(), qid.c_str()) != qid.length()) { - LOG(Helper::LogLevel::LL_Error, "Cannot write qid %d bytes!\n", qid.length()); + if (fp->WriteBinary(qid.length(), qid.c_str()) != qid.length()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write qid %d bytes!\n", qid.length()); exit(1); } } @@ -277,27 +328,33 @@ int Process(std::shared_ptr options, VectorIndex& index) for (int j = 0; j < options->m_K; j++) { std::string sd = std::to_string(results[i].GetResult(j)->Dist / baseSquare); - if (fp->WriteBinary(sd.length(), sd.c_str()) != sd.length()) { - LOG(Helper::LogLevel::LL_Error, "Cannot write dist %d bytes!\n", sd.length()); + if (fp->WriteBinary(sd.length(), sd.c_str()) != sd.length()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write dist %d bytes!\n", sd.length()); exit(1); } fp->WriteString("@"); - if (results[i].GetResult(j)->VID < 0) { + if (results[i].GetResult(j)->VID < 0) + { fp->WriteString("NULL|"); continue; } - if (!options->m_withMeta) { + if (!options->m_withMeta) + { std::string vid = std::to_string(results[i].GetResult(j)->VID); - if (fp->WriteBinary(vid.length(), vid.c_str()) != vid.length()) { - LOG(Helper::LogLevel::LL_Error, "Cannot write vid %d bytes!\n", sd.length()); + if (fp->WriteBinary(vid.length(), vid.c_str()) != vid.length()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write vid %d bytes!\n", sd.length()); exit(1); } } - else { + else + { ByteArray vm = index.GetMetadata(results[i].GetResult(j)->VID); - if (fp->WriteBinary(vm.Length(), (const char*)vm.Data()) != vm.Length()) { - LOG(Helper::LogLevel::LL_Error, "Cannot write vmeta %d bytes!\n", vm.Length()); + if (fp->WriteBinary(vm.Length(), (const char *)vm.Data()) != vm.Length()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot write vmeta %d bytes!\n", vm.Length()); exit(1); } } @@ -306,20 +363,23 @@ int Process(std::shared_ptr options, VectorIndex& index) fp->WriteString("\n"); } } - else { + else + { for (SizeType i = 0; i < numQuerys; ++i) { for (int j = 0; j < options->m_K; ++j) { SizeType i32Val = results[i].GetResult(j)->VID; - if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + if (fp->WriteBinary(sizeof(i32Val), reinterpret_cast(&i32Val)) != sizeof(i32Val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); exit(1); } float fVal = results[i].GetResult(j)->Dist; - if (fp->WriteBinary(sizeof(fVal), reinterpret_cast(&fVal)) != sizeof(fVal)) { - LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); + if (fp->WriteBinary(sizeof(fVal), reinterpret_cast(&fVal)) != sizeof(fVal)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write result file!\n"); exit(1); } } @@ -328,16 +388,20 @@ int Process(std::shared_ptr options, VectorIndex& index) } } for (int mc = 0; mc < maxCheck.size(); mc++) - LOG(Helper::LogLevel::LL_Info, "%d-%d\t%s\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\n", 0, queryVectors->Count(), maxCheck[mc].c_str(), (totalAvg[mc] / queryVectors->Count()), (total99[mc] / queryVectors->Count()), (total95[mc] / queryVectors->Count()), (totalRecall[mc] / queryVectors->Count()), (queryVectors->Count() / totalLatency[mc])); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d-%d\t%s\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\n", 0, queryVectors->Count(), + maxCheck[mc].c_str(), (totalAvg[mc] / queryVectors->Count()), + (total99[mc] / queryVectors->Count()), (total95[mc] / queryVectors->Count()), + (totalRecall[mc] / queryVectors->Count()), (queryVectors->Count() / totalLatency[mc])); - LOG(Helper::LogLevel::LL_Info, "Output results finish!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Output results finish!\n"); - if (fp != nullptr) fp->ShutDown(); + if (fp != nullptr) + fp->ShutDown(); log.close(); return 0; } -int main(int argc, char** argv) +int main(int argc, char **argv) { std::shared_ptr options(new SearcherOptions); if (!options->Parse(argc - 1, argv + 1)) @@ -349,7 +413,7 @@ int main(int argc, char** argv) auto ret = SPTAG::VectorIndex::LoadIndex(options->m_indexFolder, vecIndex); if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) { - LOG(Helper::LogLevel::LL_Error, "Cannot open index configure file!"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot open index configure file!"); return -1; } vecIndex->SetQuantizerADC(options->m_enableADC); @@ -359,26 +423,31 @@ int main(int argc, char** argv) { std::string param(argv[i]); size_t idx = param.find("="); - if (idx == std::string::npos) continue; + if (idx == std::string::npos) + continue; std::string paramName = param.substr(0, idx); std::string paramVal = param.substr(idx + 1); std::string sectionName; idx = paramName.find("."); - if (idx != std::string::npos) { + if (idx != std::string::npos) + { sectionName = paramName.substr(0, idx); paramName = paramName.substr(idx + 1); } iniReader.SetParameter(sectionName, paramName, paramVal); - LOG(Helper::LogLevel::LL_Info, "Set [%s]%s = %s\n", sectionName.c_str(), paramName.c_str(), paramVal.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set [%s]%s = %s\n", sectionName.c_str(), paramName.c_str(), + paramVal.c_str()); } - std::string sections[] = { "Base", "SelectHead", "BuildHead", "BuildSSDIndex", "Index" }; - for (int i = 0; i < 5; i++) { - if (!iniReader.DoesParameterExist(sections[i], "NumberOfThreads")) { + std::string sections[] = {"Base", "SelectHead", "BuildHead", "BuildSSDIndex", "Index"}; + for (int i = 0; i < 5; i++) + { + if (!iniReader.DoesParameterExist(sections[i], "NumberOfThreads")) + { iniReader.SetParameter(sections[i], "NumberOfThreads", std::to_string(options->m_threadNum)); } - for (const auto& iter : iniReader.GetParameters(sections[i])) + for (const auto &iter : iniReader.GetParameters(sections[i])) { vecIndex->SetParameter(iter.first.c_str(), iter.second.c_str(), sections[i]); } @@ -388,15 +457,16 @@ int main(int argc, char** argv) switch (options->m_inputValueType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - Process(options, *(vecIndex.get())); \ - break; \ +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + Process(options, *(vecIndex.get())); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: break; + default: + break; } return 0; } diff --git a/AnnService/src/KeyValueTest/main.cpp b/AnnService/src/KeyValueTest/main.cpp new file mode 100644 index 000000000..470dcc9bc --- /dev/null +++ b/AnnService/src/KeyValueTest/main.cpp @@ -0,0 +1,684 @@ +#include +#include + +#include "inc/Core/Common.h" +#include "inc/Core/SPANN/ExtraFileController.h" +#include "inc/Core/SPANN/IExtraSearcher.h" +#define CONFLICT_TEST + +using namespace SPTAG; + +int main(int argc, char *argv[]) +{ + int max_blocks = 5; + int kv_num = 1000; + int dataset_size = 10000; + int iter_num = 1000; + int batch_num = 10; + int thread_num = 4; + int timeout_times = 0; + double read_rate = 0.5; + std::mutex error_mtx; + std::vector dataset(dataset_size); + std::vector values(kv_num); + std::chrono::high_resolution_clock::time_point start, end; + std::chrono::microseconds timeout; + + for (int i = 0; i < dataset_size; i++) + { + int size = rand() % (max_blocks << 12); + dataset[i] = ""; + for (int j = 0; j < size; j++) + { + dataset[i] += (char)(rand() % 256); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Data generated\n"); + SPTAG::SPANN::Options opt; + opt.m_indexDirectory = "/tmp"; + opt.m_ssdMappingFile = "pbfile"; + opt.m_postingPageLimit = max_blocks * 2; + opt.m_spdkBatchSize = 64; + SPANN::FileIO fileIO(opt); + + bool single_thread_test = false; + bool multi_thread_test = false; + bool multi_get_test = false; + bool mixed_read_write_test = false; + bool timeout_test = true; + bool conflict_test = false; + + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + // 单线程存取 + if (!single_thread_test) + { + goto MultiThreadTest; + } + start = std::chrono::high_resolution_clock::now(); + for (int iter = 0; iter < iter_num; iter++) + { + // 随机生成一些数据 + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + } + // 写入数据 + for (int key = 0; key < kv_num; key++) + { + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + } + // 读取数据 + for (int key = 0; key < kv_num; key++) + { + std::string readValue; + fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if (dataset[values[key]] != readValue) + { + std::cout << "Error: key " << key << " value not match" << std::endl; + std::cout << "True value: "; + for (auto c : dataset[values[key]]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + for (auto c : readValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return 0; + } + } + // if (error) { + // std::cout << "Error in single thread iter " << iter << std::endl; + // return 0; + // } + } + end = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Single thread test passed\n"); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Single thread time: %d ms\n", + // std::chrono::duration_cast(end - start).count()); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Single thread IOPS: %fk\n", (double)iter_num * kv_num * 2 / + // std::chrono::duration_cast(end - start).count()); + fileIO.GetStat(); + + // 多线程存取 +MultiThreadTest: + if (!multi_thread_test) + { + goto MultiGetTest; + } + iter_num = 10; + kv_num = 100000; + thread_num = 4; + values.resize(kv_num); + start = std::chrono::high_resolution_clock::now(); + for (int iter = 0; iter < iter_num; iter++) + { + std::vector threads; + std::vector thread_keys[thread_num]; + std::vector errors(thread_num); + // 随机生成一些数据 + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + } + for (int i = 0; i < kv_num; i++) + { + int thread_id = rand() % thread_num; + thread_keys[thread_id].push_back(i); + } + for (int i = 0; i < thread_num; i++) + { + threads.emplace_back([&fileIO, &values, &thread_keys, &dataset, i]() { + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + for (auto key : thread_keys[i]) + { + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + } + }); + } + for (auto &thread : threads) + { + thread.join(); + } + threads.clear(); + for (int i = 0; i < thread_num; i++) + { + thread_keys[i].clear(); + } + for (int i = 0; i < kv_num; i++) + { + int thread_id = rand() % thread_num; + thread_keys[thread_id].push_back(i); + } + for (int i = 0; i < thread_num; i++) + { + threads.emplace_back([&fileIO, &values, &thread_keys, &dataset, &errors, &error_mtx, i]() { + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + for (auto key : thread_keys[i]) + { + std::string readValue; + fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if (dataset[values[key]] != readValue) + { + // errors[i].push_back({key, readValue}); + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << key << " value not match" << std::endl; + std::cout << "True value: "; + for (auto c : dataset[values[key]]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + for (auto c : readValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + } + }); + } + for (auto &thread : threads) + { + thread.join(); + } + threads.clear(); + for (int i = 0; i < thread_num; i++) + { + if (errors[i]) + { + std::cout << "Error in multi thread iter " << iter << std::endl; + return 0; + } + } + } + end = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Multi thread test passed\n"); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Multi thread time: %d ms\n", + // std::chrono::duration_cast(end - start).count()); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Multi thread IOPS: %fk\n", (double)iter_num * kv_num * 2 / + // std::chrono::duration_cast(end - start).count()); + fileIO.GetStat(); + + // MultiGet测试 +MultiGetTest: + if (!multi_get_test) + { + goto MixReadWriteTest; + } + iter_num = 100; + kv_num = 100000; + thread_num = 4; + values.resize(kv_num); + start = std::chrono::high_resolution_clock::now(); + for (int iter = 0; iter < iter_num; iter++) + { + std::vector threads; + std::vector thread_keys[thread_num]; + std::vector errors(thread_num); + // 随机生成一些数据 + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + } + for (int i = 0; i < kv_num; i++) + { + int thread_id = rand() % thread_num; + thread_keys[thread_id].push_back(i); + } + for (int i = 0; i < thread_num; i++) + { + threads.emplace_back([&fileIO, &values, &thread_keys, &errors, &error_mtx, &dataset, i]() { + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + for (auto key : thread_keys[i]) + { + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + } + std::vector readValues; + std::vector keys; + int num = thread_keys[i].size(); + for (int k = 0; k < num; k += 256) + { + readValues.clear(); + keys.clear(); + for (int j = k; j < std::min(k + 256, num); j++) + { + keys.push_back(thread_keys[i][j]); + } + fileIO.MultiGet(keys, &readValues, MaxTimeout, &(workspace.m_diskRequests)); + for (int j = 0; j < keys.size(); j++) + { + if (dataset[values[keys[j]]] != readValues[j]) + { + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << keys[j] << " value not match" << std::endl; + std::cout << "True value: "; + for (auto c : dataset[values[keys[j]]]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + for (auto c : readValues[j]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + } + } + }); + } + for (auto &thread : threads) + { + thread.join(); + } + threads.clear(); + for (int i = 0; i < thread_num; i++) + { + if (errors[i]) + { + std::cout << "Error in MultiGet test iter " << iter << std::endl; + return 0; + } + } + } + end = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "MultiGet test passed\n"); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "MultiGet time: %d ms\n", + // std::chrono::duration_cast(end - start).count()); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "MultiGet IOPS: %fk\n", (double)iter_num * kv_num * 2 / + // std::chrono::duration_cast(end - start).count()); + fileIO.GetStat(); + + // 多线程混合读写存取 +MixReadWriteTest: + if (!mixed_read_write_test) + { + goto TimeoutTest; + } + batch_num = 10; + iter_num = 100000; + read_rate = 0.5; + values.resize(kv_num); + start = std::chrono::high_resolution_clock::now(); + for (int batch = 0; batch < batch_num; batch++) + { + std::vector threads; + std::vector thread_keys[thread_num]; + std::vector errors(thread_num, false); + for (int i = 0; i < kv_num; i++) + { + int thread_id = rand() % thread_num; + thread_keys[thread_id].push_back(i); + } + for (int i = 0; i < thread_num; i++) + { + threads.emplace_back([&fileIO, &values, &thread_keys, &dataset, &errors, &error_mtx, i, max_blocks, + read_rate, iter_num, dataset_size]() { + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + int num = thread_keys[i].size(); + int read_count = 0, write_count = 0; + bool is_read = false; + bool is_merge = false; + for (auto key : thread_keys[i]) + { + values[key] = rand() % dataset_size; + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + } + for (int j = 0; j < iter_num; j++) + { + int key = thread_keys[i][rand() % num]; + if (rand() < RAND_MAX * read_rate) + { + is_read = true; + } + else + { + is_read = false; + if (rand() < RAND_MAX * 0.5) + { + is_merge = true; + } + else + { + is_merge = false; + } + } + if (is_read) + { + std::string readValue; + fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if (dataset[values[key]] != readValue) + { + // errors[i].push_back({key, readValue}); + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << key << " value not match" << std::endl; + std::cout << "True value: "; + for (auto c : dataset[values[key]]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + for (auto c : readValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + read_count++; + } + else if (is_merge) + { + std::string mergeValue = ""; + int size = rand() % (max_blocks << 12); + for (int k = 0; k < size; k++) + { + mergeValue += (char)(rand() % 256); + } + fileIO.Merge(key, mergeValue, MaxTimeout, &(workspace.m_diskRequests)); + write_count++; + read_count++; + std::string readValue; + fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if (readValue != dataset[values[key]] + mergeValue) + { + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << key << " value not match" << std::endl; + std::cout << "True value: "; + for (auto c : mergeValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + for (auto c : readValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + write_count++; + read_count++; + } + else + { + values[key] = rand() % dataset_size; + fileIO.Put(key, dataset[values[key]], MaxTimeout, &(workspace.m_diskRequests)); + write_count++; + } + } + }); + } + for (auto &thread : threads) + { + thread.join(); + } + threads.clear(); + for (int i = 0; i < thread_num; i++) + { + if (errors[i]) + { + std::cout << "Error in batch " << batch << std::endl; + return 0; + } + } + } + end = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Mix read write test passed\n"); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Mix read write time: %d ms\n", + // std::chrono::duration_cast(end - start).count()); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Mix read write IOPS: %fk\n", (double)batch_num * iter_num / + // std::chrono::duration_cast(end - start).count()); + fileIO.GetStat(); + + // 超时测试 +TimeoutTest: + if (!timeout_test) + { + goto ConflictTest; + } + batch_num = 10; + timeout = std::chrono::microseconds(0); + iter_num = 50; + kv_num = 10; + values.resize(kv_num); + for (int iter = 0; iter < iter_num; iter++) + { + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + fileIO.Put(i, dataset[values[i]], MaxTimeout, &(workspace.m_diskRequests)); + } + for (int key = 0; key < kv_num; key++) + { + std::string readValue; + auto result = fileIO.Get(key, &readValue, timeout, &(workspace.m_diskRequests)); + if (result == ErrorCode::Fail) + { + timeout_times++; + } + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Timeout times: %d\n", timeout_times); + iter_num = 10; + kv_num = 1000; + values.resize(kv_num); + for (int iter = 0; iter < iter_num; iter++) + { + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + fileIO.Put(i, dataset[values[i]], MaxTimeout, &(workspace.m_diskRequests)); + } + for (int key = 0; key < kv_num; key++) + { + std::string readValue; + auto result = fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if (result == ErrorCode::Fail) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error while get key %d\n", key); + return 0; + } + if (dataset[values[key]] != readValue) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error: key %d value not match\n", key); + return 0; + } + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Timeout test passed\n"); + fileIO.GetStat(); + +#ifdef CONFLICT_TEST + // 冲突测试 +ConflictTest: + if (!conflict_test) + { + goto End; + } + batch_num = 10; + iter_num = 1000; + kv_num = 2048; + thread_num = 4; + values.resize(kv_num); + // start = std::chrono::high_resolution_clock::now(); + for (int iter = 0; iter < batch_num; iter++) + { + std::vector threads; + std::vector errors(thread_num); + // 随机生成一些数据 + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset_size; + } + std::string init_str(PageSize, -1); + for (int i = 0; i < kv_num; i++) + { + fileIO.Put(i, init_str, MaxTimeout, &(workspace.m_diskRequests)); + } + for (int i = 0; i < thread_num; i++) + { + threads.emplace_back([&fileIO, &errors, &error_mtx, &dataset, iter_num, kv_num, i]() { + SPTAG::SPANN::ExtraWorkSpace workspace; + workspace.Initialize(4096, 2, 64, 4 * PageSize, true, false); + + std::vector readValues; + std::vector keys; + std::vector values(kv_num); + for (int i = 0; i < kv_num; i++) + { + values[i] = rand() % dataset.size(); + } + for (int k = 0; k < iter_num; k++) + { + int key = rand() % kv_num; + bool is_read = false; + bool batch_read = false; + if (rand() < RAND_MAX * 0.66) + { + is_read = true; + if (rand() < RAND_MAX * 0.5) + { + batch_read = true; + } + } + if (is_read && !batch_read) + { + std::string readValue; + fileIO.Get(key, &readValue, MaxTimeout, &(workspace.m_diskRequests)); + if ((int)readValue[0] != i) + { + ; + } + else + { + readValue[0] = dataset[values[key]][0]; + if (readValue != dataset[values[key]]) + { + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << key << " value not match" << std::endl; + std::cout << "True value: "; + std::cout << i << " "; + for (int j = 1; j < PageSize; j++) + { + std::cout << (int)dataset[values[key]][j] << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + readValue[0] = (char)i; + for (auto c : readValue) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + } + } + else if (is_read && batch_read) + { + std::vector readValues; + std::vector keys; + int num = 256; + std::set key_set; + while (key_set.size() < num) + { + key_set.insert(rand() % kv_num); + } + keys.assign(key_set.begin(), key_set.end()); + fileIO.MultiGet(keys, &readValues, MaxTimeout, &(workspace.m_diskRequests)); + for (int j = 0; j < keys.size(); j++) + { + if ((int)readValues[j][0] != i) + { + ; + } + else + { + readValues[j][0] = dataset[values[keys[j]]][0]; + if (readValues[j] != dataset[values[keys[j]]]) + { + errors[i] = true; + std::lock_guard lock(error_mtx); + std::cout << "Error: key " << keys[j] << " value not match" << std::endl; + std::cout << "True value: "; + std::cout << i << " "; + for (int k = 1; k < PageSize; k++) + { + std::cout << (int)dataset[values[keys[j]]][k] << " "; + } + std::cout << std::endl; + std::cout << "Read value: "; + readValues[j][0] = (char)i; + for (auto c : readValues[j]) + { + std::cout << (int)c << " "; + } + std::cout << std::endl; + return; + } + } + } + } + else + { + values[key] = rand() % dataset.size(); + std::string tmp_str(dataset[values[key]]); + tmp_str[0] = (char)i; + fileIO.Put(key, tmp_str, MaxTimeout, &(workspace.m_diskRequests)); + } + } + }); + } + for (auto &thread : threads) + { + thread.join(); + } + threads.clear(); + for (int i = 0; i < thread_num; i++) + { + if (errors[i]) + { + std::cout << "Error in Conflict test iter " << iter << std::endl; + return 0; + } + } + } + // end = std::chrono::high_resolution_clock::now(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Conflict test passed\n"); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Conflict test time: %d ms\n", + // std::chrono::duration_cast(end - start).count()); + // SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Conflict test IOPS: %fk\n", (double)iter_num * kv_num * 2 / + // std::chrono::duration_cast(end - start).count()); + fileIO.GetStat(); +#endif +End: + std::cout << "Test passed" << std::endl; + return 0; +} diff --git a/AnnService/src/Quantizer/main.cpp b/AnnService/src/Quantizer/main.cpp index 16ffa8476..a23f8dca3 100644 --- a/AnnService/src/Quantizer/main.cpp +++ b/AnnService/src/Quantizer/main.cpp @@ -1,74 +1,122 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Helper/VectorSetReader.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/SimpleIniReader.h" -#include +#include "inc/Helper/VectorSetReader.h" #include "inc/Quantizer/Training.h" +#include #include using namespace SPTAG; -void QuantizeAndSave(std::shared_ptr& vectorReader, std::shared_ptr& options, std::shared_ptr& quantizer) +void QuantizeAndSave(std::shared_ptr &vectorReader, + std::shared_ptr &options, std::shared_ptr &quantizer) { std::shared_ptr set; - for (int i = 0; (set = vectorReader->GetVectorSet(i, i + options->m_trainingSamples))->Count() > 0; i += options->m_trainingSamples) + for (int i = 0; (set = vectorReader->GetVectorSet(i, i + options->m_trainingSamples))->Count() > 0; + i += options->m_trainingSamples) { - if (i % (options->m_trainingSamples *10) == 0 || i % options->m_trainingSamples != 0) + if (i % (options->m_trainingSamples * 10) == 0 || i % options->m_trainingSamples != 0) { - LOG(Helper::LogLevel::LL_Info, "Saving vector batch starting at %d\n", i); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving vector batch starting at %d\n", i); } std::shared_ptr quantized_vectors; if (options->m_normalized) { - LOG(Helper::LogLevel::LL_Info, "Normalizing vectors.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Normalizing vectors.\n"); set->Normalize(options->m_threadNum); } ByteArray PQ_vector_array = ByteArray::Alloc(sizeof(std::uint8_t) * options->m_quantizedDim * set->Count()); - quantized_vectors = std::make_shared(PQ_vector_array, VectorValueType::UInt8, options->m_quantizedDim, set->Count()); - -#pragma omp parallel for - for (int i = 0; i < set->Count(); i++) + quantized_vectors = std::make_shared(PQ_vector_array, VectorValueType::UInt8, + options->m_quantizedDim, set->Count()); { - quantizer->QuantizeVector(set->GetVector(i), (uint8_t*)quantized_vectors->GetVector(i)); + std::vector mythreads; + mythreads.reserve(options->m_threadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < options->m_threadNum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < set->Count()) + { + quantizer->QuantizeVector(set->GetVector(i), (uint8_t *)quantized_vectors->GetVector(i)); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); } - ErrorCode code; if ((code = quantized_vectors->AppendSave(options->m_outputFile)) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to save quantized vectors, ErrorCode: %s.\n", SPTAG::Helper::Convert::ConvertToString(code).c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save quantized vectors, ErrorCode: %s.\n", + SPTAG::Helper::Convert::ConvertToString(code).c_str()); exit(1); } if (!options->m_outputFullVecFile.empty()) { if (ErrorCode::Success != set->AppendSave(options->m_outputFullVecFile)) { - LOG(Helper::LogLevel::LL_Error, "Failed to save uncompressed vectors.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save uncompressed vectors.\n"); exit(1); } } if (!options->m_outputReconstructVecFile.empty()) { -#pragma omp parallel for - for (int i = 0; i < set->Count(); i++) + std::vector mythreads; + mythreads.reserve(options->m_threadNum); + std::atomic_size_t sent(0); + for (int tid = 0; tid < options->m_threadNum; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < set->Count()) + { + quantizer->ReconstructVector((uint8_t *)quantized_vectors->GetVector(i), set->GetVector(i)); + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) { - quantizer->ReconstructVector((uint8_t*)quantized_vectors->GetVector(i), set->GetVector(i)); + t.join(); } + mythreads.clear(); + if (ErrorCode::Success != set->AppendSave(options->m_outputReconstructVecFile)) { - LOG(Helper::LogLevel::LL_Error, "Failed to save uncompressed vectors.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to save uncompressed vectors.\n"); exit(1); } } } } -int main(int argc, char* argv[]) +int main(int argc, char *argv[]) { - std::shared_ptr options = std::make_shared(10000, true, 0.0, SPTAG::QuantizerType::None, std::string(), -1, std::string(), std::string()); + std::shared_ptr options = std::make_shared( + 10000, true, 0.0f, SPTAG::QuantizerType::None, std::string(), -1, std::string(), std::string()); if (!options->Parse(argc - 1, argv + 1)) { @@ -77,19 +125,19 @@ int main(int argc, char* argv[]) auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile(options->m_inputFiles)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read input file.\n"); exit(1); } switch (options->m_quantizerType) { - case QuantizerType::None: - { + case QuantizerType::None: { std::shared_ptr set; - for (int i = 0; (set = vectorReader->GetVectorSet(i, i + options->m_trainingSamples))->Count() > 0; i += options-> m_trainingSamples) + for (int i = 0; (set = vectorReader->GetVectorSet(i, i + options->m_trainingSamples))->Count() > 0; + i += options->m_trainingSamples) { set->AppendSave(options->m_outputFile); } - + if (!options->m_outputMetadataFile.empty() && !options->m_outputMetadataIndexFile.empty()) { auto metadataSet = vectorReader->GetMetadataSet(); @@ -101,34 +149,43 @@ int main(int argc, char* argv[]) break; } - case QuantizerType::PQQuantizer: - { + case QuantizerType::PQQuantizer: { std::shared_ptr quantizer; auto fp_load = SPTAG::f_createIO(); - if (fp_load == nullptr || !fp_load->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::in)) + if (fp_load == nullptr || + !fp_load->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::in)) { auto set = vectorReader->GetVectorSet(0, options->m_trainingSamples); ByteArray PQ_vector_array = ByteArray::Alloc(sizeof(std::uint8_t) * options->m_quantizedDim * set->Count()); - std::shared_ptr quantized_vectors = std::make_shared(PQ_vector_array, VectorValueType::UInt8, options->m_quantizedDim, set->Count()); - LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Training a new one.\n"); + std::shared_ptr quantized_vectors = std::make_shared( + PQ_vector_array, VectorValueType::UInt8, options->m_quantizedDim, set->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Training a new one.\n"); switch (options->m_inputValueType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - quantizer.reset(new COMMON::PQQuantizer(options->m_quantizedDim, 256, (DimensionType)(options->m_dimension/options->m_quantizedDim), false, TrainPQQuantizer(options, set, quantized_vectors))); \ - break; +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + quantizer.reset(new COMMON::PQQuantizer( \ + options->m_quantizedDim, 256, (DimensionType)(options->m_dimension / options->m_quantizedDim), false, \ + TrainPQQuantizer(options, set, quantized_vectors))); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType + + default: { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Undefined InputValue Type!\n"); + exit(1); + } } auto ptr = SPTAG::f_createIO(); - if (ptr != nullptr && ptr->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::out)) + if (ptr != nullptr && + ptr->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::out)) { if (ErrorCode::Success != quantizer->SaveQuantizer(ptr)) { - LOG(Helper::LogLevel::LL_Error, "Failed to write quantizer file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to write quantizer file.\n"); exit(1); } } @@ -138,7 +195,7 @@ int main(int argc, char* argv[]) quantizer = SPTAG::COMMON::IQuantizer::LoadIQuantizer(fp_load); if (!quantizer) { - LOG(Helper::LogLevel::LL_Error, "Failed to open existing quantizer file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open existing quantizer file.\n"); exit(1); } quantizer->SetEnableADC(false); @@ -151,16 +208,16 @@ int main(int argc, char* argv[]) { metadataSet->SaveMetadata(options->m_outputMetadataFile, options->m_outputMetadataIndexFile); } - + break; } - case QuantizerType::OPQQuantizer: - { + case QuantizerType::OPQQuantizer: { std::shared_ptr quantizer; auto fp_load = SPTAG::f_createIO(); - if (fp_load == nullptr || !fp_load->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::in)) + if (fp_load == nullptr || + !fp_load->Initialize(options->m_outputQuantizerFile.c_str(), std::ios::binary | std::ios::in)) { - LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Not supported for OPQ.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Quantizer Does not exist. Not supported for OPQ.\n"); exit(1); } else @@ -168,7 +225,7 @@ int main(int argc, char* argv[]) quantizer = SPTAG::COMMON::IQuantizer::LoadIQuantizer(fp_load); if (!quantizer) { - LOG(Helper::LogLevel::LL_Error, "Failed to open existing quantizer file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open existing quantizer file.\n"); exit(1); } quantizer->SetEnableADC(false); @@ -176,7 +233,6 @@ int main(int argc, char* argv[]) QuantizeAndSave(vectorReader, options, quantizer); - auto metadataSet = vectorReader->GetMetadataSet(); if (metadataSet) { @@ -185,10 +241,9 @@ int main(int argc, char* argv[]) break; } - default: - { - LOG(Helper::LogLevel::LL_Error, "Failed to read quantizer type.\n"); + default: { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read quantizer type.\n"); exit(1); } } -} \ No newline at end of file +} diff --git a/AnnService/src/SPFresh/main.cpp b/AnnService/src/SPFresh/main.cpp new file mode 100644 index 000000000..0b3e5bb55 --- /dev/null +++ b/AnnService/src/SPFresh/main.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "inc/Core/Common.h" +#include "inc/Core/Common/TruthSet.h" +#include "inc/Core/SPANN/Index.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/VectorSetReader.h" + +#include "inc/SPFresh/SPFresh.h" + +using namespace SPTAG; + +// switch between exe and static library by _$(OutputType) +#ifdef _exe + +int main(int argc, char *argv[]) +{ + if (argc < 2) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "spfresh storePath\n"); + exit(-1); + } + + auto ret = SSDServing::SPFresh::UpdateTest(argv[1]); + return ret; +} + +#endif diff --git a/AnnService/src/SSDServing/main.cpp b/AnnService/src/SSDServing/main.cpp index d61270696..867e02fea 100644 --- a/AnnService/src/SSDServing/main.cpp +++ b/AnnService/src/SSDServing/main.cpp @@ -4,195 +4,209 @@ #include #include "inc/Core/Common.h" -#include "inc/Core/VectorIndex.h" +#include "inc/Core/Common/TruthSet.h" #include "inc/Core/SPANN/Index.h" +#include "inc/Core/VectorIndex.h" #include "inc/Helper/SimpleIniReader.h" -#include "inc/Helper/VectorSetReader.h" #include "inc/Helper/StringConvert.h" -#include "inc/Core/Common/TruthSet.h" +#include "inc/Helper/VectorSetReader.h" -#include "inc/SSDServing/main.h" -#include "inc/SSDServing/Utils.h" #include "inc/SSDServing/SSDIndex.h" +#include "inc/SSDServing/Utils.h" +#include "inc/SSDServing/main.h" using namespace SPTAG; -namespace SPTAG { - namespace SSDServing { - - int BootProgram(bool forANNIndexTestTool, - std::map>* config_map, - const char* configurationPath, - VectorValueType valueType, - DistCalcMethod distCalcMethod, - const char* dataFilePath, - const char* indexFilePath) { - - - bool searchSSD = false; - std::string QuantizerFilePath = ""; - if (forANNIndexTestTool) { - (*config_map)[SEC_BASE]["ValueType"] = Helper::Convert::ConvertToString(valueType); - (*config_map)[SEC_BASE]["DistCalcMethod"] = Helper::Convert::ConvertToString(distCalcMethod); - (*config_map)[SEC_BASE]["VectorPath"] = dataFilePath; - (*config_map)[SEC_BASE]["IndexDirectory"] = indexFilePath; - - (*config_map)[SEC_BUILD_HEAD]["KDTNumber"] = "2"; - (*config_map)[SEC_BUILD_HEAD]["NeighborhoodSize"] = "32"; - (*config_map)[SEC_BUILD_HEAD]["TPTNumber"] = "32"; - (*config_map)[SEC_BUILD_HEAD]["TPTLeafSize"] = "2000"; - (*config_map)[SEC_BUILD_HEAD]["MaxCheck"] = "4096"; - (*config_map)[SEC_BUILD_HEAD]["MaxCheckForRefineGraph"] = "4096"; - (*config_map)[SEC_BUILD_HEAD]["RefineIterations"] = "3"; - (*config_map)[SEC_BUILD_HEAD]["GraphNeighborhoodScale"] = "1"; - (*config_map)[SEC_BUILD_HEAD]["GraphCEFScale"] = "1"; - - (*config_map)[SEC_BASE]["DeleteHeadVectors"] = "true"; - (*config_map)[SEC_SELECT_HEAD]["isExecute"] = "true"; - (*config_map)[SEC_BUILD_HEAD]["isExecute"] = "true"; - (*config_map)[SEC_BUILD_SSD_INDEX]["isExecute"] = "true"; - (*config_map)[SEC_BUILD_SSD_INDEX]["BuildSsdIndex"] = "true"; - - std::map::iterator iter; - if ((iter = (*config_map)[SEC_BASE].find("QuantizerFilePath")) != (*config_map)[SEC_BASE].end()) { - QuantizerFilePath = iter->second; - } - } - else { - Helper::IniReader iniReader; - iniReader.LoadIniFile(configurationPath); - (*config_map)[SEC_BASE] = iniReader.GetParameters(SEC_BASE); - (*config_map)[SEC_SELECT_HEAD] = iniReader.GetParameters(SEC_SELECT_HEAD); - (*config_map)[SEC_BUILD_HEAD] = iniReader.GetParameters(SEC_BUILD_HEAD); - (*config_map)[SEC_BUILD_SSD_INDEX] = iniReader.GetParameters(SEC_BUILD_SSD_INDEX); - - valueType = iniReader.GetParameter(SEC_BASE, "ValueType", valueType); - distCalcMethod = iniReader.GetParameter(SEC_BASE, "DistCalcMethod", distCalcMethod); - bool buildSSD = iniReader.GetParameter(SEC_BUILD_SSD_INDEX, "isExecute", false); - searchSSD = iniReader.GetParameter(SEC_SEARCH_SSD_INDEX, "isExecute", false); - QuantizerFilePath = iniReader.GetParameter(SEC_BASE, "QuantizerFilePath", std::string("")); - - for (auto& KV : iniReader.GetParameters(SEC_SEARCH_SSD_INDEX)) { - std::string param = KV.first, value = KV.second; - if (buildSSD && Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "BuildSsdIndex")) continue; - if (buildSSD && Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "isExecute")) continue; - if (Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "PostingPageLimit")) param = "SearchPostingPageLimit"; - if (Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "InternalResultNum")) param = "SearchInternalResultNum"; - (*config_map)[SEC_BUILD_SSD_INDEX][param] = value; - } - } - - - LOG(Helper::LogLevel::LL_Info, "Set QuantizerFile = %s\n", QuantizerFilePath.c_str()); - - std::shared_ptr index = VectorIndex::CreateInstance(IndexAlgoType::SPANN, valueType); - if (!QuantizerFilePath.empty() && index->LoadQuantizer(QuantizerFilePath) != ErrorCode::Success) - { - exit(1); - } - if (index == nullptr) { - LOG(Helper::LogLevel::LL_Error, "Cannot create Index with ValueType %s!\n", (*config_map)[SEC_BASE]["ValueType"].c_str()); - return -1; - } - - for (auto& sectionKV : *config_map) { - for (auto& KV : sectionKV.second) { - index->SetParameter(KV.first, KV.second, sectionKV.first); - } - } - - if (index->BuildIndex() != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to build index.\n"); - exit(1); - } - - SPANN::Options* opts = nullptr; - - -#define DefineVectorValueType(Name, Type) \ - if (index->GetVectorValueType() == VectorValueType::Name) { \ - opts = ((SPANN::Index*)index.get())->GetOptions(); \ - } \ +namespace SPTAG +{ +namespace SSDServing +{ + +int BootProgram(bool forANNIndexTestTool, std::map> *config_map, + const char *configurationPath, VectorValueType valueType, DistCalcMethod distCalcMethod, + const char *dataFilePath, const char *indexFilePath) +{ + + bool searchSSD = false; + std::string QuantizerFilePath = ""; + if (forANNIndexTestTool) + { + (*config_map)[SEC_BASE]["ValueType"] = Helper::Convert::ConvertToString(valueType); + (*config_map)[SEC_BASE]["DistCalcMethod"] = Helper::Convert::ConvertToString(distCalcMethod); + (*config_map)[SEC_BASE]["VectorPath"] = dataFilePath; + (*config_map)[SEC_BASE]["IndexDirectory"] = indexFilePath; + + (*config_map)[SEC_BUILD_HEAD]["KDTNumber"] = "2"; + (*config_map)[SEC_BUILD_HEAD]["NeighborhoodSize"] = "32"; + (*config_map)[SEC_BUILD_HEAD]["TPTNumber"] = "32"; + (*config_map)[SEC_BUILD_HEAD]["TPTLeafSize"] = "2000"; + (*config_map)[SEC_BUILD_HEAD]["MaxCheck"] = "4096"; + (*config_map)[SEC_BUILD_HEAD]["MaxCheckForRefineGraph"] = "4096"; + (*config_map)[SEC_BUILD_HEAD]["RefineIterations"] = "3"; + (*config_map)[SEC_BUILD_HEAD]["GraphNeighborhoodScale"] = "1"; + (*config_map)[SEC_BUILD_HEAD]["GraphCEFScale"] = "1"; + + (*config_map)[SEC_BASE]["DeleteHeadVectors"] = "true"; + (*config_map)[SEC_SELECT_HEAD]["isExecute"] = "true"; + (*config_map)[SEC_BUILD_HEAD]["isExecute"] = "true"; + (*config_map)[SEC_BUILD_SSD_INDEX]["isExecute"] = "true"; + (*config_map)[SEC_BUILD_SSD_INDEX]["BuildSsdIndex"] = "true"; + + std::map::iterator iter; + if ((iter = (*config_map)[SEC_BASE].find("QuantizerFilePath")) != (*config_map)[SEC_BASE].end()) + { + QuantizerFilePath = iter->second; + } + } + else + { + Helper::IniReader iniReader; + iniReader.LoadIniFile(configurationPath); + (*config_map)[SEC_BASE] = iniReader.GetParameters(SEC_BASE); + (*config_map)[SEC_SELECT_HEAD] = iniReader.GetParameters(SEC_SELECT_HEAD); + (*config_map)[SEC_BUILD_HEAD] = iniReader.GetParameters(SEC_BUILD_HEAD); + (*config_map)[SEC_BUILD_SSD_INDEX] = iniReader.GetParameters(SEC_BUILD_SSD_INDEX); + + valueType = iniReader.GetParameter(SEC_BASE, "ValueType", valueType); + distCalcMethod = iniReader.GetParameter(SEC_BASE, "DistCalcMethod", distCalcMethod); + bool buildSSD = iniReader.GetParameter(SEC_BUILD_SSD_INDEX, "isExecute", false); + searchSSD = iniReader.GetParameter(SEC_SEARCH_SSD_INDEX, "isExecute", false); + QuantizerFilePath = iniReader.GetParameter(SEC_BASE, "QuantizerFilePath", std::string("")); + + for (auto &KV : iniReader.GetParameters(SEC_SEARCH_SSD_INDEX)) + { + std::string param = KV.first, value = KV.second; + if (buildSSD && Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "BuildSsdIndex")) + continue; + if (buildSSD && Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "isExecute")) + continue; + if (Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "PostingPageLimit")) + param = "SearchPostingPageLimit"; + if (Helper::StrUtils::StrEqualIgnoreCase(param.c_str(), "InternalResultNum")) + param = "SearchInternalResultNum"; + (*config_map)[SEC_BUILD_SSD_INDEX][param] = value; + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Set QuantizerFile = %s\n", QuantizerFilePath.c_str()); + + std::shared_ptr index = VectorIndex::CreateInstance(IndexAlgoType::SPANN, valueType); + if (!QuantizerFilePath.empty() && index->LoadQuantizer(QuantizerFilePath) != ErrorCode::Success) + { + return -1; + } + if (index == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot create Index with ValueType %s!\n", + (*config_map)[SEC_BASE]["ValueType"].c_str()); + return -1; + } + + for (auto §ionKV : *config_map) + { + for (auto &KV : sectionKV.second) + { + index->SetParameter(KV.first, KV.second, sectionKV.first); + } + } + + if (index->BuildIndex() != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to build index.\n"); + return -1; + } + + SPANN::Options *opts = nullptr; + +#define DefineVectorValueType(Name, Type) \ + if (index->GetVectorValueType() == VectorValueType::Name) \ + { \ + opts = ((SPANN::Index *)index.get())->GetOptions(); \ + } #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - if (opts == nullptr) { - LOG(Helper::LogLevel::LL_Error, "Cannot get options.\n"); - exit(1); - } - -printf("in BootProgram! genreateTruth:%d\n", opts->m_generateTruth); - - if (opts->m_generateTruth) - { - LOG(Helper::LogLevel::LL_Info, "Start generating truth. It's maybe a long time.\n"); - SizeType dim = opts->m_dim; - if (index->m_pQuantizer) - { - valueType = VectorValueType::UInt8; - dim = index->m_pQuantizer->GetNumSubvectors(); - } - std::shared_ptr vectorOptions(new Helper::ReaderOptions(valueType, dim, opts->m_vectorType, opts->m_vectorDelimiter)); - auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); - if (ErrorCode::Success != vectorReader->LoadFile(opts->m_vectorPath)) - { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); - exit(1); - } - std::shared_ptr queryOptions(new Helper::ReaderOptions(opts->m_valueType, opts->m_dim, opts->m_queryType, opts->m_queryDelimiter)); - auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); - if (ErrorCode::Success != queryReader->LoadFile(opts->m_queryPath)) - { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); - exit(1); - } - auto vectorSet = vectorReader->GetVectorSet(); - auto querySet = queryReader->GetVectorSet(); - if (distCalcMethod == DistCalcMethod::Cosine && !index->m_pQuantizer) vectorSet->Normalize(opts->m_iSSDNumberOfThreads); - - omp_set_num_threads(opts->m_iSSDNumberOfThreads); - -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - COMMON::TruthSet::GenerateTruth(querySet, vectorSet, opts->m_truthPath, \ - distCalcMethod, opts->m_resultNum, opts->m_truthType, index->m_pQuantizer); \ - } \ + if (opts == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Cannot get options.\n"); + return -1; + } + + if (opts->m_generateTruth) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start generating truth. It's maybe a long time.\n"); + SizeType dim = opts->m_dim; + if (index->m_pQuantizer) + { + valueType = VectorValueType::UInt8; + dim = index->m_pQuantizer->GetNumSubvectors(); + } + std::shared_ptr vectorOptions( + new Helper::ReaderOptions(valueType, dim, opts->m_vectorType, opts->m_vectorDelimiter)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (ErrorCode::Success != vectorReader->LoadFile(opts->m_vectorPath)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + return -1; + } + auto vectorSet = vectorReader->GetVectorSet(); + + std::shared_ptr queryOptions( + new Helper::ReaderOptions(opts->m_valueType, opts->m_dim, opts->m_queryType, opts->m_queryDelimiter)); + auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); + if (ErrorCode::Success != queryReader->LoadFile(opts->m_queryPath)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + return -1; + } + auto querySet = queryReader->GetVectorSet(); + if (distCalcMethod == DistCalcMethod::Cosine && !index->m_pQuantizer) + vectorSet->Normalize(opts->m_iSSDNumberOfThreads); + +#define DefineVectorValueType(Name, Type) \ + if (opts->m_valueType == VectorValueType::Name) \ + { \ + COMMON::TruthSet::GenerateTruth(querySet, vectorSet, opts->m_truthPath, distCalcMethod, \ + opts->m_resultNum, opts->m_truthType, index->m_pQuantizer); \ + } #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - LOG(Helper::LogLevel::LL_Info, "End generating truth.\n"); - } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End generating truth.\n"); + } - if (searchSSD) { -#define DefineVectorValueType(Name, Type) \ - if (opts->m_valueType == VectorValueType::Name) { \ - SSDIndex::Search((SPANN::Index*)(index.get())); \ - } \ + if (searchSSD) + { +#define DefineVectorValueType(Name, Type) \ + if (opts->m_valueType == VectorValueType::Name) \ + { \ + if (ErrorCode::Success != SSDIndex::Search((SPANN::Index *)(index.get()))) return -1; \ + } #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - } - return 0; - } - } + } + return 0; } +} // namespace SSDServing +} // namespace SPTAG // switch between exe and static library by _$(OutputType) #ifdef _exe -int main(int argc, char* argv[]) { - if (argc < 2) - { - LOG(Helper::LogLevel::LL_Error, - "ssdserving configFilePath\n"); - exit(-1); - } - - std::map> my_map; - auto ret = SSDServing::BootProgram(false, &my_map, argv[1]); - return ret; +int main(int argc, char *argv[]) +{ + if (argc < 2) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "ssdserving configFilePath\n"); + return -1; + } + + std::map> my_map; + auto ret = SSDServing::BootProgram(false, &my_map, argv[1]); + return ret; } #endif diff --git a/AnnService/src/Server/QueryParser.cpp b/AnnService/src/Server/QueryParser.cpp index 42ff360eb..f7819036b 100644 --- a/AnnService/src/Server/QueryParser.cpp +++ b/AnnService/src/Server/QueryParser.cpp @@ -9,24 +9,17 @@ using namespace SPTAG; using namespace SPTAG::Service; +const char *QueryParser::c_defaultVectorSeparator = "|"; -const char* QueryParser::c_defaultVectorSeparator = "|"; - - -QueryParser::QueryParser() - : m_vectorBase64(nullptr), - m_vectorBase64Length(0) +QueryParser::QueryParser() : m_vectorBase64(nullptr), m_vectorBase64Length(0) { } - QueryParser::~QueryParser() { } - -ErrorCode -QueryParser::Parse(const std::string& p_query, const char* p_vectorSeparator) +ErrorCode QueryParser::Parse(const std::string &p_query, const char *p_vectorSeparator) { if (p_vectorSeparator == nullptr) { @@ -52,14 +45,14 @@ QueryParser::Parse(const std::string& p_query, const char* p_vectorSeparator) State currState = State::None; - char* optionName = nullptr; - char* vectorStrBegin = nullptr; - char* vectorStrEnd = nullptr; + char *optionName = nullptr; + char *vectorStrBegin = nullptr; + char *vectorStrEnd = nullptr; SizeType estDimension = 0; - char* iter = nullptr; + char *iter = nullptr; - for (iter = reinterpret_cast(m_dataHolder.Data()); *iter != '\0'; ++iter) + for (iter = reinterpret_cast(m_dataHolder.Data()); *iter != '\0'; ++iter) { if (std::isspace(*iter)) { @@ -180,30 +173,22 @@ QueryParser::Parse(const std::string& p_query, const char* p_vectorSeparator) return ErrorCode::Success; } - -const std::vector& -QueryParser::GetVectorElements() const +const std::vector &QueryParser::GetVectorElements() const { return m_vectorElements; } - -const std::vector& -QueryParser::GetOptions() const +const std::vector &QueryParser::GetOptions() const { return m_options; } - -const char* -QueryParser::GetVectorBase64() const +const char *QueryParser::GetVectorBase64() const { return m_vectorBase64; } - -SizeType -QueryParser::GetVectorBase64Length() const +SizeType QueryParser::GetVectorBase64Length() const { return m_vectorBase64Length; } diff --git a/AnnService/src/Server/SearchExecutionContext.cpp b/AnnService/src/Server/SearchExecutionContext.cpp index 45d83ec43..3f31cc987 100644 --- a/AnnService/src/Server/SearchExecutionContext.cpp +++ b/AnnService/src/Server/SearchExecutionContext.cpp @@ -2,44 +2,36 @@ // Licensed under the MIT License. #include "inc/Server/SearchExecutionContext.h" -#include "inc/Helper/CommonHelper.h" #include "inc/Helper/Base64Encode.h" +#include "inc/Helper/CommonHelper.h" using namespace SPTAG; using namespace SPTAG::Service; -SearchExecutionContext::SearchExecutionContext(const std::shared_ptr& p_serviceSettings) - : c_serviceSettings(p_serviceSettings), - m_vectorDimension(0), - m_inputValueType(VectorValueType::Undefined), - m_extractMetadata(false), - m_resultNum(p_serviceSettings->m_defaultMaxResultNumber) +SearchExecutionContext::SearchExecutionContext(const std::shared_ptr &p_serviceSettings) + : c_serviceSettings(p_serviceSettings), m_vectorDimension(0), m_inputValueType(VectorValueType::Undefined), + m_extractMetadata(false), m_resultNum(p_serviceSettings->m_defaultMaxResultNumber) { } - SearchExecutionContext::~SearchExecutionContext() { m_results.clear(); } - -ErrorCode -SearchExecutionContext::ParseQuery(const std::string& p_query) +ErrorCode SearchExecutionContext::ParseQuery(const std::string &p_query) { return m_queryParser.Parse(p_query, c_serviceSettings->m_vectorSeparator.c_str()); } - -ErrorCode -SearchExecutionContext::ExtractOption() +ErrorCode SearchExecutionContext::ExtractOption() { - for (const auto& optionPair : m_queryParser.GetOptions()) + for (const auto &optionPair : m_queryParser.GetOptions()) { if (Helper::StrUtils::StrEqualIgnoreCase(optionPair.first, "indexname")) { - const char* begin = optionPair.second; - const char* end = optionPair.second; + const char *begin = optionPair.second; + const char *end = optionPair.second; while (*end != '\0') { while (*end != '\0' && *end != ',') @@ -76,18 +68,16 @@ SearchExecutionContext::ExtractOption() return ErrorCode::Success; } - -ErrorCode -SearchExecutionContext::ExtractVector(VectorValueType p_targetType) +ErrorCode SearchExecutionContext::ExtractVector(VectorValueType p_targetType) { if (!m_queryParser.GetVectorElements().empty()) { switch (p_targetType) { -#define DefineVectorValueType(Name, Type) \ - case VectorValueType::Name: \ - return ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); \ - break; \ +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return ConvertVectorFromString(m_queryParser.GetVectorElements(), m_vector, m_vectorDimension); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType @@ -96,8 +86,7 @@ SearchExecutionContext::ExtractVector(VectorValueType p_targetType) break; } } - else if (m_queryParser.GetVectorBase64() != nullptr - && m_queryParser.GetVectorBase64Length() != 0) + else if (m_queryParser.GetVectorBase64() != nullptr && m_queryParser.GetVectorBase64Length() != 0) { SizeType estLen = m_queryParser.GetVectorBase64Length(); auto temp = ByteArray::Alloc(Helper::Base64::CapacityForDecode(estLen)); @@ -121,67 +110,49 @@ SearchExecutionContext::ExtractVector(VectorValueType p_targetType) return ErrorCode::Fail; } - -const std::vector& -SearchExecutionContext::GetSelectedIndexNames() const +const std::vector &SearchExecutionContext::GetSelectedIndexNames() const { return m_indexNames; } - -void -SearchExecutionContext::AddResults(std::string p_indexName, QueryResult& p_results) +void SearchExecutionContext::AddResults(std::string p_indexName, QueryResult &p_results) { m_results.emplace_back(); m_results.back().m_indexName.swap(p_indexName); m_results.back().m_results = p_results; } - -std::vector& -SearchExecutionContext::GetResults() +std::vector &SearchExecutionContext::GetResults() { return m_results; } - -const std::vector& -SearchExecutionContext::GetResults() const +const std::vector &SearchExecutionContext::GetResults() const { return m_results; } - -const ByteArray& -SearchExecutionContext::GetVector() const +const ByteArray &SearchExecutionContext::GetVector() const { return m_vector; } - -const SizeType -SearchExecutionContext::GetVectorDimension() const +const SizeType SearchExecutionContext::GetVectorDimension() const { return m_vectorDimension; } - -const std::vector& -SearchExecutionContext::GetOptions() const +const std::vector &SearchExecutionContext::GetOptions() const { return m_queryParser.GetOptions(); } - -const SizeType -SearchExecutionContext::GetResultNum() const +const SizeType SearchExecutionContext::GetResultNum() const { return m_resultNum; } - -const bool -SearchExecutionContext::GetExtractMetadata() const +const bool SearchExecutionContext::GetExtractMetadata() const { return m_extractMetadata; } diff --git a/AnnService/src/Server/SearchExecutor.cpp b/AnnService/src/Server/SearchExecutor.cpp index 37f4461aa..cccdf7fed 100644 --- a/AnnService/src/Server/SearchExecutor.cpp +++ b/AnnService/src/Server/SearchExecutor.cpp @@ -6,24 +6,17 @@ using namespace SPTAG; using namespace SPTAG::Service; - -SearchExecutor::SearchExecutor(std::string p_queryString, - std::shared_ptr p_serviceContext, - const CallBack& p_callback) - : m_callback(p_callback), - c_serviceContext(std::move(p_serviceContext)), - m_queryString(std::move(p_queryString)) +SearchExecutor::SearchExecutor(std::string p_queryString, std::shared_ptr p_serviceContext, + const CallBack &p_callback) + : m_callback(p_callback), c_serviceContext(std::move(p_serviceContext)), m_queryString(std::move(p_queryString)) { } - SearchExecutor::~SearchExecutor() { } - -void -SearchExecutor::Execute() +void SearchExecutor::Execute() { ExecuteInternal(); if (bool(m_callback)) @@ -32,14 +25,13 @@ SearchExecutor::Execute() } } - -void -SearchExecutor::ExecuteInternal() +void SearchExecutor::ExecuteInternal() { m_executionContext.reset(new SearchExecutionContext(c_serviceContext->GetServiceSettings())); - if (m_executionContext->ParseQuery(m_queryString) != ErrorCode::Success) { - LOG(Helper::LogLevel::LL_Error, "Failed to parse query:%s!\n", m_queryString.c_str()); + if (m_executionContext->ParseQuery(m_queryString) != ErrorCode::Success) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to parse query:%s!\n", m_queryString.c_str()); return; } @@ -49,32 +41,31 @@ SearchExecutor::ExecuteInternal() if (m_selectedIndex.empty()) { - LOG(Helper::LogLevel::LL_Error, "Empty selected index!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty selected index!\n"); return; } - const auto& firstIndex = m_selectedIndex.front(); + const auto &firstIndex = m_selectedIndex.front(); if (ErrorCode::Success != m_executionContext->ExtractVector(firstIndex->GetVectorValueType())) { - LOG(Helper::LogLevel::LL_Error, "Failed to extract vector!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to extract vector!\n"); return; } if (m_executionContext->GetVectorDimension() != firstIndex->GetFeatureDim()) { - LOG(Helper::LogLevel::LL_Error, "Failed to match vector dimension!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to match vector dimension!\n"); return; - } + } - QueryResult query(m_executionContext->GetVector().Data(), - m_executionContext->GetResultNum(), + QueryResult query(m_executionContext->GetVector().Data(), m_executionContext->GetResultNum(), m_executionContext->GetExtractMetadata()); - for (const auto& vectorIndex : m_selectedIndex) + for (const auto &vectorIndex : m_selectedIndex) { - if (vectorIndex->GetVectorValueType() != firstIndex->GetVectorValueType() - || vectorIndex->GetFeatureDim() != firstIndex->GetFeatureDim()) + if (vectorIndex->GetVectorValueType() != firstIndex->GetVectorValueType() || + vectorIndex->GetFeatureDim() != firstIndex->GetFeatureDim()) { continue; } @@ -84,18 +75,17 @@ SearchExecutor::ExecuteInternal() { m_executionContext->AddResults(vectorIndex->GetIndexName(), query); } - else { - LOG(Helper::LogLevel::LL_Error, "Failed to execute SearchIndex!\n"); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to execute SearchIndex!\n"); } } } - -void -SearchExecutor::SelectIndex() +void SearchExecutor::SelectIndex() { - const auto& indexNames = m_executionContext->GetSelectedIndexNames(); - const auto& indexMap = c_serviceContext->GetIndexMap(); + const auto &indexNames = m_executionContext->GetSelectedIndexNames(); + const auto &indexMap = c_serviceContext->GetIndexMap(); if (indexMap.empty()) { return; @@ -110,7 +100,7 @@ SearchExecutor::SelectIndex() } else { - for (const auto& indexName : indexNames) + for (const auto &indexName : indexNames) { auto iter = indexMap.find(indexName); if (iter != indexMap.cend()) @@ -120,4 +110,3 @@ SearchExecutor::SelectIndex() } } } - diff --git a/AnnService/src/Server/SearchService.cpp b/AnnService/src/Server/SearchService.cpp index 93b760f8b..bb0bc8e1d 100644 --- a/AnnService/src/Server/SearchService.cpp +++ b/AnnService/src/Server/SearchService.cpp @@ -2,17 +2,16 @@ // Licensed under the MIT License. #include "inc/Server/SearchService.h" +#include "inc/Helper/ArgumentsParser.h" +#include "inc/Helper/CommonHelper.h" #include "inc/Server/SearchExecutor.h" #include "inc/Socket/RemoteSearchQuery.h" -#include "inc/Helper/CommonHelper.h" -#include "inc/Helper/ArgumentsParser.h" #include using namespace SPTAG; using namespace SPTAG::Service; - namespace { namespace Local @@ -20,11 +19,8 @@ namespace Local class SerivceCmdOptions : public Helper::ArgumentsParser { -public: - SerivceCmdOptions() - : m_serveMode("interactive"), - m_configFile("AnnService.ini"), - m_logFile("") + public: + SerivceCmdOptions() : m_serveMode("interactive"), m_configFile("AnnService.ini"), m_logFile("") { AddOptionalOption(m_serveMode, "-m", "--mode", "Service mode, interactive or socket."); AddOptionalOption(m_configFile, "-c", "--config", "Service config file path."); @@ -42,26 +38,20 @@ class SerivceCmdOptions : public Helper::ArgumentsParser std::string m_logFile; }; -} +} // namespace Local } // namespace - SearchService::SearchService() - : m_initialized(false), - m_shutdownSignals(m_ioContext), - m_serveMode(ServeMode::Interactive) + : m_initialized(false), m_shutdownSignals(m_ioContext), m_serveMode(ServeMode::Interactive) { } - SearchService::~SearchService() { } - -bool -SearchService::Initialize(int p_argNum, char* p_args[]) +bool SearchService::Initialize(int p_argNum, char *p_args[]) { Local::SerivceCmdOptions cmdOptions; if (!cmdOptions.Parse(p_argNum - 1, p_args + 1)) @@ -79,12 +69,13 @@ SearchService::Initialize(int p_argNum, char* p_args[]) } else { - LOG(Helper::LogLevel::LL_Error, "Failed parse Serve Mode!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed parse Serve Mode!\n"); return false; } - if (!cmdOptions.m_logFile.empty()) { - g_pLogger.reset(new Helper::FileLogger(Helper::LogLevel::LL_Debug, cmdOptions.m_logFile.c_str())); + if (!cmdOptions.m_logFile.empty()) + { + SetLogger(std::make_shared(Helper::LogLevel::LL_Debug, cmdOptions.m_logFile.c_str())); } m_serviceContext.reset(new ServiceContext(cmdOptions.m_configFile)); @@ -94,9 +85,7 @@ SearchService::Initialize(int p_argNum, char* p_args[]) return m_initialized; } - -void -SearchService::Run() +void SearchService::Run() { if (!m_initialized) { @@ -118,29 +107,24 @@ SearchService::Run() } } - -void -SearchService::RunSocketMode() +void SearchService::RunSocketMode() { auto threadNum = max((SizeType)1, m_serviceContext->GetServiceSettings()->m_threadNum); m_threadPool.reset(new boost::asio::thread_pool(threadNum)); Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); - handlerMap->emplace(Socket::PacketType::SearchRequest, - [this](Socket::ConnectionID p_srcID, Socket::Packet p_packet) - { - boost::asio::post(*m_threadPool, std::bind(&SearchService::SearchHanlder, this, p_srcID, std::move(p_packet))); - }); + handlerMap->emplace(Socket::PacketType::SearchRequest, [this](Socket::ConnectionID p_srcID, + Socket::Packet p_packet) { + boost::asio::post(*m_threadPool, std::bind(&SearchService::SearchHanlder, this, p_srcID, std::move(p_packet))); + }); m_socketServer.reset(new Socket::Server(m_serviceContext->GetServiceSettings()->m_listenAddr, - m_serviceContext->GetServiceSettings()->m_listenPort, - handlerMap, + m_serviceContext->GetServiceSettings()->m_listenPort, handlerMap, m_serviceContext->GetServiceSettings()->m_socketThreadNum)); - LOG(Helper::LogLevel::LL_Info, - "Start to listen %s:%s ...\n", - m_serviceContext->GetServiceSettings()->m_listenAddr.c_str(), - m_serviceContext->GetServiceSettings()->m_listenPort.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to listen %s:%s ...\n", + m_serviceContext->GetServiceSettings()->m_listenAddr.c_str(), + m_serviceContext->GetServiceSettings()->m_listenPort.c_str()); m_shutdownSignals.add(SIGINT); m_shutdownSignals.add(SIGTERM); @@ -148,22 +132,19 @@ SearchService::RunSocketMode() m_shutdownSignals.add(SIGQUIT); #endif - m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) - { - LOG(Helper::LogLevel::LL_Info, "Received shutdown signals.\n"); - }); + m_shutdownSignals.async_wait([this](boost::system::error_code p_ec, int p_signal) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Received shutdown signals.\n"); + }); m_ioContext.run(); - LOG(Helper::LogLevel::LL_Info, "Start shutdown procedure.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start shutdown procedure.\n"); m_socketServer.reset(); m_threadPool->stop(); m_threadPool->join(); } - -void -SearchService::RunInteractiveMode() +void SearchService::RunInteractiveMode() { const std::size_t bufferSize = 1 << 16; std::unique_ptr inputBuffer(new char[bufferSize]); @@ -175,8 +156,7 @@ SearchService::RunInteractiveMode() break; } - auto callback = [](std::shared_ptr p_exeContext) - { + auto callback = [](std::shared_ptr p_exeContext) { std::cout << "Result:" << std::endl; if (nullptr == p_exeContext) { @@ -184,20 +164,20 @@ SearchService::RunInteractiveMode() return; } - const auto& results = p_exeContext->GetResults(); - for (const auto& result : results) + const auto &results = p_exeContext->GetResults(); + for (const auto &result : results) { std::cout << "Index: " << result.m_indexName << std::endl; int idx = 0; - for (const auto& res : result.m_results) + for (const auto &res : result.m_results) { std::cout << "------------------" << std::endl; std::cout << "DocIndex: " << res.VID << " Distance: " << res.Dist; if (result.m_results.WithMeta()) { - const auto& metadata = result.m_results.GetMetadata(idx); - std::cout << " MetaData: " << std::string((char*)metadata.Data(), metadata.Length()); - } + const auto &metadata = result.m_results.GetMetadata(idx); + std::cout << " MetaData: " << std::string((char *)metadata.Data(), metadata.Length()); + } std::cout << std::endl; ++idx; } @@ -209,13 +189,11 @@ SearchService::RunInteractiveMode() } } - -void -SearchService::SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) +void SearchService::SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet) { if (p_packet.Header().m_bodyLength == 0) { - LOG(Helper::LogLevel::LL_Error, "Empty package with body length equals 0!\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Empty package with body length equals 0!\n"); return; } @@ -225,26 +203,20 @@ SearchService::SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::P } Socket::RemoteQuery remoteQuery; - if(remoteQuery.Read(p_packet.Body()) == nullptr) { - LOG(Helper::LogLevel::LL_Error, "majorVersion is not match!\n"); + if (remoteQuery.Read(p_packet.Body()) == nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "majorVersion is not match!\n"); return; } - auto callback = std::bind(&SearchService::SearchHanlderCallback, - this, - std::placeholders::_1, - std::move(p_packet)); + auto callback = std::bind(&SearchService::SearchHanlderCallback, this, std::placeholders::_1, std::move(p_packet)); - SearchExecutor executor(std::move(remoteQuery.m_queryString), - m_serviceContext, - callback); + SearchExecutor executor(std::move(remoteQuery.m_queryString), m_serviceContext, callback); executor.Execute(); } - -void -SearchService::SearchHanlderCallback(std::shared_ptr p_exeContext, - Socket::Packet p_srcPacket) +void SearchService::SearchHanlderCallback(std::shared_ptr p_exeContext, + Socket::Packet p_srcPacket) { Socket::Packet ret; ret.Header().m_packetType = Socket::PacketType::SearchResponse; @@ -265,7 +237,7 @@ SearchService::SearchHanlderCallback(std::shared_ptr p_e remoteResult.m_allIndexResults.swap(p_exeContext->GetResults()); ret.AllocateBuffer(static_cast(remoteResult.EstimateBufferSize())); auto bodyEnd = remoteResult.Write(ret.Body()); - + ret.Header().m_bodyLength = static_cast(bodyEnd - ret.Body()); ret.Header().WriteBuffer(ret.HeaderBuffer()); } diff --git a/AnnService/src/Server/ServiceContext.cpp b/AnnService/src/Server/ServiceContext.cpp index 97d6bdf76..7f20eb950 100644 --- a/AnnService/src/Server/ServiceContext.cpp +++ b/AnnService/src/Server/ServiceContext.cpp @@ -2,16 +2,14 @@ // Licensed under the MIT License. #include "inc/Server/ServiceContext.h" -#include "inc/Helper/SimpleIniReader.h" #include "inc/Helper/CommonHelper.h" +#include "inc/Helper/SimpleIniReader.h" #include "inc/Helper/StringConvert.h" using namespace SPTAG; using namespace SPTAG::Service; - -ServiceContext::ServiceContext(const std::string& p_configFilePath) - : m_initialized(false) +ServiceContext::ServiceContext(const std::string &p_configFilePath) : m_initialized(false) { Helper::IniReader iniReader; if (ErrorCode::Success != iniReader.LoadIniFile(p_configFilePath)) @@ -24,17 +22,19 @@ ServiceContext::ServiceContext(const std::string& p_configFilePath) m_settings->m_listenAddr = iniReader.GetParameter("Service", "ListenAddr", std::string("0.0.0.0")); m_settings->m_listenPort = iniReader.GetParameter("Service", "ListenPort", std::string("8000")); m_settings->m_threadNum = iniReader.GetParameter("Service", "ThreadNumber", static_cast(8)); - m_settings->m_socketThreadNum = iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); + m_settings->m_socketThreadNum = + iniReader.GetParameter("Service", "SocketThreadNumber", static_cast(8)); - m_settings->m_defaultMaxResultNumber = iniReader.GetParameter("QueryConfig", "DefaultMaxResultNumber", static_cast(10)); + m_settings->m_defaultMaxResultNumber = + iniReader.GetParameter("QueryConfig", "DefaultMaxResultNumber", static_cast(10)); m_settings->m_vectorSeparator = iniReader.GetParameter("QueryConfig", "DefaultSeparator", std::string("|")); const std::string emptyStr; std::string indexListStr = iniReader.GetParameter("Index", "List", emptyStr); - const auto& indexList = Helper::StrUtils::SplitString(indexListStr, ","); + const auto &indexList = Helper::StrUtils::SplitString(indexListStr, ","); - for (const auto& indexName : indexList) + for (const auto &indexName : indexList) { std::string sectionName("Index_"); sectionName += indexName.c_str(); @@ -53,36 +53,28 @@ ServiceContext::ServiceContext(const std::string& p_configFilePath) } else { - LOG(Helper::LogLevel::LL_Error, "Failed loading index: %s\n", indexName.c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed loading index: %s\n", indexName.c_str()); } } m_initialized = true; } - ServiceContext::~ServiceContext() { - } - -const std::map>& -ServiceContext::GetIndexMap() const +const std::map> &ServiceContext::GetIndexMap() const { return m_fullIndexList; } - -const std::shared_ptr& -ServiceContext::GetServiceSettings() const +const std::shared_ptr &ServiceContext::GetServiceSettings() const { return m_settings; } - -bool -ServiceContext::IsInitialized() const +bool ServiceContext::IsInitialized() const { return m_initialized; } diff --git a/AnnService/src/Server/ServiceSettings.cpp b/AnnService/src/Server/ServiceSettings.cpp index d51153195..738f96446 100644 --- a/AnnService/src/Server/ServiceSettings.cpp +++ b/AnnService/src/Server/ServiceSettings.cpp @@ -6,9 +6,6 @@ using namespace SPTAG; using namespace SPTAG::Service; - -ServiceSettings::ServiceSettings() - : m_defaultMaxResultNumber(10), - m_threadNum(12) +ServiceSettings::ServiceSettings() : m_defaultMaxResultNumber(10), m_threadNum(12) { } diff --git a/AnnService/src/Server/main.cpp b/AnnService/src/Server/main.cpp index 5aa5dc1e5..4d74d9312 100644 --- a/AnnService/src/Server/main.cpp +++ b/AnnService/src/Server/main.cpp @@ -5,7 +5,7 @@ SPTAG::Service::SearchService g_service; -int main(int argc, char* argv[]) +int main(int argc, char *argv[]) { if (!g_service.Initialize(argc, argv)) { @@ -16,4 +16,3 @@ int main(int argc, char* argv[]) return 0; } - diff --git a/AnnService/src/Socket/Client.cpp b/AnnService/src/Socket/Client.cpp index 9c4101e4f..397b60e70 100644 --- a/AnnService/src/Socket/Client.cpp +++ b/AnnService/src/Socket/Client.cpp @@ -7,15 +7,10 @@ using namespace SPTAG::Socket; - -Client::Client(const PacketHandlerMapPtr& p_handlerMap, - std::size_t p_threadNum, +Client::Client(const PacketHandlerMapPtr &p_handlerMap, std::size_t p_threadNum, std::uint32_t p_heartbeatIntervalSeconds) - : c_requestHandlerMap(p_handlerMap), - m_connectionManager(new ConnectionManager), - m_deadlineTimer(m_ioContext), - m_heartbeatIntervalSeconds(p_heartbeatIntervalSeconds), - m_stopped(false) + : c_requestHandlerMap(p_handlerMap), m_connectionManager(new ConnectionManager), m_deadlineTimer(m_ioContext), + m_heartbeatIntervalSeconds(p_heartbeatIntervalSeconds), m_stopped(false) { KeepIoContext(); m_threadPool.reserve(p_threadNum); @@ -25,7 +20,6 @@ Client::Client(const PacketHandlerMapPtr& p_handlerMap, } } - Client::~Client() { m_stopped = true; @@ -37,17 +31,13 @@ Client::~Client() m_ioContext.stop(); } - for (auto& t : m_threadPool) + for (auto &t : m_threadPool) { t.join(); } } - -ConnectionID -Client::ConnectToServer(const std::string& p_address, - const std::string& p_port, - SPTAG::ErrorCode& p_ec) +ConnectionID Client::ConnectToServer(const std::string &p_address, const std::string &p_port, SPTAG::ErrorCode &p_ec) { boost::asio::ip::tcp::resolver resolver(m_ioContext); @@ -60,7 +50,7 @@ Client::ConnectToServer(const std::string& p_address, } boost::asio::ip::tcp::socket socket(m_ioContext); - for (const auto ep : endPoints) + for (const auto &ep : endPoints) { errCode.clear(); socket.connect(ep, errCode); @@ -75,24 +65,16 @@ Client::ConnectToServer(const std::string& p_address, if (socket.is_open()) { p_ec = ErrorCode::Success; - return m_connectionManager->AddConnection(std::move(socket), - c_requestHandlerMap, - m_heartbeatIntervalSeconds); + return m_connectionManager->AddConnection(std::move(socket), c_requestHandlerMap, m_heartbeatIntervalSeconds); } p_ec = ErrorCode::Socket_FailedConnectToEndPoint; return c_invalidConnectionID; } - -void -Client::AsyncConnectToServer(const std::string& p_address, - const std::string& p_port, - ConnectCallback p_callback) +void Client::AsyncConnectToServer(const std::string &p_address, const std::string &p_port, ConnectCallback p_callback) { - boost::asio::post(m_ioContext, - [this, p_address, p_port, p_callback]() - { + boost::asio::post(m_ioContext, [this, p_address, p_port, p_callback]() { SPTAG::ErrorCode errCode; auto connID = ConnectToServer(p_address, p_port, errCode); if (bool(p_callback)) @@ -102,9 +84,7 @@ Client::AsyncConnectToServer(const std::string& p_address, }); } - -void -Client::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) +void Client::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) { auto connection = m_connectionManager->GetConnection(p_connection); if (nullptr != connection) @@ -117,16 +97,12 @@ Client::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_event) +void Client::SetEventOnConnectionClose(std::function p_event) { m_connectionManager->SetEventOnRemoving(std::move(p_event)); } - -void -Client::KeepIoContext() +void Client::KeepIoContext() { if (m_stopped) { @@ -134,8 +110,5 @@ Client::KeepIoContext() } m_deadlineTimer.expires_from_now(boost::posix_time::hours(24)); - m_deadlineTimer.async_wait([this](boost::system::error_code p_ec) - { - this->KeepIoContext(); - }); + m_deadlineTimer.async_wait([this](boost::system::error_code p_ec) { this->KeepIoContext(); }); } diff --git a/AnnService/src/Socket/Common.cpp b/AnnService/src/Socket/Common.cpp index 2cfc1178e..2ac6d8c50 100644 --- a/AnnService/src/Socket/Common.cpp +++ b/AnnService/src/Socket/Common.cpp @@ -3,7 +3,6 @@ #include "inc/Socket/Common.h" - using namespace SPTAG::Socket; const ConnectionID SPTAG::Socket::c_invalidConnectionID = 0; diff --git a/AnnService/src/Socket/Connection.cpp b/AnnService/src/Socket/Connection.cpp index 9ea935106..150889d2f 100644 --- a/AnnService/src/Socket/Connection.cpp +++ b/AnnService/src/Socket/Connection.cpp @@ -2,43 +2,34 @@ // Licensed under the MIT License. #include "inc/Socket/Connection.h" -#include "inc/Socket/ConnectionManager.h" #include "inc/Core/Common.h" +#include "inc/Socket/ConnectionManager.h" -#include -#include #include #include +#include +#include #include #include using namespace SPTAG::Socket; -Connection::Connection(ConnectionID p_connectionID, - boost::asio::ip::tcp::socket&& p_socket, - const PacketHandlerMapPtr& p_handlerMap, - std::weak_ptr p_connectionManager) - : c_connectionID(p_connectionID), - c_handlerMap(p_handlerMap), - c_connectionManager(std::move(p_connectionManager)), - m_socket(std::move(p_socket)), - m_strand((boost::asio::io_context&)p_socket.get_executor().context()), - m_heartbeatTimer((boost::asio::io_context&)p_socket.get_executor().context()), - m_remoteConnectionID(c_invalidConnectionID), - m_stopped(true), - m_heartbeatStarted(false) +Connection::Connection(ConnectionID p_connectionID, boost::asio::ip::tcp::socket &&p_socket, + const PacketHandlerMapPtr &p_handlerMap, std::weak_ptr p_connectionManager) + : c_connectionID(p_connectionID), c_handlerMap(p_handlerMap), c_connectionManager(std::move(p_connectionManager)), + m_socket(std::move(p_socket)), m_strand((boost::asio::io_context &)p_socket.get_executor().context()), + m_heartbeatTimer((boost::asio::io_context &)p_socket.get_executor().context()), + m_remoteConnectionID(c_invalidConnectionID), m_stopped(true), m_heartbeatStarted(false) { } - -void -Connection::Start() +void Connection::Start() { - LOG(Helper::LogLevel::LL_Debug, "Connection Start, local: %u, remote: %s:%u\n", - static_cast(m_socket.local_endpoint().port()), - m_socket.remote_endpoint().address().to_string().c_str(), - static_cast(m_socket.remote_endpoint().port())); + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "Connection Start, local: %u, remote: %s:%u\n", + static_cast(m_socket.local_endpoint().port()), + m_socket.remote_endpoint().address().to_string().c_str(), + static_cast(m_socket.remote_endpoint().port())); if (!m_stopped.exchange(false)) { @@ -49,14 +40,12 @@ Connection::Start() AsyncReadHeader(); } - -void -Connection::Stop() +void Connection::Stop() { - LOG(Helper::LogLevel::LL_Debug, "Connection Stop, local: %u, remote: %s:%u\n", - static_cast(m_socket.local_endpoint().port()), - m_socket.remote_endpoint().address().to_string().c_str(), - static_cast(m_socket.remote_endpoint().port())); + SPTAGLIB_LOG(Helper::LogLevel::LL_Debug, "Connection Stop, local: %u, remote: %s:%u\n", + static_cast(m_socket.local_endpoint().port()), + m_socket.remote_endpoint().address().to_string().c_str(), + static_cast(m_socket.remote_endpoint().port())); if (m_stopped.exchange(true)) { @@ -68,14 +57,12 @@ Connection::Stop() { m_heartbeatTimer.cancel(errCode); } - + m_socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, errCode); m_socket.close(errCode); } - -void -Connection::StartHeartbeat(std::size_t p_intervalSeconds) +void Connection::StartHeartbeat(std::size_t p_intervalSeconds) { if (m_stopped || m_heartbeatStarted.exchange(true)) { @@ -85,23 +72,17 @@ Connection::StartHeartbeat(std::size_t p_intervalSeconds) SendHeartbeat(p_intervalSeconds); } - -ConnectionID -Connection::GetConnectionID() const +ConnectionID Connection::GetConnectionID() const { return c_connectionID; } - -ConnectionID -Connection::GetRemoteConnectionID() const +ConnectionID Connection::GetRemoteConnectionID() const { return m_remoteConnectionID; } - -void -Connection::AsyncSend(Packet p_packet, std::function p_callback) +void Connection::AsyncSend(Packet p_packet, std::function p_callback) { if (m_stopped) { @@ -114,33 +95,26 @@ Connection::AsyncSend(Packet p_packet, std::function p_callback) } auto sharedThis = shared_from_this(); - boost::asio::post(m_strand, - [sharedThis, p_packet, p_callback]() - { - auto handler = [p_callback, p_packet, sharedThis](boost::system::error_code p_ec, - std::size_t p_bytesTransferred) - { - if (p_ec && boost::asio::error::operation_aborted != p_ec) - { - sharedThis->OnConnectionFail(p_ec); - } - - if (bool(p_callback)) - { - p_callback(!p_ec); - } - }; - - boost::asio::async_write(sharedThis->m_socket, - boost::asio::buffer(p_packet.Buffer(), - p_packet.BufferLength()), - std::move(handler)); - }); -} + boost::asio::post(m_strand, [sharedThis, p_packet, p_callback]() { + auto handler = [p_callback, p_packet, sharedThis](boost::system::error_code p_ec, + std::size_t p_bytesTransferred) { + if (p_ec && boost::asio::error::operation_aborted != p_ec) + { + sharedThis->OnConnectionFail(p_ec); + } + if (bool(p_callback)) + { + p_callback(!p_ec); + } + }; + + boost::asio::async_write(sharedThis->m_socket, boost::asio::buffer(p_packet.Buffer(), p_packet.BufferLength()), + std::move(handler)); + }); +} -void -Connection::AsyncReadHeader() +void Connection::AsyncReadHeader() { if (m_stopped) { @@ -148,23 +122,16 @@ Connection::AsyncReadHeader() } auto sharedThis = shared_from_this(); - boost::asio::post(m_strand, - [sharedThis]() - { - auto handler = boost::bind(&Connection::HandleReadHeader, - sharedThis, - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred); - - boost::asio::async_read(sharedThis->m_socket, - boost::asio::buffer(sharedThis->m_packetHeaderReadBuffer), - std::move(handler)); - }); -} + boost::asio::post(m_strand, [sharedThis]() { + auto handler = boost::bind(&Connection::HandleReadHeader, sharedThis, boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred); + boost::asio::async_read(sharedThis->m_socket, boost::asio::buffer(sharedThis->m_packetHeaderReadBuffer), + std::move(handler)); + }); +} -void -Connection::AsyncReadBody() +void Connection::AsyncReadBody() { if (m_stopped) { @@ -172,24 +139,18 @@ Connection::AsyncReadBody() } auto sharedThis = shared_from_this(); - boost::asio::post(m_strand, - [sharedThis]() - { - auto handler = boost::bind(&Connection::HandleReadBody, - sharedThis, - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred); - - boost::asio::async_read(sharedThis->m_socket, - boost::asio::buffer(sharedThis->m_packetRead.Body(), - sharedThis->m_packetRead.Header().m_bodyLength), - std::move(handler)); - }); + boost::asio::post(m_strand, [sharedThis]() { + auto handler = boost::bind(&Connection::HandleReadBody, sharedThis, boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred); + + boost::asio::async_read( + sharedThis->m_socket, + boost::asio::buffer(sharedThis->m_packetRead.Body(), sharedThis->m_packetRead.Header().m_bodyLength), + std::move(handler)); + }); } - -void -Connection::HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred) +void Connection::HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred) { if (!p_ec) { @@ -215,9 +176,7 @@ Connection::HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytes AsyncReadHeader(); } - -void -Connection::HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred) +void Connection::HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred) { if (!p_ec) { @@ -268,9 +227,7 @@ Connection::HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTr AsyncReadHeader(); } - -void -Connection::SendHeartbeat(std::size_t p_intervalSeconds) +void Connection::SendHeartbeat(std::size_t p_intervalSeconds) { if (m_stopped) { @@ -281,21 +238,17 @@ Connection::SendHeartbeat(std::size_t p_intervalSeconds) msg.Header().m_packetType = PacketType::HeartbeatRequest; msg.Header().m_processStatus = PacketProcessStatus::Ok; msg.Header().m_connectionID = 0; - + msg.AllocateBuffer(0); msg.Header().WriteBuffer(msg.HeaderBuffer()); AsyncSend(std::move(msg), nullptr); m_heartbeatTimer.expires_from_now(boost::posix_time::seconds(p_intervalSeconds)); - m_heartbeatTimer.async_wait(boost::bind(&Connection::SendHeartbeat, - shared_from_this(), - p_intervalSeconds)); + m_heartbeatTimer.async_wait(boost::bind(&Connection::SendHeartbeat, shared_from_this(), p_intervalSeconds)); } - -void -Connection::SendRegister() +void Connection::SendRegister() { Packet msg; msg.Header().m_packetType = PacketType::RegisterRequest; @@ -308,9 +261,7 @@ Connection::SendRegister() AsyncSend(std::move(msg), nullptr); } - -void -Connection::HandleHeartbeatRequest() +void Connection::HandleHeartbeatRequest() { Packet msg; msg.Header().m_packetType = PacketType::HeartbeatResponse; @@ -318,8 +269,7 @@ Connection::HandleHeartbeatRequest() msg.AllocateBuffer(0); - if (0 == m_packetRead.Header().m_connectionID - || c_connectionID == m_packetRead.Header().m_connectionID) + if (0 == m_packetRead.Header().m_connectionID || c_connectionID == m_packetRead.Header().m_connectionID) { m_packetRead.Header().m_connectionID; msg.Header().WriteBuffer(msg.HeaderBuffer()); @@ -343,9 +293,7 @@ Connection::HandleHeartbeatRequest() } } - -void -Connection::HandleRegisterRequest() +void Connection::HandleRegisterRequest() { Packet msg; msg.Header().m_packetType = PacketType::RegisterResponse; @@ -359,16 +307,12 @@ Connection::HandleRegisterRequest() AsyncSend(std::move(msg), nullptr); } - -void -Connection::HandleRegisterResponse() +void Connection::HandleRegisterResponse() { m_remoteConnectionID = m_packetRead.Header().m_connectionID; } - -void -Connection::HandleNoHandlerResponse() +void Connection::HandleNoHandlerResponse() { auto packetType = m_packetRead.Header().m_packetType; if (!PacketTypeHelper::IsRequestPacket(packetType)) @@ -388,9 +332,7 @@ Connection::HandleNoHandlerResponse() AsyncSend(std::move(msg), nullptr); } - -void -Connection::OnConnectionFail(const boost::system::error_code& p_ec) +void Connection::OnConnectionFail(const boost::system::error_code &p_ec) { auto mgr = c_connectionManager.lock(); if (nullptr != mgr) diff --git a/AnnService/src/Socket/ConnectionManager.cpp b/AnnService/src/Socket/ConnectionManager.cpp index 9d52dbc8b..065258391 100644 --- a/AnnService/src/Socket/ConnectionManager.cpp +++ b/AnnService/src/Socket/ConnectionManager.cpp @@ -5,24 +5,17 @@ using namespace SPTAG::Socket; - -ConnectionManager::ConnectionItem::ConnectionItem() - : m_isEmpty(true) +ConnectionManager::ConnectionItem::ConnectionItem() : m_isEmpty(true) { } - -ConnectionManager::ConnectionManager() - : m_nextConnectionID(1), - m_connectionCount(0) +ConnectionManager::ConnectionManager() : m_nextConnectionID(1), m_connectionCount(0) { } - -ConnectionID -ConnectionManager::AddConnection(boost::asio::ip::tcp::socket&& p_socket, - const PacketHandlerMapPtr& p_handler, - std::uint32_t p_heartbeatIntervalSeconds) +ConnectionID ConnectionManager::AddConnection(boost::asio::ip::tcp::socket &&p_socket, + const PacketHandlerMapPtr &p_handler, + std::uint32_t p_heartbeatIntervalSeconds) { ConnectionID currID = m_nextConnectionID.fetch_add(1); while (c_invalidConnectionID == currID || !m_connections[GetPosition(currID)].m_isEmpty.exchange(false)) @@ -37,9 +30,7 @@ ConnectionManager::AddConnection(boost::asio::ip::tcp::socket&& p_socket, ++m_connectionCount; - auto connection = std::make_shared(currID, - std::move(p_socket), - p_handler, + auto connection = std::make_shared(currID, std::move(p_socket), p_handler, std::weak_ptr(shared_from_this())); { @@ -56,9 +47,7 @@ ConnectionManager::AddConnection(boost::asio::ip::tcp::socket&& p_socket, return currID; } - -void -ConnectionManager::RemoveConnection(ConnectionID p_connectionID) +void ConnectionManager::RemoveConnection(ConnectionID p_connectionID) { auto position = GetPosition(p_connectionID); if (m_connections[position].m_isEmpty.exchange(true)) @@ -84,9 +73,7 @@ ConnectionManager::RemoveConnection(ConnectionID p_connectionID) } } - -Connection::Ptr -ConnectionManager::GetConnection(ConnectionID p_connectionID) +Connection::Ptr ConnectionManager::GetConnection(ConnectionID p_connectionID) { auto position = GetPosition(p_connectionID); Connection::Ptr ret; @@ -104,19 +91,15 @@ ConnectionManager::GetConnection(ConnectionID p_connectionID) return ret; } - -void -ConnectionManager::SetEventOnRemoving(std::function p_event) +void ConnectionManager::SetEventOnRemoving(std::function p_event) { m_eventOnRemoving = std::move(p_event); } - -void -ConnectionManager::StopAll() +void ConnectionManager::StopAll() { Helper::Concurrent::LockGuard guard(m_spinLock); - for (auto& connection : m_connections) + for (auto &connection : m_connections) { if (nullptr != connection.m_connection) { @@ -125,9 +108,7 @@ ConnectionManager::StopAll() } } - -std::uint32_t -ConnectionManager::GetPosition(ConnectionID p_connectionID) +std::uint32_t ConnectionManager::GetPosition(ConnectionID p_connectionID) { return static_cast(p_connectionID) & c_connectionPoolMask; } diff --git a/AnnService/src/Socket/Packet.cpp b/AnnService/src/Socket/Packet.cpp index 335400bbd..78510e55e 100644 --- a/AnnService/src/Socket/Packet.cpp +++ b/AnnService/src/Socket/Packet.cpp @@ -9,39 +9,27 @@ using namespace SPTAG::Socket; PacketHeader::PacketHeader() - : m_packetType(PacketType::Undefined), - m_processStatus(PacketProcessStatus::Ok), - m_bodyLength(0), - m_connectionID(c_invalidConnectionID), - m_resourceID(c_invalidResourceID) + : m_packetType(PacketType::Undefined), m_processStatus(PacketProcessStatus::Ok), m_bodyLength(0), + m_connectionID(c_invalidConnectionID), m_resourceID(c_invalidResourceID) { } - -PacketHeader::PacketHeader(PacketHeader&& p_right) - : m_packetType(std::move(p_right.m_packetType)), - m_processStatus(std::move(p_right.m_processStatus)), - m_bodyLength(std::move(p_right.m_bodyLength)), - m_connectionID(std::move(p_right.m_connectionID)), +PacketHeader::PacketHeader(PacketHeader &&p_right) + : m_packetType(std::move(p_right.m_packetType)), m_processStatus(std::move(p_right.m_processStatus)), + m_bodyLength(std::move(p_right.m_bodyLength)), m_connectionID(std::move(p_right.m_connectionID)), m_resourceID(std::move(p_right.m_resourceID)) { } - -PacketHeader::PacketHeader(const PacketHeader& p_right) - : m_packetType(p_right.m_packetType), - m_processStatus(p_right.m_processStatus), - m_bodyLength(p_right.m_bodyLength), - m_connectionID(p_right.m_connectionID), - m_resourceID(p_right.m_resourceID) +PacketHeader::PacketHeader(const PacketHeader &p_right) + : m_packetType(p_right.m_packetType), m_processStatus(p_right.m_processStatus), m_bodyLength(p_right.m_bodyLength), + m_connectionID(p_right.m_connectionID), m_resourceID(p_right.m_resourceID) { } - -std::size_t -PacketHeader::WriteBuffer(std::uint8_t* p_buffer) +std::size_t PacketHeader::WriteBuffer(std::uint8_t *p_buffer) { - std::uint8_t* buff = p_buffer; + std::uint8_t *buff = p_buffer; buff = SimpleSerialization::SimpleWriteBuffer(m_packetType, buff); buff = SimpleSerialization::SimpleWriteBuffer(m_processStatus, buff); buff = SimpleSerialization::SimpleWriteBuffer(m_bodyLength, buff); @@ -51,11 +39,9 @@ PacketHeader::WriteBuffer(std::uint8_t* p_buffer) return p_buffer - buff; } - -void -PacketHeader::ReadBuffer(const std::uint8_t* p_buffer) +void PacketHeader::ReadBuffer(const std::uint8_t *p_buffer) { - const std::uint8_t* buff = p_buffer; + const std::uint8_t *buff = p_buffer; buff = SimpleSerialization::SimpleReadBuffer(buff, m_packetType); buff = SimpleSerialization::SimpleReadBuffer(buff, m_processStatus); buff = SimpleSerialization::SimpleReadBuffer(buff, m_bodyLength); @@ -63,44 +49,32 @@ PacketHeader::ReadBuffer(const std::uint8_t* p_buffer) buff = SimpleSerialization::SimpleReadBuffer(buff, m_resourceID); } - Packet::Packet() { } - -Packet::Packet(Packet&& p_right) - : m_header(std::move(p_right.m_header)), - m_buffer(std::move(p_right.m_buffer)), +Packet::Packet(Packet &&p_right) + : m_header(std::move(p_right.m_header)), m_buffer(std::move(p_right.m_buffer)), m_bufferCapacity(std::move(p_right.m_bufferCapacity)) { } - -Packet::Packet(const Packet& p_right) - : m_header(p_right.m_header), - m_buffer(p_right.m_buffer), - m_bufferCapacity(p_right.m_bufferCapacity) +Packet::Packet(const Packet &p_right) + : m_header(p_right.m_header), m_buffer(p_right.m_buffer), m_bufferCapacity(p_right.m_bufferCapacity) { } - -PacketHeader& -Packet::Header() +PacketHeader &Packet::Header() { return m_header; } - -std::uint8_t* -Packet::HeaderBuffer() const +std::uint8_t *Packet::HeaderBuffer() const { return m_buffer.get(); } - -std::uint8_t* -Packet::Body() const +std::uint8_t *Packet::Body() const { if (nullptr != m_buffer && PacketHeader::c_bufferSize < m_bufferCapacity) { @@ -110,38 +84,28 @@ Packet::Body() const return nullptr; } - -std::uint8_t* -Packet::Buffer() const +std::uint8_t *Packet::Buffer() const { return m_buffer.get(); } - -std::uint32_t -Packet::BufferLength() const +std::uint32_t Packet::BufferLength() const { return PacketHeader::c_bufferSize + m_header.m_bodyLength; } - -std::uint32_t -Packet::BufferCapacity() const +std::uint32_t Packet::BufferCapacity() const { return m_bufferCapacity; } - -void -Packet::AllocateBuffer(std::uint32_t p_bodyCapacity) +void Packet::AllocateBuffer(std::uint32_t p_bodyCapacity) { m_bufferCapacity = PacketHeader::c_bufferSize + p_bodyCapacity; m_buffer.reset(new std::uint8_t[m_bufferCapacity], std::default_delete()); } - -bool -PacketTypeHelper::IsRequestPacket(PacketType p_type) +bool PacketTypeHelper::IsRequestPacket(PacketType p_type) { if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) { @@ -151,9 +115,7 @@ PacketTypeHelper::IsRequestPacket(PacketType p_type) return (static_cast(p_type) & static_cast(PacketType::ResponseMask)) == 0; } - -bool -PacketTypeHelper::IsResponsePacket(PacketType p_type) +bool PacketTypeHelper::IsResponsePacket(PacketType p_type) { if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) { @@ -163,9 +125,7 @@ PacketTypeHelper::IsResponsePacket(PacketType p_type) return (static_cast(p_type) & static_cast(PacketType::ResponseMask)) != 0; } - -PacketType -PacketTypeHelper::GetCrosspondingResponseType(PacketType p_type) +PacketType PacketTypeHelper::GetCrosspondingResponseType(PacketType p_type) { if (PacketType::Undefined == p_type || PacketType::ResponseMask == p_type) { diff --git a/AnnService/src/Socket/RemoteSearchQuery.cpp b/AnnService/src/Socket/RemoteSearchQuery.cpp index 2cb450328..55df44799 100644 --- a/AnnService/src/Socket/RemoteSearchQuery.cpp +++ b/AnnService/src/Socket/RemoteSearchQuery.cpp @@ -7,15 +7,11 @@ using namespace SPTAG; using namespace SPTAG::Socket; - -RemoteQuery::RemoteQuery() - : m_type(QueryType::String) +RemoteQuery::RemoteQuery() : m_type(QueryType::String) { } - -std::size_t -RemoteQuery::EstimateBufferSize() const +std::size_t RemoteQuery::EstimateBufferSize() const { std::size_t sum = 0; sum += SimpleSerialization::EstimateBufferSize(MajorVersion()); @@ -26,9 +22,7 @@ RemoteQuery::EstimateBufferSize() const return sum; } - -std::uint8_t* -RemoteQuery::Write(std::uint8_t* p_buffer) const +std::uint8_t *RemoteQuery::Write(std::uint8_t *p_buffer) const { p_buffer = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), p_buffer); @@ -39,9 +33,7 @@ RemoteQuery::Write(std::uint8_t* p_buffer) const return p_buffer; } - -const std::uint8_t* -RemoteQuery::Read(const std::uint8_t* p_buffer) +const std::uint8_t *RemoteQuery::Read(const std::uint8_t *p_buffer) { decltype(MajorVersion()) majorVer = 0; decltype(MirrorVersion()) mirrorVer = 0; @@ -59,29 +51,21 @@ RemoteQuery::Read(const std::uint8_t* p_buffer) return p_buffer; } - -RemoteSearchResult::RemoteSearchResult() - : m_status(ResultStatus::Timeout) +RemoteSearchResult::RemoteSearchResult() : m_status(ResultStatus::Timeout) { } - -RemoteSearchResult::RemoteSearchResult(const RemoteSearchResult& p_right) - : m_status(p_right.m_status), - m_allIndexResults(p_right.m_allIndexResults) +RemoteSearchResult::RemoteSearchResult(const RemoteSearchResult &p_right) + : m_status(p_right.m_status), m_allIndexResults(p_right.m_allIndexResults) { } - -RemoteSearchResult::RemoteSearchResult(RemoteSearchResult&& p_right) - : m_status(std::move(p_right.m_status)), - m_allIndexResults(std::move(p_right.m_allIndexResults)) +RemoteSearchResult::RemoteSearchResult(RemoteSearchResult &&p_right) + : m_status(std::move(p_right.m_status)), m_allIndexResults(std::move(p_right.m_allIndexResults)) { } - -RemoteSearchResult& -RemoteSearchResult::operator=(RemoteSearchResult&& p_right) +RemoteSearchResult &RemoteSearchResult::operator=(RemoteSearchResult &&p_right) { m_status = p_right.m_status; m_allIndexResults = std::move(p_right.m_allIndexResults); @@ -89,9 +73,7 @@ RemoteSearchResult::operator=(RemoteSearchResult&& p_right) return *this; } - -std::size_t -RemoteSearchResult::EstimateBufferSize() const +std::size_t RemoteSearchResult::EstimateBufferSize() const { std::size_t sum = 0; sum += SimpleSerialization::EstimateBufferSize(MajorVersion()); @@ -100,13 +82,13 @@ RemoteSearchResult::EstimateBufferSize() const sum += SimpleSerialization::EstimateBufferSize(m_status); sum += sizeof(std::uint32_t); - for (const auto& indexRes : m_allIndexResults) + for (const auto &indexRes : m_allIndexResults) { sum += SimpleSerialization::EstimateBufferSize(indexRes.m_indexName); sum += sizeof(std::uint32_t); sum += sizeof(bool); - for (const auto& res : indexRes.m_results) + for (const auto &res : indexRes.m_results) { sum += SimpleSerialization::EstimateBufferSize(res.VID); sum += SimpleSerialization::EstimateBufferSize(res.Dist); @@ -124,23 +106,22 @@ RemoteSearchResult::EstimateBufferSize() const return sum; } - -std::uint8_t* -RemoteSearchResult::Write(std::uint8_t* p_buffer) const +std::uint8_t *RemoteSearchResult::Write(std::uint8_t *p_buffer) const { p_buffer = SimpleSerialization::SimpleWriteBuffer(MajorVersion(), p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(MirrorVersion(), p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(m_status, p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(static_cast(m_allIndexResults.size()), p_buffer); - for (const auto& indexRes : m_allIndexResults) + for (const auto &indexRes : m_allIndexResults) { p_buffer = SimpleSerialization::SimpleWriteBuffer(indexRes.m_indexName, p_buffer); - p_buffer = SimpleSerialization::SimpleWriteBuffer(static_cast(indexRes.m_results.GetResultNum()), p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(static_cast(indexRes.m_results.GetResultNum()), + p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(indexRes.m_results.WithMeta(), p_buffer); - for (const auto& res : indexRes.m_results) + for (const auto &res : indexRes.m_results) { p_buffer = SimpleSerialization::SimpleWriteBuffer(res.VID, p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(res.Dist, p_buffer); @@ -158,9 +139,7 @@ RemoteSearchResult::Write(std::uint8_t* p_buffer) const return p_buffer; } - -const std::uint8_t* -RemoteSearchResult::Read(const std::uint8_t* p_buffer) +const std::uint8_t *RemoteSearchResult::Read(const std::uint8_t *p_buffer) { decltype(MajorVersion()) majorVer = 0; decltype(MirrorVersion()) mirrorVer = 0; @@ -178,7 +157,7 @@ RemoteSearchResult::Read(const std::uint8_t* p_buffer) p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, len); m_allIndexResults.resize(len); - for (auto& indexRes : m_allIndexResults) + for (auto &indexRes : m_allIndexResults) { p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, indexRes.m_indexName); @@ -189,7 +168,7 @@ RemoteSearchResult::Read(const std::uint8_t* p_buffer) p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, withMeta); indexRes.m_results.Init(nullptr, resNum, withMeta); - for (auto& res : indexRes.m_results) + for (auto &res : indexRes.m_results) { p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.VID); p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.Dist); diff --git a/AnnService/src/Socket/Server.cpp b/AnnService/src/Socket/Server.cpp index 0d2a5cb2a..9781bf1d4 100644 --- a/AnnService/src/Socket/Server.cpp +++ b/AnnService/src/Socket/Server.cpp @@ -8,13 +8,9 @@ using namespace SPTAG::Socket; -Server::Server(const std::string& p_address, - const std::string& p_port, - const PacketHandlerMapPtr& p_handlerMap, +Server::Server(const std::string &p_address, const std::string &p_port, const PacketHandlerMapPtr &p_handlerMap, std::size_t p_threadNum) - : m_requestHandlerMap(p_handlerMap), - m_connectionManager(new ConnectionManager), - m_acceptor(m_ioContext) + : m_requestHandlerMap(p_handlerMap), m_connectionManager(new ConnectionManager), m_acceptor(m_ioContext) { boost::asio::ip::tcp::resolver resolver(m_ioContext); @@ -22,11 +18,8 @@ Server::Server(const std::string& p_address, auto endPoints = resolver.resolve(p_address, p_port, errCode); if (errCode) { - LOG(Helper::LogLevel::LL_Error, - "Failed to resolve %s %s, error: %s", - p_address.c_str(), - p_port.c_str(), - errCode.message().c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to resolve %s %s, error: %s", p_address.c_str(), + p_port.c_str(), errCode.message().c_str()); throw std::runtime_error("Failed to resolve address."); } @@ -38,11 +31,8 @@ Server::Server(const std::string& p_address, m_acceptor.bind(endpoint, errCode); if (errCode) { - LOG(Helper::LogLevel::LL_Error, - "Failed to bind %s %s, error: %s", - p_address.c_str(), - p_port.c_str(), - errCode.message().c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to bind %s %s, error: %s", p_address.c_str(), p_port.c_str(), + errCode.message().c_str()); throw std::runtime_error("Failed to bind port."); } @@ -50,17 +40,14 @@ Server::Server(const std::string& p_address, m_acceptor.listen(boost::asio::socket_base::max_listen_connections, errCode); if (errCode) { - LOG(Helper::LogLevel::LL_Error, - "Failed to listen %s %s, error: %s", - p_address.c_str(), - p_port.c_str(), - errCode.message().c_str()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to listen %s %s, error: %s", p_address.c_str(), p_port.c_str(), + errCode.message().c_str()); throw std::runtime_error("Failed to listen port."); } StartAccept(); - + m_threadPool.reserve(p_threadNum); for (std::size_t i = 0; i < p_threadNum; ++i) { @@ -68,7 +55,6 @@ Server::Server(const std::string& p_address, } } - Server::~Server() { m_acceptor.close(); @@ -78,52 +64,40 @@ Server::~Server() m_ioContext.stop(); } - for (auto& t : m_threadPool) + for (auto &t : m_threadPool) { t.join(); } } - -void -Server::SetEventOnConnectionClose(std::function p_event) +void Server::SetEventOnConnectionClose(std::function p_event) { m_connectionManager->SetEventOnRemoving(std::move(p_event)); } - -void -Server::StartAccept() +void Server::StartAccept() { - m_acceptor.async_accept([this](boost::system::error_code p_ec, - boost::asio::ip::tcp::socket p_socket) - { - if (!m_acceptor.is_open()) - { - return; - } - - if (!p_ec) - { - m_connectionManager->AddConnection(std::move(p_socket), - m_requestHandlerMap, - 0); - } - - StartAccept(); - }); + m_acceptor.async_accept([this](boost::system::error_code p_ec, boost::asio::ip::tcp::socket p_socket) { + if (!m_acceptor.is_open()) + { + return; + } + + if (!p_ec) + { + m_connectionManager->AddConnection(std::move(p_socket), m_requestHandlerMap, 0); + } + + StartAccept(); + }); } - -void -Server::StartListen() +void Server::StartListen() { m_ioContext.run(); } - -void -Server::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) +void Server::SendPacket(ConnectionID p_connection, Packet p_packet, std::function p_callback) { auto connection = m_connectionManager->GetConnection(p_connection); if (nullptr != connection) diff --git a/AnnService/src/UsefulTool/main.cpp b/AnnService/src/UsefulTool/main.cpp new file mode 100644 index 000000000..600064437 --- /dev/null +++ b/AnnService/src/UsefulTool/main.cpp @@ -0,0 +1,671 @@ +#include +#include +#include + +#include "inc/Core/Common.h" +#include "inc/Core/Common/DistanceUtils.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/Common/TruthSet.h" +#include "inc/Core/SPANN/Index.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/SSDServing/Utils.h" + +using namespace SPTAG; + +class ToolOptions : public Helper::ReaderOptions +{ + public: + ToolOptions() : Helper::ReaderOptions(VectorValueType::Float, 0, VectorFileType::TXT, "|", 32) + { + AddOptionalOption(newDataSetFileName, "-ndf", "--NewDataSetFileName", "New dataset file name."); + AddOptionalOption(currentListFileName, "-clf", "--CurrentListFileName", "Current list file name."); + AddOptionalOption(reserveListFileName, "-rlfe", "--ReserveListFileName", "Reserve list file name."); + AddOptionalOption(traceFileName, "-tf", "--TraceFileName", "Trace file name."); + AddOptionalOption(baseNum, "-bn", "--BaseNum", "Base vector number."); + AddOptionalOption(reserveNum, "-rn", "--ReserveNum", "Reserve number."); + AddOptionalOption(updateSize, "-us", "--UpdateSize", "Update size."); + AddOptionalOption(batch, "-bs", "--Batch", "Batch size."); + AddOptionalOption(genTrace, "-gt", "--GenTrace", "Gen trace."); + AddOptionalOption(convertTruth, "-ct", "--ConvertTruth", "Convert Truth."); + AddOptionalOption(callRecall, "-cr", "--CallRecall", "Calculate Recall."); + AddOptionalOption(genSet, "-gs", "--GenSet", "Generate set."); + AddOptionalOption(genStress, "-gss", "--GenStress", "Generate stress data."); + AddOptionalOption(m_vectorPath, "-vp", "--VectorPath", "Vector Path"); + AddOptionalOption(m_distCalcMethod, "-m", "--dist", "Distance method (L2 or Cosine)."); + AddOptionalOption(m_resultNum, "-rn", "--resultNum", "Result num."); + AddOptionalOption(m_querySize, "-qs", "--querySize", "Query size."); + AddOptionalOption(m_truthPath, "-tp", "--truthPath", "Truth path."); + AddOptionalOption(m_truthType, "-tt", "--truthType", "Truth type."); + AddOptionalOption(m_queryPath, "-qp", "--queryPath", "Query path."); + AddOptionalOption(m_searchResult, "-sr", "--searchResult", "Search result path."); + } + + ~ToolOptions() + { + } + + std::string newDataSetFileName = ""; + std::string currentListFileName = ""; + std::string reserveListFileName = ""; + std::string traceFileName = ""; + std::string m_vectorPath = ""; + int baseNum = 1; + int reserveNum = 1; + int updateSize = 1; + int batch = 1; + bool genTrace = false; + bool convertTruth = false; + bool callRecall = false; + bool genSet = false; + bool genStress = false; + DistCalcMethod m_distCalcMethod = DistCalcMethod::L2; + int m_resultNum = 1; + int m_querySize = 1; + std::string m_truthPath = ""; + TruthFileType m_truthType = TruthFileType::DEFAULT; + std::string m_queryPath = ""; + std::string m_searchResult = ""; + std::string m_headVectorFile = ""; + std::string m_headIDFile = ""; + +} options; + +template +void PrintPercentiles(const std::vector &p_values, std::function p_get, const char *p_format, + bool reverse = false) +{ + double sum = 0; + std::vector collects; + collects.reserve(p_values.size()); + for (const auto &v : p_values) + { + T tmp = p_get(v); + sum += tmp; + collects.push_back(tmp); + } + + if (reverse) + { + std::sort(collects.begin(), collects.end(), std::greater()); + } + else + { + std::sort(collects.begin(), collects.end()); + } + if (reverse) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMin\n"); + } + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg\t50tiles\t90tiles\t95tiles\t99tiles\t99.9tiles\tMax\n"); + } + + std::string formatStr("%.3lf"); + for (int i = 1; i < 7; ++i) + { + formatStr += '\t'; + formatStr += p_format; + } + + formatStr += '\n'; + + SPTAGLIB_LOG( + Helper::LogLevel::LL_Info, formatStr.c_str(), sum / collects.size(), + collects[static_cast(collects.size() * 0.50)], collects[static_cast(collects.size() * 0.90)], + collects[static_cast(collects.size() * 0.95)], collects[static_cast(collects.size() * 0.99)], + collects[static_cast(collects.size() * 0.999)], collects[static_cast(collects.size() - 1)]); +} + +void LoadTruth(ToolOptions &p_opts, std::vector> &truth, int numQueries, std::string truthfilename, + int truthK) +{ + auto ptr = f_createIO(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...: %s\n", truthfilename.c_str()); + + if (ptr == nullptr || !ptr->Initialize(truthfilename.c_str(), std::ios::in | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open truth file: %s\n", truthfilename.c_str()); + exit(1); + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "K: %d, TruthResultNum: %d\n", p_opts.m_resultNum, p_opts.m_resultNum); + COMMON::TruthSet::LoadTruth(ptr, truth, numQueries, p_opts.m_resultNum, p_opts.m_resultNum, p_opts.m_truthType); + char tmp[4]; + if (ptr->ReadBinary(4, tmp) == 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number is larger than query number(%d)!\n", numQueries); + } +} + +std::shared_ptr LoadVectorSet(ToolOptions &p_opts, int numThreads) +{ + std::shared_ptr vectorSet; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading VectorSet...\n"); + if (!p_opts.m_vectorPath.empty() && fileexists(p_opts.m_vectorPath.c_str())) + { + std::shared_ptr vectorOptions(new Helper::ReaderOptions( + p_opts.m_inputValueType, p_opts.m_dimension, p_opts.m_inputFileType, p_opts.m_vectorDelimiter)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (ErrorCode::Success == vectorReader->LoadFile(p_opts.m_vectorPath)) + { + vectorSet = vectorReader->GetVectorSet(); + if (p_opts.m_distCalcMethod == DistCalcMethod::Cosine) + vectorSet->Normalize(numThreads); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nLoad VectorSet(%d,%d).\n", vectorSet->Count(), + vectorSet->Dimension()); + } + } + return vectorSet; +} + +std::shared_ptr LoadQuerySet(ToolOptions &p_opts) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading QuerySet...\n"); + std::shared_ptr queryOptions(new Helper::ReaderOptions( + p_opts.m_inputValueType, p_opts.m_dimension, p_opts.m_inputFileType, p_opts.m_vectorDelimiter)); + auto queryReader = Helper::VectorSetReader::CreateInstance(queryOptions); + if (ErrorCode::Success != queryReader->LoadFile(p_opts.m_queryPath)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + exit(1); + } + return queryReader->GetVectorSet(); +} + +void LoadOutputResult(ToolOptions &p_opts, std::vector> &ids, + std::vector> &dists) +{ + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(p_opts.m_searchResult.c_str(), std::ios::binary | std::ios::in)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open file: %s\n", p_opts.m_searchResult.c_str()); + exit(1); + } + int32_t NumQuerys; + if (ptr->ReadBinary(sizeof(NumQuerys), (char *)&NumQuerys) != sizeof(NumQuerys)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read result file!\n"); + exit(1); + } + int32_t resultNum; + if (ptr->ReadBinary(sizeof(resultNum), (char *)&resultNum) != sizeof(resultNum)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read result file!\n"); + exit(1); + } + + for (size_t i = 0; i < NumQuerys; ++i) + { + std::vector tempVec_id; + std::vector tempVec_dist; + for (int j = 0; j < resultNum; ++j) + { + int32_t i32Val; + if (ptr->ReadBinary(sizeof(i32Val), (char *)&i32Val) != sizeof(i32Val)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read result file!\n"); + exit(1); + } + + float fVal; + if (ptr->ReadBinary(sizeof(fVal), (char *)&fVal) != sizeof(fVal)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read result file!\n"); + exit(1); + } + tempVec_id.push_back(i32Val); + tempVec_dist.push_back(fVal); + } + ids.push_back(tempVec_id); + dists.push_back(tempVec_dist); + } +} + +void LoadUpdateMapping(std::string fileName, std::vector &reverseIndices) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading %s\n", fileName.c_str()); + + int vectorNum; + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::in | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + if (ptr->ReadBinary(4, (char *)&vectorNum) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + } + + reverseIndices.clear(); + reverseIndices.resize(vectorNum); + + if (ptr->ReadBinary(vectorNum * 4, (char *)reverseIndices.data()) != vectorNum * 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "update mapping Error!\n"); + exit(1); + } +} + +void SaveUpdateMapping(std::string fileName, std::vector &reverseIndices, SizeType vectorNum) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving %s\n", fileName.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + if (ptr->WriteBinary(4, reinterpret_cast(&vectorNum)) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + } + + for (int i = 0; i < vectorNum; i++) + { + if (ptr->WriteBinary(4, reinterpret_cast(&reverseIndices[i])) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "update mapping Error!\n"); + exit(1); + } + } +} + +void SaveUpdateTrace(std::string fileName, SizeType &updateSize, std::vector &insertSet, + std::vector &deleteSet) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving %s\n", fileName.c_str()); + + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + + if (ptr->WriteBinary(4, reinterpret_cast(&updateSize)) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + exit(1); + } + + for (int i = 0; i < updateSize; i++) + { + if (ptr->WriteBinary(4, reinterpret_cast(&deleteSet[i])) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + exit(1); + } + } + + for (int i = 0; i < updateSize; i++) + { + if (ptr->WriteBinary(4, reinterpret_cast(&insertSet[i])) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + exit(1); + } + } +} + +void InitializeList(std::vector ¤t_list, std::vector &reserve_list, ToolOptions &p_opts) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Initialize list\n"); + if (p_opts.batch == 0) + { + current_list.resize(p_opts.baseNum); + reserve_list.resize(p_opts.reserveNum); + for (int i = 0; i < p_opts.baseNum; i++) + current_list[i] = i; + for (int i = 0; i < p_opts.reserveNum; i++) + reserve_list[i] = i + p_opts.baseNum; + } + else + { + LoadUpdateMapping(p_opts.currentListFileName + std::to_string(p_opts.batch - 1), current_list); + LoadUpdateMapping(p_opts.reserveListFileName + std::to_string(p_opts.batch - 1), reserve_list); + } +} + +template void GenerateTrace(ToolOptions &p_opts) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Generating Trace\n"); + std::shared_ptr vectorSet; + std::vector current_list; + std::vector reserve_list; + InitializeList(current_list, reserve_list, p_opts); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Loading Vector Set\n"); + if (!p_opts.m_vectorPath.empty() && fileexists(p_opts.m_vectorPath.c_str())) + { + std::shared_ptr vectorOptions(new Helper::ReaderOptions( + p_opts.m_inputValueType, p_opts.m_dimension, p_opts.m_inputFileType, p_opts.m_vectorDelimiter)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(vectorOptions); + if (ErrorCode::Success == vectorReader->LoadFile(p_opts.m_vectorPath)) + { + vectorSet = vectorReader->GetVectorSet(); + if (p_opts.m_distCalcMethod == DistCalcMethod::Cosine) + vectorSet->Normalize(4); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nLoad VectorSet(%d,%d).\n", vectorSet->Count(), + vectorSet->Dimension()); + } + } + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "batch: %d, updateSize: %d\n", p_opts.batch, p_opts.updateSize); + std::vector deleteList; + std::vector insertList; + deleteList.resize(p_opts.updateSize); + insertList.resize(p_opts.updateSize); + std::mt19937 rg; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Generate delete list\n"); + std::shuffle(current_list.begin(), current_list.end(), rg); + for (int j = 0; j < p_opts.updateSize; j++) + { + deleteList[j] = current_list[p_opts.baseNum - j - 1]; + } + current_list.resize(p_opts.baseNum - p_opts.updateSize); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Generate insert list\n"); + std::shuffle(reserve_list.begin(), reserve_list.end(), rg); + for (int j = 0; j < p_opts.updateSize; j++) + { + insertList[j] = reserve_list[p_opts.reserveNum - j - 1]; + current_list.push_back(reserve_list[p_opts.reserveNum - j - 1]); + } + + reserve_list.resize(p_opts.reserveNum - p_opts.updateSize); + for (int j = 0; j < p_opts.updateSize; j++) + { + reserve_list.push_back(deleteList[j]); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Sorting list\n"); + std::sort(current_list.begin(), current_list.end()); + std::sort(reserve_list.begin(), reserve_list.end()); + + SaveUpdateMapping(p_opts.currentListFileName + std::to_string(p_opts.batch), current_list, p_opts.baseNum); + SaveUpdateMapping(p_opts.reserveListFileName + std::to_string(p_opts.batch), reserve_list, p_opts.reserveNum); + + SaveUpdateTrace(p_opts.traceFileName + std::to_string(p_opts.batch), p_opts.updateSize, insertList, deleteList); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Generate new dataset for truth\n"); + COMMON::Dataset newSample(0, p_opts.m_dimension, 1024 * 1024, MaxSize); + + for (int i = 0; i < p_opts.baseNum; i++) + { + newSample.AddBatch(1, (ValueType *)(vectorSet->GetVector(current_list[i]))); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Saving Truth\n"); + newSample.Save(p_opts.newDataSetFileName + std::to_string(p_opts.batch)); +} + +template void GenerateStressTest(ToolOptions &p_opts) +{ + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Begin Generating Stress Test\n"); + std::shared_ptr vectorSet; + std::vector current_list; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "batch: %d, updateSize: %d\n", p_opts.batch, p_opts.updateSize); + current_list.resize(p_opts.baseNum); + for (int i = 0; i < p_opts.baseNum; i++) + current_list[i] = i; + std::mt19937 rg; + for (int i = 0; i < p_opts.batch; i++) + { + std::vector insertList; + insertList.resize(p_opts.updateSize); + std::shuffle(current_list.begin(), current_list.end(), rg); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Generate delete/re-insert list\n"); + for (int j = 0; j < p_opts.updateSize; j++) + { + insertList[j] = current_list[p_opts.baseNum - j - 1]; + } + std::string fileName = p_opts.traceFileName + std::to_string(i); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(fileName.c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed open trace file: %s\n", fileName.c_str()); + exit(1); + } + if (ptr->WriteBinary(4, reinterpret_cast(&p_opts.updateSize)) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + exit(1); + } + + for (int i = 0; i < p_opts.updateSize; i++) + { + if (ptr->WriteBinary(4, reinterpret_cast(&insertList[i])) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "vector Size Error!\n"); + exit(1); + } + } + } +} + +template void ConvertTruth(ToolOptions &p_opts) +{ + std::vector current_list; + LoadUpdateMapping(p_opts.currentListFileName + std::to_string(p_opts.batch), current_list); + int K = p_opts.m_resultNum; + + int numQueries = p_opts.m_querySize; + + std::string truthFile = p_opts.m_truthPath + std::to_string(p_opts.batch); + std::vector> truth; + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start loading TruthFile...\n"); + truth.resize(numQueries); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(truthFile.c_str(), std::ios::in | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read the file:%s\n", truthFile.c_str()); + exit(1); + } + + if (ptr->TellP() == 0) + { + int row, originalK; + if (ptr->ReadBinary(4, (char *)&row) != 4 || ptr->ReadBinary(4, (char *)&originalK) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to read truth file!\n"); + exit(1); + } + } + + for (int i = 0; i < numQueries; ++i) + { + truth[i].clear(); + truth[i].resize(K); + if (ptr->ReadBinary(K * 4, (char *)truth[i].data()) != K * 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth number(%d) and query number(%d) are not match!\n", i, + numQueries); + exit(1); + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "ChangeMapping & Writing\n"); + std::string truthFileAfter = p_opts.m_truthPath + "_after" + std::to_string(p_opts.batch); + ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize(truthFileAfter.c_str(), std::ios::out | std::ios::binary)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to create the file:%s\n", truthFileAfter.c_str()); + exit(1); + } + ptr->WriteBinary(4, (char *)&numQueries); + ptr->WriteBinary(4, (char *)&K); + for (int i = 0; i < numQueries; i++) + { + for (int j = 0; j < K; j++) + { + if (ptr->WriteBinary(4, (char *)(¤t_list[truth[i][j]])) != 4) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Fail to write the truth file!\n"); + exit(1); + } + } + } +} + +template void CallRecall(ToolOptions &p_opts) +{ + std::string truthfile = p_opts.m_truthPath; + std::vector> truth; + int truthK = p_opts.m_resultNum; + int K = p_opts.m_resultNum; + auto vectorSet = LoadVectorSet(p_opts, 10); + auto querySet = LoadQuerySet(p_opts); + int NumQuerys = querySet->Count(); + LoadTruth(p_opts, truth, NumQuerys, truthfile, truthK); + std::vector> ids; + std::vector> dists; + + LoadOutputResult(p_opts, ids, dists); + + float meanrecall = 0, minrecall = MaxDist, maxrecall = 0, stdrecall = 0; + std::vector thisrecall(NumQuerys, 0); + std::unique_ptr visited(new bool[K]); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start Calculating Recall\n"); + for (SizeType i = 0; i < NumQuerys; i++) + { + memset(visited.get(), 0, K * sizeof(bool)); + for (SizeType id : truth[i]) + { + for (int j = 0; j < K; j++) + { + if (visited[j] || ids[i][j] < 0) + continue; + if (vectorSet != nullptr) + { + float dist = dists[i][j]; + float truthDist = COMMON::DistanceUtils::ComputeDistance( + (const ValueType *)querySet->GetVector(i), (const ValueType *)vectorSet->GetVector(id), + vectorSet->Dimension(), SPTAG::DistCalcMethod::L2); + if (fabs(dist - truthDist) < Epsilon * (dist + Epsilon)) + { + thisrecall[i] += 1; + visited[j] = true; + break; + } + } + } + } + thisrecall[i] /= truthK; + meanrecall += thisrecall[i]; + if (thisrecall[i] < minrecall) + minrecall = thisrecall[i]; + if (thisrecall[i] > maxrecall) + maxrecall = thisrecall[i]; + } + meanrecall /= NumQuerys; + for (SizeType i = 0; i < NumQuerys; i++) + { + stdrecall += (thisrecall[i] - meanrecall) * (thisrecall[i] - meanrecall); + } + stdrecall = std::sqrt(stdrecall / NumQuerys); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "stdrecall: %.6lf, maxrecall: %.2lf, minrecall: %.2lf\n", stdrecall, + maxrecall, minrecall); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "\nRecall Distribution:\n"); + PrintPercentiles( + thisrecall, [](const float recall) -> float { return recall; }, "%.3lf", true); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall%d@%d: %f\n", K, truthK, meanrecall); +} + +int main(int argc, char *argv[]) +{ + if (!options.Parse(argc - 1, argv + 1)) + { + exit(1); + } + if (options.genTrace) + { + switch (options.m_inputValueType) + { + case SPTAG::VectorValueType::Float: + GenerateTrace(options); + break; + case SPTAG::VectorValueType::Int16: + GenerateTrace(options); + break; + case SPTAG::VectorValueType::Int8: + GenerateTrace(options); + break; + case SPTAG::VectorValueType::UInt8: + GenerateTrace(options); + break; + default: + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + } + } + + if (options.genStress) + { + switch (options.m_inputValueType) + { + case SPTAG::VectorValueType::Float: + GenerateStressTest(options); + break; + case SPTAG::VectorValueType::Int16: + GenerateStressTest(options); + break; + case SPTAG::VectorValueType::Int8: + GenerateStressTest(options); + break; + case SPTAG::VectorValueType::UInt8: + GenerateStressTest(options); + break; + default: + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + } + } + + if (options.convertTruth) + { + switch (options.m_inputValueType) + { + case SPTAG::VectorValueType::Float: + ConvertTruth(options); + break; + case SPTAG::VectorValueType::Int16: + ConvertTruth(options); + break; + case SPTAG::VectorValueType::Int8: + ConvertTruth(options); + break; + case SPTAG::VectorValueType::UInt8: + ConvertTruth(options); + break; + default: + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + } + } + + if (options.callRecall) + { + switch (options.m_inputValueType) + { + case SPTAG::VectorValueType::Float: + CallRecall(options); + break; + case SPTAG::VectorValueType::Int16: + CallRecall(options); + break; + case SPTAG::VectorValueType::Int8: + CallRecall(options); + break; + case SPTAG::VectorValueType::UInt8: + CallRecall(options); + break; + default: + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error data type!\n"); + } + } + + return 0; +} diff --git a/CMakeLists.txt b/CMakeLists.txt index f1b97a451..a49d6dc8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. +# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. cmake_minimum_required (VERSION 3.12) @@ -18,20 +18,53 @@ if(NOT WIN32) CXX_COMPILER_DUMPVERSION(CXX_COMPILER_VERSION) endif() +option(USE_ASAN "Enable AddressSanitizer (ASAN) for memory debugging" OFF) + if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") # require at least gcc 5.0 if (CXX_COMPILER_VERSION VERSION_LESS 5.0) message(FATAL_ERROR "GCC version must be at least 5.0!") endif() - set (CMAKE_CXX_FLAGS "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -lm -lrt -std=c++14 -fopenmp") + set (CMAKE_CXX_FLAGS "-Wall -Wunreachable-code -Wno-reorder -Wno-pessimizing-move -Wno-delete-non-virtual-dtor -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -lm -lrt -std=c++17 -fopenmp") set (CMAKE_CXX_FLAGS_RELEASE "-DNDEBUG -O3 -march=native") set (CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG") + + if(USE_ASAN) + message(STATUS "ASAN is enabled") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer") + set(CMAKE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} -fsanitize=address") + endif() + + find_path(NUMA_INCLUDE_DIR NAME numa.h + HINTS $ENV{HOME}/local/include /opt/local/include /usr/local/include /usr/include) + + find_library(NUMA_LIBRARY NAME libnuma.so + HINTS $ENV{HOME}/local/lib64 $ENV{HOME}/local/lib /usr/local/lib64 /usr/local/lib /opt/local/lib64 /opt/local/lib /usr/lib64 /usr/lib) + + find_library(NUMA_LIBRARY_STATIC NAME libnuma.a + HINTS $ENV{HOME}/local/lib64 $ENV{HOME}/local/lib /usr/local/lib64 /usr/local/lib /opt/local/lib64 /opt/local/lib /usr/lib64 /usr/lib) + + if (NUMA_INCLUDE_DIR AND NUMA_LIBRARY AND NUMA_LIBRARY_STATIC) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -lnuma") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lnuma") + + include_directories (${NUMA_INCLUDE_DIR}) + message (STATUS "Found numa library: inc=${NUMA_INCLUDE_DIR}, lib=${NUMA_LIBRARY}, staticlib=${NUMA_LIBRARY_STATIC}") + add_definitions(-DNUMA) + else () + set (NUMA_LIBRARY "") + set (NUMA_LIBRARY_STATIC "") + message (STATUS "WARNING: Numa library not found.") + message (STATUS "Try: 'sudo yum install numactl numactl-devel' (or sudo apt-get install libnuma libnuma-dev)") + endif () + elseif(WIN32) if(NOT MSVC14) message(FATAL_ERROR "On Windows, only MSVC version 14 are supported!") endif() set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++14") - + set (NUMA_LIBRARY "") + set (NUMA_LIBRARY_STATIC "") message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") else () message(FATAL_ERROR "Unrecognized compiler (use GCC or MSVC)!") @@ -90,8 +123,60 @@ else() message (FATAL_ERROR "Could not find Boost >= 1.67!") endif() -option(GPU "GPU" ON) +option(GPU "GPU" OFF) option(LIBRARYONLY "LIBRARYONLY" OFF) +option(ROCKSDB "ROCKSDB" OFF) +option(SPDK "SPDK" OFF) +option(TBB "TBB" ON) +option(URING "URING" OFF) + +if (URING) + find_path(uring_INCLUDE_DIR NAMES liburing.h) + find_library(uring_LIBRARIES NAMES liburing.a liburing) + + if (uring_INCLUDE_DIR AND uring_LIBRARIES) + add_library(uring::uring INTERFACE IMPORTED) + set_target_properties(uring::uring PROPERTIES + INTERFACE_LINK_LIBRARIES "${uring_LIBRARIES}" + INTERFACE_INCLUDE_DIRECTORIES "${uring_INCLUDE_DIR}" + ) + add_definitions(-DURING) + set(uring_LIBRARIES uring::uring) + message (STATUS "Found uring ${uring_LIBRARIES}") + else() + message (FATAL_ERROR "Could not find uring!") + endif() +else() + set(uring_LIBRARIES "") +endif() + + +set(RocksDB_LIBRARIES "") +if (ROCKSDB) + find_package(RocksDB CONFIG) + if((DEFINED RocksDB_DIR) AND RocksDB_DIR) + list(APPEND RocksDB_LIBRARIES RocksDB::rocksdb) + add_definitions(-DROCKSDB) + message (STATUS "Found RocksDB ${RocksDB_VERSION} ${RocksDB_LIBRARIES}") + message (STATUS "RocksDB: ${RocksDB_DIR}") + else() + set(RocksDB_LIBRARIES "") + set(uring_LIBRARIES "") + message (FATAL_ERROR "Could not find RocksDB!") + endif() +endif() + +set (TBB_LIBRARIES "") +if (TBB) + find_package(TBB REQUIRED) + if (TBB_FOUND) + list(APPEND TBB_LIBRARIES tbb) + add_definitions(-DTBB) + message (STATUS "Found TBB ${TBB_LIBRARIES}") + else() + message (FATAL_ERROR "Could not find TBB!") + endif() +endif() add_subdirectory (ThirdParty/zstd/build/cmake) diff --git a/Dockerfile b/Dockerfile index 00ec1f842..1ac755f5e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,10 +3,11 @@ WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get -y install wget build-essential \ - swig cmake git \ +RUN apt-get update && apt-get -y install wget build-essential swig cmake git libnuma-dev python3.8-dev python3-distutils gcc-8 g++-8 \ libboost-filesystem-dev libboost-test-dev libboost-serialization-dev libboost-regex-dev libboost-serialization-dev libboost-regex-dev libboost-thread-dev libboost-system-dev +RUN wget https://bootstrap.pypa.io/get-pip.py && python3.8 get-pip.py && python3.8 -m pip install numpy + ENV PYTHONPATH=/app/Release COPY CMakeLists.txt ./ @@ -16,4 +17,4 @@ COPY Wrappers ./Wrappers/ COPY GPUSupport ./GPUSupport/ COPY ThirdParty ./ThirdParty/ -RUN mkdir build && cd build && cmake .. && make -j$(nproc) && cd .. +RUN export CC=/usr/bin/gcc-8 && export CXX=/usr/bin/g++-8 && mkdir build && cd build && cmake .. && make -j && cd .. diff --git a/Dockerfile.cuda b/Dockerfile.cuda index bfd265104..fe922a453 100644 --- a/Dockerfile.cuda +++ b/Dockerfile.cuda @@ -1,12 +1,33 @@ -FROM mcr.microsoft.com/oss/nvidia/k8s-device-plugin:v0.10.0 +FROM mcr.microsoft.com/mirror/nvcr/nvidia/cuda:11.8.0-base-ubuntu20.04 WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get -y install build-essential \ - swig cmake git \ +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub + +RUN apt-get -y install wget build-essential swig cmake git libnuma-dev python3.8-dev python3-distutils gcc-8 g++-8 \ libboost-filesystem-dev libboost-test-dev libboost-serialization-dev libboost-regex-dev libboost-serialization-dev libboost-regex-dev libboost-thread-dev libboost-system-dev +RUN wget https://bootstrap.pypa.io/get-pip.py && python3.8 get-pip.py && python3.8 -m pip install numpy + +RUN git clone https://github.com/NVIDIA/cub && cd cub && git checkout c3cceac && cp -r cub /usr/include/ && cd .. + +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin && mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 \ + && wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-ubuntu2004-11-8-local_11.8.0-520.61.05-1_amd64.deb \ + && dpkg -i cuda-repo-ubuntu2004-11-8-local_11.8.0-520.61.05-1_amd64.deb && cp /var/cuda-repo-ubuntu2004-11-8-local/cuda-*-keyring.gpg /usr/share/keyrings/ \ + && apt-get update && apt-get install -y cuda-toolkit-11-8 + +LABEL com.nvidia.volumes.needed="nvidia_driver" +LABEL com.nvidia.cuda.version="11.8.0" + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} +ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +ENV NVIDIA_REQUIRE_CUDA "cuda>=11.0" + + ENV PYTHONPATH=/app/Release COPY CMakeLists.txt ./ @@ -16,4 +37,4 @@ COPY Wrappers ./Wrappers/ COPY GPUSupport ./GPUSupport/ COPY ThirdParty ./ThirdParty/ -RUN mkdir build && cd build && cmake .. && -j$(nproc) && cd .. +RUN export CC=/usr/bin/gcc-8 && export CXX=/usr/bin/g++-8 && mkdir build && cd build && cmake .. && make -j && cd .. diff --git a/GPUSupport/CMakeLists.txt b/GPUSupport/CMakeLists.txt index 380df2b3d..066a59f35 100644 --- a/GPUSupport/CMakeLists.txt +++ b/GPUSupport/CMakeLists.txt @@ -14,13 +14,13 @@ if (CUDA_FOUND) if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") set (CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC -Xcompiler -fopenmp -std=c++14 -Xptxas -O3 --use_fast_math --disable-warnings -lineinfo -gencode arch=compute_70,code=sm_70 - -gencode arch=compute_61,code=sm_61 - -gencode arch=compute_60,code=sm_60" ) + -gencode arch=compute_75,code=sm_75 + -gencode arch=compute_80,code=sm_80" ) elseif(WIN32) set (CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler /openmp -Xcompiler /std:c++14 -Xcompiler /Zc:__cplusplus -Xcompiler /FS --use_fast_math -gencode arch=compute_70,code=sm_70 - -gencode arch=compute_61,code=sm_61 - -gencode arch=compute_60,code=sm_60" ) + -gencode arch=compute_75,code=sm_75 + -gencode arch=compute_80,code=sm_80" ) endif() message (STATUS "CUDA_NVCC_FLAGS:" ${CUDA_NVCC_FLAGS}) @@ -28,6 +28,9 @@ if (CUDA_FOUND) set(AnnService ${PROJECT_SOURCE_DIR}/AnnService) include_directories(${AnnService}) + include_directories(${PROJECT_SOURCE_DIR}/Test/cuda) + + include_directories(${PROJECT_SOURCE_DIR}/ThirdParty/zstd/lib) file(GLOB_RECURSE GPU_HDR_FILES ${AnnService}/inc/Core/*.h ${AnnService}/inc/Helper/*.h ${AnnService}/inc/Core/Common/cuda/*) file(GLOB_RECURSE GPU_SRC_FILES ${AnnService}/src/Core/*.cpp ${AnnService}/src/Helper/*.cpp ${AnnService}/src/Core/Common/Kernel.cu) @@ -36,11 +39,11 @@ if (CUDA_FOUND) ${AnnService}/inc/Core/Common/DistanceUtils.h ${AnnService}/inc/Core/Common/InstructionUtils.h ${AnnService}/inc/Core/Common/CommonUtils.h - ${AnnService}/inc/Core/Common/cuda/cuda_unit_tests.cu ) list(REMOVE_ITEM GPU_SRC_FILES ${AnnService}/src/Core/Common/DistanceUtils.cpp ${AnnService}/src/Core/Common/InstructionUtils.cpp + ${AnnService}/Test/cuda ) set_source_files_properties(${GPU_SRC_FILES} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ) @@ -51,11 +54,11 @@ if (CUDA_FOUND) endif() CUDA_ADD_LIBRARY(GPUSPTAGLib SHARED ${GPU_SRC_FILES} ${GPU_HDR_FILES}) - target_link_libraries(GPUSPTAGLib DistanceUtils ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) + target_link_libraries(GPUSPTAGLib DistanceUtils ${Boost_LIBRARIES} ${CUDA_LIBRARIES} libzstd_shared ${NUMA_LIBRARY}) target_compile_definitions(GPUSPTAGLib PRIVATE ${Definition}) CUDA_ADD_LIBRARY(GPUSPTAGLibStatic STATIC ${GPU_SRC_FILES} ${GPU_HDR_FILES}) - target_link_libraries(GPUSPTAGLibStatic DistanceUtils ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) + target_link_libraries(GPUSPTAGLibStatic DistanceUtils ${Boost_LIBRARIES} ${CUDA_LIBRARIES} libzstd_static ${NUMA_LIBRARY_STATIC}) target_compile_definitions(GPUSPTAGLibStatic PRIVATE ${Definition}) add_dependencies(GPUSPTAGLibStatic GPUSPTAGLib) @@ -80,6 +83,19 @@ if (CUDA_FOUND) CUDA_ADD_EXECUTABLE(gpussdserving ${SSD_SERVING_HDR_FILES} ${SSD_SERVING_FILES}) target_link_libraries(gpussdserving GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) target_compile_definitions(gpussdserving PRIVATE ${Definition} _exe) + + CUDA_ADD_LIBRARY(GPUSPTAGTests SHARED ${AnnService}/../Test/cuda/knn_tests.cu ${AnnService}/../Test/cuda/distance_tests.cu ${AnnService}/../Test/cuda/tptree_tests.cu ${AnnService}/../Test/cuda/buildssd_test.cu ${AnnService}/../Test/cuda/gpu_pq_perf.cu) + target_link_libraries(GPUSPTAGTests ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) + target_compile_definitions(GPUSPTAGTests PRIVATE ${Definition}) + + CUDA_ADD_EXECUTABLE(gpu_test ${AnnService}/../Test/cuda/cuda_tests.cpp) + target_link_libraries(gpu_test GPUSPTAGTests GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) + target_compile_definitions(gpu_test PRIVATE ${Definition} _exe) + + CUDA_ADD_EXECUTABLE(gpu_pq_test ${AnnService}/../Test/cuda/pq_perf.cpp) + target_link_libraries(gpu_pq_test GPUSPTAGTests GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES}) + target_compile_definitions(gpu_pq_test PRIVATE ${Definition} _exe) + else() message (STATUS "Could not find cuda.") endif() diff --git a/LICENSE b/LICENSE index d1ca00f20..21071075c 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE \ No newline at end of file + SOFTWARE diff --git a/MANIFEST.in b/MANIFEST.in index 0be181b5f..f5d87b092 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include sptag *.py _SPTAG* Server.exe server \ No newline at end of file +recursive-include sptag *.py _SPTAG* *.so *.dll Server.exe server \ No newline at end of file diff --git a/README.md b/README.md index ccae4bdd6..d5cb08ac4 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,10 @@ architecture

- +## What's NEW +* Result Iterator with Relaxed Monotonicity Signal Support +* New Research Paper [SPFresh: Incremental In-Place Update for Billion-Scale Vector Search](https://dl.acm.org/doi/10.1145/3600006.3613166) - _published in SOSP 2023_ +* New Research Paper [VBASE: Unifying Online Vector Similarity Search and Relational Queries via Relaxed Monotonicity](https://www.usenix.org/system/files/osdi23-zhang-qianxi_1.pdf) - _published in OSDI 2023_ ## **Introduction** @@ -40,7 +43,7 @@ The searches in the trees and the graph are iteratively conducted. ### **Requirements** -* swig >= 3.0 +* swig >= 4.0.2 * cmake >= 3.12.0 * boost >= 1.67.0 @@ -48,7 +51,7 @@ The searches in the trees and the graph are iteratively conducted. ``` set GIT_LFS_SKIP_SMUDGE=1 -git clone https://github.com/microsoft/SPTAG +git clone --recurse-submodules https://github.com/microsoft/SPTAG OR @@ -59,16 +62,41 @@ git config --global filter.lfs.process "git-lfs filter-process --skip" ### **Install** > For Linux: +> Compile SPDK +```bash +cd ThirdParty/spdk +./scripts/pkgdep.sh +CC=gcc-9 ./configure +CC=gcc-9 make -j +``` + +> Compile isal-l_crypto +```bash +cd ThirdParty/isal-l_crypto +./autogen.sh +./configure +make -j +``` + +> Build RocksDB +```bash +mkdir build && cd build +cmake -DUSE_RTTI=1 -DWITH_JEMALLOC=1 -DWITH_SNAPPY=1 -DCMAKE_C_COMPILER=gcc-7 -DCMAKE_CXX_COMPILER=g++-7 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-fPIC" .. +make -j +sudo make install +``` + +> Build SPTAG ```bash mkdir build -cd build && cmake .. && make +cd build && cmake -DSPDK=OFF -DROCKSDB=OFF .. && make ``` It will generate a Release folder in the code directory which contains all the build targets. > For Windows: ```bash mkdir build -cd build && cmake -A x64 .. +cd build && cmake -A x64 -DSPDK=OFF -DROCKSDB=OFF .. ``` It will generate a SPTAGLib.sln in the build directory. Compiling the ALL_BUILD project in the Visual Studio (at least 2019) will generate a Release directory which contains all the build targets. @@ -93,6 +121,21 @@ The detailed parameters tunning can be found in [Parameters](docs/Parameters.md) ## **References** Please cite SPTAG in your publications if it helps your research: ``` +@inproceedings{xu2023spfresh, + title={SPFresh: Incremental In-Place Update for Billion-Scale Vector Search}, + author={Xu, Yuming and Liang, Hengyu and Li, Jin and Xu, Shuotao and Chen, Qi and Zhang, Qianxi and Li, Cheng and Yang, Ziyue and Yang, Fan and Yang, Yuqing and others}, + booktitle={Proceedings of the 29th Symposium on Operating Systems Principles}, + pages={545--561}, + year={2023} +} + +@inproceedings{zhang2023vbase, + title={$\{$VBASE$\}$: Unifying Online Vector Similarity Search and Relational Queries via Relaxed Monotonicity}, + author={Zhang, Qianxi and Xu, Shuotao and Chen, Qi and Sui, Guoxin and Xie, Jiadong and Cai, Zhizhen and Chen, Yaoqi and He, Yinxuan and Yang, Yuqing and Yang, Fan and others}, + booktitle={17th USENIX Symposium on Operating Systems Design and Implementation (OSDI 23)}, + year={2023} +} + @inproceedings{ChenW21, author = {Qi Chen and Bing Zhao and diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..e138ec5d6 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). + + diff --git a/SPTAG.WinRT.nuspec b/SPTAG.WinRT.nuspec new file mode 100644 index 000000000..0cc2b3a65 --- /dev/null +++ b/SPTAG.WinRT.nuspec @@ -0,0 +1,31 @@ + + + + SPTAG.WinRT + 1.2.13-mainOPQ-CoreLib-SPANN-withsource-a6b7604 + SPTAG.WinRT + cheqi,haidwa,mingqli + cheqi,haidwa,mingqli,zhah + false + https://github.com/microsoft/SPTAG + https://github.com/microsoft/SPTAG + SPTAG (Space Partition Tree And Graph) is a library for large scale vector approximate nearest neighbor search scenario released by Microsoft Research (MSR) and Microsoft Bing. + publish with commit microsoft/add python version in wheel package/a6b7604 + Copyright © Microsoft + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/SPTAG.nuspec b/SPTAG.nuspec index 9f4284bf9..39001e150 100644 --- a/SPTAG.nuspec +++ b/SPTAG.nuspec @@ -2,7 +2,7 @@ SPTAG.Library - 1.2.12-mainOPQ-CoreLib-SPANN-c6a30481-vc140 + 1.2.13-mainOPQ-CoreLib-SPANN-withsource-a6b7604 SPTAG.Library cheqi,haidwa,mingqli cheqi,haidwa,mingqli,zhah @@ -10,7 +10,7 @@ https://github.com/microsoft/SPTAG https://github.com/microsoft/SPTAG SPTAG (Space Partition Tree And Graph) is a library for large scale vector approximate nearest neighbor search scenario released by Microsoft Research (MSR) and Microsoft Bing. - publish with commit microsoft/legacy-avx512/c6a30481 + publish with commit microsoft/add python version in wheel package/a6b7604 Copyright © Microsoft @@ -78,9 +78,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + \ No newline at end of file diff --git a/SPTAG.sln b/SPTAG.sln index 0fdfe2212..49eaf3359 100644 --- a/SPTAG.sln +++ b/SPTAG.sln @@ -1,267 +1,286 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.31624.102 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CoreLibrary", "AnnService\CoreLibrary.vcxproj", "{C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}" -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Server", "AnnService\Server.vcxproj", "{E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonCore", "Wrappers\PythonCore.vcxproj", "{AF31947C-0495-42FE-A1AD-8F0DA2A679C7}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SocketLib", "AnnService\SocketLib.vcxproj", "{F9A72303-6381-4C80-86FF-606A2F6F7B96}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Client", "AnnService\Client.vcxproj", "{A89D70C3-C53B-42DE-A5CE-9A472540F5CB}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Aggregator", "AnnService\Aggregator.vcxproj", "{D7F09A63-BDCA-4F6C-A864-8551D1FE447A}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonClient", "Wrappers\PythonClient.vcxproj", "{9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexBuilder", "AnnService\IndexBuilder.vcxproj", "{F492F794-E78B-4B1F-A556-5E045B9163D5}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexSearcher", "AnnService\IndexSearcher.vcxproj", "{97615D3B-9FA0-469E-B229-95A91A5087E0}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Test", "Test\Test.vcxproj", "{29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {217B42B7-8F2B-4323-804C-08992CA2F65E} = {217B42B7-8F2B-4323-804C-08992CA2F65E} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaCore", "Wrappers\JavaCore.vcxproj", "{93FEB26B-965E-4157-8BE5-052F5CA112BB}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaClient", "Wrappers\JavaClient.vcxproj", "{8866BF98-AA2E-450F-9F33-083E007CCA74}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpCore", "Wrappers\CsharpCore.vcxproj", "{1896C009-AD46-4A70-B83C-4652A7F37503}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpClient", "Wrappers\CsharpClient.vcxproj", "{363BA3BB-75C4-4CC7-AECB-28C7534B3710}" - ProjectSection(ProjectDependencies) = postProject - {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CLRCore", "Wrappers\CLRCore.vcxproj", "{38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUCoreLibrary", "AnnService\GPUCoreLibrary.vcxproj", "{84A84DBE-D46B-4EC9-8948-FF150DE5D386}" -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUIndexBuilder", "AnnService\GPUIndexBuilder.vcxproj", "{85106062-3D9F-4B52-AFF3-97D0F9575888}" - ProjectSection(ProjectDependencies) = postProject - {84A84DBE-D46B-4EC9-8948-FF150DE5D386} = {84A84DBE-D46B-4EC9-8948-FF150DE5D386} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SSDServing", "AnnService\SSDServing.vcxproj", "{217B42B7-8F2B-4323-804C-08992CA2F65E}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "BalancedDataPartition", "AnnService\BalancedDataPartition.vcxproj", "{EB9D699A-16CB-4D51-AB39-0B170BE72FC8}" -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUSSDServing", "AnnService\GPUSSDServing.vcxproj", "{0CA39557-067E-4EBD-8264-9D09B00275A2}" -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Quantizer", "AnnService\Quantizer.vcxproj", "{94D8008A-3AD8-4CA6-BA02-FE1458804E1D}" - ProjectSection(ProjectDependencies) = postProject - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} - EndProjectSection -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Debug|x86 = Debug|x86 - Release|x64 = Release|x64 - Release|x86 = Release|x86 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.ActiveCfg = Debug|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.Build.0 = Debug|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.ActiveCfg = Debug|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.Build.0 = Debug|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.ActiveCfg = Release|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.Build.0 = Release|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.ActiveCfg = Debug|x64 - {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.Build.0 = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.ActiveCfg = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.Build.0 = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.ActiveCfg = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.Build.0 = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.ActiveCfg = Release|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.Build.0 = Release|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.ActiveCfg = Debug|x64 - {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.Build.0 = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.ActiveCfg = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.Build.0 = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.ActiveCfg = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.Build.0 = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.ActiveCfg = Release|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.Build.0 = Release|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.ActiveCfg = Debug|x64 - {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.Build.0 = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.ActiveCfg = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.Build.0 = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.ActiveCfg = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.Build.0 = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.ActiveCfg = Release|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.Build.0 = Release|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.ActiveCfg = Debug|x64 - {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.Build.0 = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.ActiveCfg = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.Build.0 = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.ActiveCfg = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.Build.0 = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.ActiveCfg = Release|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.Build.0 = Release|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.ActiveCfg = Debug|x64 - {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.Build.0 = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.ActiveCfg = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.Build.0 = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.ActiveCfg = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.Build.0 = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.ActiveCfg = Release|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.Build.0 = Release|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.ActiveCfg = Debug|x64 - {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.Build.0 = Debug|x64 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.ActiveCfg = Debug|x64 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.Build.0 = Debug|x64 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.ActiveCfg = Debug|Win32 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.Build.0 = Debug|Win32 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.ActiveCfg = Release|x64 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.Build.0 = Release|x64 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.ActiveCfg = Release|Win32 - {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.Build.0 = Release|Win32 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.ActiveCfg = Debug|x64 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.Build.0 = Debug|x64 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.ActiveCfg = Debug|Win32 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.Build.0 = Debug|Win32 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.ActiveCfg = Release|x64 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.Build.0 = Release|x64 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.ActiveCfg = Release|Win32 - {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.Build.0 = Release|Win32 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.ActiveCfg = Debug|x64 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.Build.0 = Debug|x64 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.ActiveCfg = Debug|Win32 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.Build.0 = Debug|Win32 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.ActiveCfg = Release|x64 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.Build.0 = Release|x64 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.ActiveCfg = Release|Win32 - {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.Build.0 = Release|Win32 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.ActiveCfg = Debug|x64 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.Build.0 = Debug|x64 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.ActiveCfg = Debug|Win32 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.Build.0 = Debug|Win32 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.ActiveCfg = Release|x64 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.Build.0 = Release|x64 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.ActiveCfg = Release|Win32 - {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.Build.0 = Release|Win32 - {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x64.ActiveCfg = Debug|x64 - {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x86.ActiveCfg = Debug|Win32 - {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x64.ActiveCfg = Release|x64 - {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x86.ActiveCfg = Release|Win32 - {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x64.ActiveCfg = Debug|x64 - {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x86.ActiveCfg = Debug|Win32 - {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x64.ActiveCfg = Release|x64 - {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x86.ActiveCfg = Release|Win32 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.ActiveCfg = Debug|x64 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.Build.0 = Debug|x64 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.ActiveCfg = Debug|Win32 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.Build.0 = Debug|Win32 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.ActiveCfg = Release|x64 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.Build.0 = Release|x64 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.ActiveCfg = Release|Win32 - {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.Build.0 = Release|Win32 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.ActiveCfg = Debug|x64 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.Build.0 = Debug|x64 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.ActiveCfg = Debug|Win32 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.Build.0 = Debug|Win32 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.ActiveCfg = Release|x64 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.Build.0 = Release|x64 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.ActiveCfg = Release|Win32 - {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.Build.0 = Release|Win32 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.ActiveCfg = Debug|x64 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.Build.0 = Debug|x64 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.ActiveCfg = Debug|Win32 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.Build.0 = Debug|Win32 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.ActiveCfg = Release|x64 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.Build.0 = Release|x64 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.ActiveCfg = Release|Win32 - {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.Build.0 = Release|Win32 - {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Debug|x64.ActiveCfg = Debug|x64 - {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Debug|x86.ActiveCfg = Debug|x64 - {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Release|x64.ActiveCfg = Release|x64 - {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Release|x86.ActiveCfg = Release|x64 - {85106062-3D9F-4B52-AFF3-97D0F9575888}.Debug|x64.ActiveCfg = Debug|x64 - {85106062-3D9F-4B52-AFF3-97D0F9575888}.Debug|x86.ActiveCfg = Debug|x64 - {85106062-3D9F-4B52-AFF3-97D0F9575888}.Release|x64.ActiveCfg = Release|x64 - {85106062-3D9F-4B52-AFF3-97D0F9575888}.Release|x86.ActiveCfg = Release|x64 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x64.ActiveCfg = Debug|x64 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x64.Build.0 = Debug|x64 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x86.ActiveCfg = Debug|Win32 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x86.Build.0 = Debug|Win32 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x64.ActiveCfg = Release|x64 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x64.Build.0 = Release|x64 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x86.ActiveCfg = Release|Win32 - {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x86.Build.0 = Release|Win32 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x64.ActiveCfg = Debug|x64 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x86.ActiveCfg = Debug|Win32 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x86.Build.0 = Debug|Win32 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x64.ActiveCfg = Release|x64 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x86.ActiveCfg = Release|Win32 - {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x86.Build.0 = Release|Win32 - {0CA39557-067E-4EBD-8264-9D09B00275A2}.Debug|x64.ActiveCfg = Debug|x64 - {0CA39557-067E-4EBD-8264-9D09B00275A2}.Debug|x86.ActiveCfg = Debug|x64 - {0CA39557-067E-4EBD-8264-9D09B00275A2}.Release|x64.ActiveCfg = Release|x64 - {0CA39557-067E-4EBD-8264-9D09B00275A2}.Release|x86.ActiveCfg = Release|x64 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x64.ActiveCfg = Debug|x64 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x64.Build.0 = Debug|x64 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x86.ActiveCfg = Debug|Win32 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x86.Build.0 = Debug|Win32 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x64.ActiveCfg = Release|x64 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x64.Build.0 = Release|x64 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x86.ActiveCfg = Release|Win32 - {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x86.Build.0 = Release|Win32 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection - GlobalSection(ExtensibilityGlobals) = postSolution - SolutionGuid = {38BDFF12-6FEC-4B67-A7BD-436D9E2544FD} - EndGlobalSection -EndGlobal +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.4.33403.182 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CoreLibrary", "AnnService\CoreLibrary.vcxproj", "{C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Server", "AnnService\Server.vcxproj", "{E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonCore", "Wrappers\PythonCore.vcxproj", "{AF31947C-0495-42FE-A1AD-8F0DA2A679C7}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SocketLib", "AnnService\SocketLib.vcxproj", "{F9A72303-6381-4C80-86FF-606A2F6F7B96}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Client", "AnnService\Client.vcxproj", "{A89D70C3-C53B-42DE-A5CE-9A472540F5CB}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Aggregator", "AnnService\Aggregator.vcxproj", "{D7F09A63-BDCA-4F6C-A864-8551D1FE447A}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PythonClient", "Wrappers\PythonClient.vcxproj", "{9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexBuilder", "AnnService\IndexBuilder.vcxproj", "{F492F794-E78B-4B1F-A556-5E045B9163D5}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexSearcher", "AnnService\IndexSearcher.vcxproj", "{97615D3B-9FA0-469E-B229-95A91A5087E0}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Test", "Test\Test.vcxproj", "{29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}" + ProjectSection(ProjectDependencies) = postProject + {217B42B7-8F2B-4323-804C-08992CA2F65E} = {217B42B7-8F2B-4323-804C-08992CA2F65E} + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaCore", "Wrappers\JavaCore.vcxproj", "{93FEB26B-965E-4157-8BE5-052F5CA112BB}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "JavaClient", "Wrappers\JavaClient.vcxproj", "{8866BF98-AA2E-450F-9F33-083E007CCA74}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpCore", "Wrappers\CsharpCore.vcxproj", "{1896C009-AD46-4A70-B83C-4652A7F37503}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CsharpClient", "Wrappers\CsharpClient.vcxproj", "{363BA3BB-75C4-4CC7-AECB-28C7534B3710}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + {F9A72303-6381-4C80-86FF-606A2F6F7B96} = {F9A72303-6381-4C80-86FF-606A2F6F7B96} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CLRCore", "Wrappers\CLRCore.vcxproj", "{38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUCoreLibrary", "AnnService\GPUCoreLibrary.vcxproj", "{84A84DBE-D46B-4EC9-8948-FF150DE5D386}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUIndexBuilder", "AnnService\GPUIndexBuilder.vcxproj", "{85106062-3D9F-4B52-AFF3-97D0F9575888}" + ProjectSection(ProjectDependencies) = postProject + {84A84DBE-D46B-4EC9-8948-FF150DE5D386} = {84A84DBE-D46B-4EC9-8948-FF150DE5D386} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SSDServing", "AnnService\SSDServing.vcxproj", "{217B42B7-8F2B-4323-804C-08992CA2F65E}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "BalancedDataPartition", "AnnService\BalancedDataPartition.vcxproj", "{EB9D699A-16CB-4D51-AB39-0B170BE72FC8}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "GPUSSDServing", "AnnService\GPUSSDServing.vcxproj", "{0CA39557-067E-4EBD-8264-9D09B00275A2}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Quantizer", "AnnService\Quantizer.vcxproj", "{94D8008A-3AD8-4CA6-BA02-FE1458804E1D}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SPTAG", "Wrappers\WinRT\SPTAG.WinRT.vcxproj", "{8DC74C33-6E15-43ED-9300-2A140589E3DA}" + ProjectSection(ProjectDependencies) = postProject + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} = {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9} + EndProjectSection +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "WinRTTest", "Test\WinRTTest\WinRTTest.vcxproj", "{C9DF3099-A142-4AA7-B936-1541816A1F21}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|x64 = Debug|x64 + Debug|x86 = Debug|x86 + Release|x64 = Release|x64 + Release|x86 = Release|x86 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x64.Build.0 = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Debug|x86.Build.0 = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.ActiveCfg = Release|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x64.Build.0 = Release|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.ActiveCfg = Debug|x64 + {C2BC5FDE-C853-4F3D-B7E4-2C9B5524DDF9}.Release|x86.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x64.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Debug|x86.Build.0 = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.ActiveCfg = Release|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x64.Build.0 = Release|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.ActiveCfg = Debug|x64 + {E28B1222-8BEA-4A92-8FE0-088EBDAA7FE0}.Release|x86.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x64.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Debug|x86.Build.0 = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.ActiveCfg = Release|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x64.Build.0 = Release|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.ActiveCfg = Debug|x64 + {AF31947C-0495-42FE-A1AD-8F0DA2A679C7}.Release|x86.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x64.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Debug|x86.Build.0 = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.ActiveCfg = Release|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x64.Build.0 = Release|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.ActiveCfg = Debug|x64 + {F9A72303-6381-4C80-86FF-606A2F6F7B96}.Release|x86.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x64.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Debug|x86.Build.0 = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.ActiveCfg = Release|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x64.Build.0 = Release|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.ActiveCfg = Debug|x64 + {A89D70C3-C53B-42DE-A5CE-9A472540F5CB}.Release|x86.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x64.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Debug|x86.Build.0 = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.ActiveCfg = Release|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x64.Build.0 = Release|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.ActiveCfg = Debug|x64 + {D7F09A63-BDCA-4F6C-A864-8551D1FE447A}.Release|x86.Build.0 = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.ActiveCfg = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x64.Build.0 = Debug|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.ActiveCfg = Debug|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Debug|x86.Build.0 = Debug|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.ActiveCfg = Release|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x64.Build.0 = Release|x64 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.ActiveCfg = Release|Win32 + {9B014CF6-E3FB-4BD4-B3B1-D26297BB31AA}.Release|x86.Build.0 = Release|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.ActiveCfg = Debug|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x64.Build.0 = Debug|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.ActiveCfg = Debug|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Debug|x86.Build.0 = Debug|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.ActiveCfg = Release|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x64.Build.0 = Release|x64 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.ActiveCfg = Release|Win32 + {F492F794-E78B-4B1F-A556-5E045B9163D5}.Release|x86.Build.0 = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.ActiveCfg = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.Build.0 = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.ActiveCfg = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.Build.0 = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.ActiveCfg = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.Build.0 = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.ActiveCfg = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.Build.0 = Release|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.ActiveCfg = Debug|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x64.Build.0 = Debug|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.ActiveCfg = Debug|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Debug|x86.Build.0 = Debug|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.ActiveCfg = Release|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.Build.0 = Release|x64 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.ActiveCfg = Release|Win32 + {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.Build.0 = Release|Win32 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x64.ActiveCfg = Debug|x64 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Debug|x86.ActiveCfg = Debug|Win32 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x64.ActiveCfg = Release|x64 + {93FEB26B-965E-4157-8BE5-052F5CA112BB}.Release|x86.ActiveCfg = Release|Win32 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x64.ActiveCfg = Debug|x64 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Debug|x86.ActiveCfg = Debug|Win32 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x64.ActiveCfg = Release|x64 + {8866BF98-AA2E-450F-9F33-083E007CCA74}.Release|x86.ActiveCfg = Release|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.ActiveCfg = Debug|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x64.Build.0 = Debug|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.ActiveCfg = Debug|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Debug|x86.Build.0 = Debug|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.ActiveCfg = Release|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x64.Build.0 = Release|x64 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.ActiveCfg = Release|Win32 + {1896C009-AD46-4A70-B83C-4652A7F37503}.Release|x86.Build.0 = Release|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.ActiveCfg = Debug|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x64.Build.0 = Debug|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.ActiveCfg = Debug|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Debug|x86.Build.0 = Debug|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.ActiveCfg = Release|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x64.Build.0 = Release|x64 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.ActiveCfg = Release|Win32 + {363BA3BB-75C4-4CC7-AECB-28C7534B3710}.Release|x86.Build.0 = Release|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.ActiveCfg = Debug|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x64.Build.0 = Debug|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.ActiveCfg = Debug|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Debug|x86.Build.0 = Debug|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.ActiveCfg = Release|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x64.Build.0 = Release|x64 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.ActiveCfg = Release|Win32 + {38ACBA6C-2E50-44D4-9A6D-DC735B56E38F}.Release|x86.Build.0 = Release|Win32 + {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Debug|x64.ActiveCfg = Debug|x64 + {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Debug|x86.ActiveCfg = Debug|x64 + {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Release|x64.ActiveCfg = Release|x64 + {84A84DBE-D46B-4EC9-8948-FF150DE5D386}.Release|x86.ActiveCfg = Release|x64 + {85106062-3D9F-4B52-AFF3-97D0F9575888}.Debug|x64.ActiveCfg = Debug|x64 + {85106062-3D9F-4B52-AFF3-97D0F9575888}.Debug|x86.ActiveCfg = Debug|x64 + {85106062-3D9F-4B52-AFF3-97D0F9575888}.Release|x64.ActiveCfg = Release|x64 + {85106062-3D9F-4B52-AFF3-97D0F9575888}.Release|x86.ActiveCfg = Release|x64 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x64.ActiveCfg = Debug|x64 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x64.Build.0 = Debug|x64 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x86.ActiveCfg = Debug|Win32 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Debug|x86.Build.0 = Debug|Win32 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x64.ActiveCfg = Release|x64 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x64.Build.0 = Release|x64 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x86.ActiveCfg = Release|Win32 + {217B42B7-8F2B-4323-804C-08992CA2F65E}.Release|x86.Build.0 = Release|Win32 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x64.ActiveCfg = Debug|x64 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x86.ActiveCfg = Debug|Win32 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Debug|x86.Build.0 = Debug|Win32 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x64.ActiveCfg = Release|x64 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x86.ActiveCfg = Release|Win32 + {EB9D699A-16CB-4D51-AB39-0B170BE72FC8}.Release|x86.Build.0 = Release|Win32 + {0CA39557-067E-4EBD-8264-9D09B00275A2}.Debug|x64.ActiveCfg = Debug|x64 + {0CA39557-067E-4EBD-8264-9D09B00275A2}.Debug|x86.ActiveCfg = Debug|x64 + {0CA39557-067E-4EBD-8264-9D09B00275A2}.Release|x64.ActiveCfg = Release|x64 + {0CA39557-067E-4EBD-8264-9D09B00275A2}.Release|x86.ActiveCfg = Release|x64 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x64.ActiveCfg = Debug|x64 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x64.Build.0 = Debug|x64 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x86.ActiveCfg = Debug|Win32 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Debug|x86.Build.0 = Debug|Win32 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x64.ActiveCfg = Release|x64 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x64.Build.0 = Release|x64 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x86.ActiveCfg = Release|Win32 + {94D8008A-3AD8-4CA6-BA02-FE1458804E1D}.Release|x86.Build.0 = Release|Win32 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Debug|x64.ActiveCfg = Debug|x64 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Debug|x86.ActiveCfg = Debug|Win32 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Debug|x86.Build.0 = Debug|Win32 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Release|x64.ActiveCfg = Release|x64 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Release|x86.ActiveCfg = Release|Win32 + {8DC74C33-6E15-43ED-9300-2A140589E3DA}.Release|x86.Build.0 = Release|Win32 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Debug|x64.ActiveCfg = Debug|x64 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Debug|x86.ActiveCfg = Debug|Win32 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Debug|x86.Build.0 = Debug|Win32 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Release|x64.ActiveCfg = Release|x64 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Release|x86.ActiveCfg = Release|Win32 + {C9DF3099-A142-4AA7-B936-1541816A1F21}.Release|x86.Build.0 = Release|Win32 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {38BDFF12-6FEC-4B67-A7BD-436D9E2544FD} + EndGlobalSection +EndGlobal diff --git a/SPTAG.targets b/SPTAG.targets new file mode 100644 index 000000000..64f99aec6 --- /dev/null +++ b/SPTAG.targets @@ -0,0 +1,15 @@ + + + + + x64 + <_nugetNativeFolder>$(MSBuildThisFileDirectory)..\runtimes\win-$(Native-Platform)\native\ + + + + + %(FileName)%(Extension) + PreserveNewest + + + \ No newline at end of file diff --git a/Script_AE/Figure1/motivation.p b/Script_AE/Figure1/motivation.p new file mode 100644 index 000000000..bbe380911 --- /dev/null +++ b/Script_AE/Figure1/motivation.p @@ -0,0 +1,79 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 1.75 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 1.75 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 7 lc rgb '#80d653' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 8 lc rgb '#9400d3' lt 1 lw 1.7 pt 7 ps 1.5 +# set style line 1 lc rgb '#00d5ff' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 2 lc rgb '#000080' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 3 lc rgb '#ff7f0e' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 4 lc rgb '#008176' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 5 lc rgb '#b3b3b3' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 6 lc rgb '#000000' lt 1 lw 1.5 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 +# set key box ls 20 opaque fc rgb "#3fffffff" width 2 +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set ytics nomirror out autofreq offset 0.5,0,0 +set border 3 lw 2 +# set xrange [0:*] +set yrange [0:*] + +# Start the first plot +set output "Motivation.pdf" + +set multiplot + +set size 0.55, 0.88 +set origin -0.03, 0.12 + +set xlabel "Latency (ms)" offset 0,1,0 +set ylabel 'Recall 10\@10' offset 2,0,0 +set xrange [1:5] +set xtics ("1" 1, "" 1.5, "2" 2, "" 2.5, "3" 3, "" 3.5, "4" 4, "" 4.5, "5" 5) +# set key bottom right reverse Left +unset key + +# set title "Search Latency" offset 0, -0.7 +unset title +set yrange [0.92:1] +set ytics 0.02 +plot "motivation1.csv" using 3:4 every ::3 with linespoints title 'In-place update' ls 4 pointsize 0.5, \ + "motivation1.csv" using 1:2 every ::3 with linespoints title 'Static' ls 7 pointsize 0.5 +unset key + +set xlabel "Latency (ms)" offset 0,1,0 +set ylabel "CDF" offset 2,0,0 +set key bottom right reverse Left invert +set xrange [0:30] +set xtics 5 +set yrange [0:*] +set ytics 0.2 + +set key +set size 0.56, 0.88 +set origin 0.46, 0.12 +# set title "Latency Distrubution After Data Shifting" offset 0, -0.7 +unset title +plot "motivation.csv" using 1:2 every ::1 with linespoints title 'Static' at 0.25, 0.08 ls 7 pointsize 0.5, \ + "motivation.csv" using 3:4 every ::1 with linespoints title 'In-place update' at 0.65, 0.08 ls 4 pointsize 0.5 + +unset key + +unset multiplot diff --git a/Script_AE/Figure1/motivation.sh b/Script_AE/Figure1/motivation.sh new file mode 100644 index 000000000..39b4e6972 --- /dev/null +++ b/Script_AE/Figure1/motivation.sh @@ -0,0 +1,6 @@ +# static +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster_2m |tee log_static.log + +# nolimit +cp /home/sosp/data/store_sift_cluster/indexloader_nolimit.ini /home/sosp/data/store_sift_cluster/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster |tee log_nolimit.log diff --git a/Script_AE/Figure1/plot_motivation_result.sh b/Script_AE/Figure1/plot_motivation_result.sh new file mode 100644 index 000000000..736e05f3d --- /dev/null +++ b/Script_AE/Figure1/plot_motivation_result.sh @@ -0,0 +1,2 @@ +python process_motivation.py +gnuplot motivation.p \ No newline at end of file diff --git a/Script_AE/Figure1/process_motivation.py b/Script_AE/Figure1/process_motivation.py new file mode 100644 index 000000000..b09ac5acc --- /dev/null +++ b/Script_AE/Figure1/process_motivation.py @@ -0,0 +1,97 @@ +import sys +import re +import csv + +process_list = ["log_static.log", "log_nolimit.log"] + +percentile = ['', 0.5, 0.9, 0.95, 0.99, 0.999, 1] + +percentile_baseline = [] + +avg_latency = [] +accuracy = [] + +last_data_group = [] + +for fileName in process_list: + + log_f = open(fileName) + + templist_latency = [] + + templist_accuracy = [] + + templist_accuracy.append('') + templist_latency.append('') + + line_count = 0 + + if fileName == "log_nolimit.log": + + while True: + line = log_f.readline() + line_count+=1 + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Total" and result_group[2] == "Vector": + break + + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + templist_latency.append(float(result_group[1])) + last_data_group = result_group + if len(result_group) > 2 and result_group[1] == "Recall10@10:": + templist_accuracy.append(float(result_group[2])) + else: + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 3 and result_group[1] == "Updating" and result_group[2] == "numThread:": + break + + if len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + templist_latency.append(float(result_group[1])) + last_data_group = result_group + if len(result_group) > 2 and result_group[1] == "Recall10@10:": + templist_accuracy.append(float(result_group[2])) + + accuracy.append(templist_accuracy) + avg_latency.append(templist_latency) + temp_baseline = [] + temp_baseline.append('') + for i in range(2, 8): + temp_baseline.append(float(last_data_group[i])) + percentile_baseline.append(temp_baseline) + +with open("motivation1.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency[0], accuracy[0], avg_latency[1], accuracy[1])) + +with open("motivation.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(percentile_baseline[0], percentile, percentile_baseline[1], percentile)) diff --git a/Script_AE/Figure10/parameter_study_range.p b/Script_AE/Figure10/parameter_study_range.p new file mode 100644 index 000000000..ba70ee971 --- /dev/null +++ b/Script_AE/Figure10/parameter_study_range.p @@ -0,0 +1,53 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 1.75 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 1.75 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 7 lc rgb '#80d653' lt 1 lw 1.7 pt 7 ps 1.5 +# set style line 1 lc rgb '#00d5ff' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 2 lc rgb '#000080' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 3 lc rgb '#ff7f0e' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 4 lc rgb '#008176' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 5 lc rgb '#b3b3b3' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 6 lc rgb '#000000' lt 1 lw 1.5 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 +# set key box ls 20 opaque fc rgb "#3fffffff" width -2 +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set ytics nomirror out autofreq offset 0.5,0,0 +set border lw 2 + +# Start the second plot + +set output "ParameterStudyRange.pdf" + +set xlabel "Latency (ms)" offset 0,1,0 +set ylabel 'Recall 10\@10' offset 1.5,0,0 +set key bottom right reverse Left + +set size 1, 1 +set origin 0, 0 +# set title "Search Latency" offset 0, -0.7 +unset title +set yrange [0.93:1] +set ytics ("" 0.93, "0.94" 0.94, "" 0.95, "0.96" 0.96, "" 0.97, "0.98" 0.98, "" 0.99, "1" 1) +plot "parameter_study_range.csv" using 1:2 every ::3 with lines title 'Reassign top0' ls 1, \ + "parameter_study_range.csv" using 3:4 every ::3 with lines title 'Reassign top8' ls 2, \ + "parameter_study_range.csv" using 5:6 every ::3 with lines title 'Reassign top64' ls 3, \ + "parameter_study_range.csv" using 7:8 every ::3 with lines title 'Reassign top128' ls 4 +unset key \ No newline at end of file diff --git a/Script_AE/Figure10/parameter_study_range.sh b/Script_AE/Figure10/parameter_study_range.sh new file mode 100644 index 000000000..e21563229 --- /dev/null +++ b/Script_AE/Figure10/parameter_study_range.sh @@ -0,0 +1,13 @@ +loaderPath="/home/sosp/data/store_sift_cluster/indexloader.ini" +storePath="/home/sosp/data/store_sift_cluster" +ReassignLine="118c ReassignK=" +logPath="log_top" + +cp /home/sosp/data/store_sift_cluster/indexloader_top64.ini /home/sosp/data/store_sift_cluster/indexloader.ini + +for i in 0 8 64 128 +do + newReassignLine=$ReassignLine$i + sed -i "$newReassignLine" $loaderPath + PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh $storePath |tee $logPath$i +done \ No newline at end of file diff --git a/Script_AE/Figure10/plot_range_result.sh b/Script_AE/Figure10/plot_range_result.sh new file mode 100644 index 000000000..06029d53d --- /dev/null +++ b/Script_AE/Figure10/plot_range_result.sh @@ -0,0 +1,2 @@ +python process_para_range.py log_top +gnuplot parameter_study_range.p \ No newline at end of file diff --git a/Script_AE/Figure10/process_para_range.py b/Script_AE/Figure10/process_para_range.py new file mode 100644 index 000000000..1e33db9f1 --- /dev/null +++ b/Script_AE/Figure10/process_para_range.py @@ -0,0 +1,60 @@ +import sys +import re +import csv + + +topkList = [0, 8, 64, 128] + +avg_latency = [] +accuracy = [] + +for i in topkList: + templist_latency = [] + + templist_accuracy = [] + + templist_accuracy.append('') + templist_latency.append('') + + fileName = sys.argv[1] + str(i) + + log_f = open(fileName) + + line_count = 0 + + while True: + line = log_f.readline() + line_count+=1 + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Total" and result_group[2] == "Vector": + break + + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + templist_latency.append(float(result_group[1])) + if len(result_group) > 2 and result_group[1] == "Recall10@10:": + templist_accuracy.append(float(result_group[2])) + + accuracy.append(templist_accuracy) + avg_latency.append(templist_latency) + + +with open("parameter_study_range.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency[0], accuracy[0], avg_latency[1], accuracy[1], avg_latency[2], accuracy[2], avg_latency[3], accuracy[3])) + diff --git a/Script_AE/Figure11/foreground_background.p b/Script_AE/Figure11/foreground_background.p new file mode 100644 index 000000000..9a22c61cb --- /dev/null +++ b/Script_AE/Figure11/foreground_background.p @@ -0,0 +1,63 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 1.75 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 2.07 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 +set style fill solid + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 noxtics ytics +# set key box ls 20 opaque fc rgb "#3fffffff" +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set ytics nomirror out offset 0.5,0,0 +set border lw 2 +set yrange [0:6000] +set ytics 1500 + +set style data histogram +set style histogram cluster gap 1 +set style fill solid +set boxwidth 0.9 + +# Start the first plot +set output "Scalability.pdf" + +set multiplot + +set xlabel "Foreground Update\nThread Num" offset 0,1,0 +set ylabel "Throughput" offset 1,0,0 + +set key reverse Left + +set size 0.6, 0.88 +set origin -0.02, 0.12 + +# set title "Scalability (Background Thread = 1)" offset 0, -0.7 +plot "foreground_background.csv" using 3:xtic(2) every ::1 title 'Foreground' at 0.2, 0.07, \ + "foreground_background.csv" using 4 every ::1 title 'Background' at 0.75, 0.07 + +set size 0.5, 0.88 +set origin 0.52, 0.12 +set xlabel "Background Update\nThread Num" offset 0,1,0 +unset ylabel +set ytics format "" + +# set title "Scalability (Background Thread = 8)" offset 0, -0.7 +plot "foreground_background.csv" using 10:xtic(9) every ::1::3 title 'Foreground' at 0.2, 0.07, \ + "foreground_background.csv" using 11 every ::1::3 title 'Background' at 0.75, 0.07 + +unset multiplot \ No newline at end of file diff --git a/Script_AE/Figure11/parameter_study_balance.sh b/Script_AE/Figure11/parameter_study_balance.sh new file mode 100644 index 000000000..78f5cb895 --- /dev/null +++ b/Script_AE/Figure11/parameter_study_balance.sh @@ -0,0 +1,39 @@ +loaderPath="/home/sosp/data/store_sift1m/indexloader.ini" +storePath="/home/sosp/data/store_sift1m" +InsertLine="109c InsertThreadNum=" +AppendLine="110c AppendThreadNum=" +DeleteQPSLine="122c DeleteQPS=" +Insert=1 +Append=1 +DeleteQPS=1000 +logPath="log_" +newDeleteQPS=0 + + +for i in {1..4} +do + let 'newDeleteQPS=Insert*DeleteQPS' + newInsertLine=$InsertLine$Insert + newAppendLine=$AppendLine$Append + newDeleteQPSLine=$DeleteQPSLine$newDeleteQPS + sed -i "$newInsertLine" $loaderPath + sed -i "$newAppendLine" $loaderPath + sed -i "$newDeleteQPSLine" $loaderPath + PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh $storePath |tee $logPath$Insert$Append + let 'Insert=Insert*2' +done + +let 'Insert=8' + +for i in {1..2} +do + let 'Append=Append*2' + let 'newDeleteQPS=Insert*DeleteQPS' + newInsertLine=$InsertLine$Insert + newAppendLine=$AppendLine$Append + newDeleteQPSLine=$DeleteQPSLine$newDeleteQPS + sed -i "$newInsertLine" $loaderPath + sed -i "$newAppendLine" $loaderPath + sed -i "$newDeleteQPSLine" $loaderPath + PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh $storePath |tee $logPath$Insert$Append +done \ No newline at end of file diff --git a/Script_AE/Figure11/plot_balance_result.sh b/Script_AE/Figure11/plot_balance_result.sh new file mode 100644 index 000000000..90a4fd48a --- /dev/null +++ b/Script_AE/Figure11/plot_balance_result.sh @@ -0,0 +1,2 @@ +python process_balance.py log_ +gnuplot foreground_background.p \ No newline at end of file diff --git a/Script_AE/Figure11/process_balance.py b/Script_AE/Figure11/process_balance.py new file mode 100644 index 000000000..fbcad8597 --- /dev/null +++ b/Script_AE/Figure11/process_balance.py @@ -0,0 +1,60 @@ +import sys +import re +import csv + + +insert = [1, 2, 4, 8] +append = [1, 2, 4] + +threadlist = [11, 21, 41, 81, 82, 84] + +fore_throughput = [] +back_throughput = [] + +for i in threadlist: + fileName = sys.argv[1] + str(i) + + log_f = open(fileName) + + while True: + line = log_f.readline() + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 11 and result_group[1] == "Insert:" and result_group[7] == "sending": + fore_throughput.append(float(result_group[10])) + + if len(result_group) > 11 and result_group[1] == "Insert:" and result_group[7] == "actuall": + back_throughput.append(float(result_group[10].rstrip(','))) + +with open("foreground_background.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + templist = [] + for i in range(0, 11): + templist.append('') + writer.writerow(templist) + for i in range(0,4): + templist = [] + templist.append('') + templist.append(insert[i]) + if insert[i] != 8: + templist.append(fore_throughput[i]) + templist.append(back_throughput[i]) + for j in range(0, 4): + templist.append('') + templist.append(insert[i]) + templist.append(fore_throughput[i+3]) + templist.append(back_throughput[i+3]) + else: + for j in range(0,3): + templist.append(fore_throughput[j+3]) + templist.append(back_throughput[j+3]) + for j in range(0, 3): + templist.append('') + writer.writerow(templist) + + + diff --git a/Script_AE/Figure6/OverallPerformance_merge_result.py b/Script_AE/Figure6/OverallPerformance_merge_result.py new file mode 100644 index 000000000..395852663 --- /dev/null +++ b/Script_AE/Figure6/OverallPerformance_merge_result.py @@ -0,0 +1,106 @@ +import sys +import re +import csv +line_count = 0 + +# spfresh 0, spann 1, diskann 2 +accuracy_list=[] + +for j in range(1, 4): + templist = [] + templist.append('') + templist.append('Accuracy') + for i in range(-1, 100): + fileName = sys.argv[j] + str(i) + log_f = open(fileName) + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + if len(result_group) > 1 and result_group[1] == "Recall10@10:": + templist.append(float(result_group[2])) + while len(templist) != 402: + templist.append('') + + accuracy_list.append(templist) + +# spfresh 3, spann 4, diskann 5 + +csv_reader_spfresh = csv.reader(open(sys.argv[4])) +csv_reader_spann = csv.reader(open(sys.argv[5])) +csv_reader_diskann = csv.reader(open(sys.argv[6])) + +step = 0.2 +batch = [] +batch.append('') +batch.append('Batch') +for i in range(0,100): + #batch.append(i) + batch.append(i+step) + batch.append(i+step*2) + batch.append(i+step*3) + batch.append(i+step*4) + + +with open('overall_performance_spacev.csv', 'w') as f: + writer = csv.writer(f, delimiter=',') + i = 0 + for row1, row2, row3 in zip(csv_reader_spfresh, csv_reader_diskann, csv_reader_spann): + line = [] + line.append(batch[i]) + line.append('') + line += row1 + line.append('') + line.append(accuracy_list[0][i]) + line.append('') + line += row2 + line.append('') + line.append(accuracy_list[2][i]) + line.append('') + line += row3 + line.append('') + line.append(accuracy_list[1][i]) + writer.writerow(line) + i+=1 + +batch = [] +batch.append('') +batch.append('Batch') +for i in range(0,101): + batch.append(i) + +csv_reader_spfresh = csv.reader(open(sys.argv[4])) +csv_reader_spann = csv.reader(open(sys.argv[5])) +csv_reader_diskann = csv.reader(open(sys.argv[6])) + +with open('overall_performance_spacev2.csv', 'w') as f: + writer = csv.writer(f, delimiter=',') + i = 0 + for row1, row2, row3 in zip(csv_reader_spfresh, csv_reader_diskann, csv_reader_spann): + if i == 101: + break + line = [] + line.append(batch[i]) + line.append(row1[4]) + line.append('') + line.append(accuracy_list[0][i]) + line.append(row2[4]) + line.append('') + line.append(accuracy_list[2][i]) + line.append(row3[4]) + line.append('') + line.append(accuracy_list[1][i]) + writer.writerow(line) + i+=1 + + + + + + + + diff --git a/Script_AE/Figure6/overall_performance_spacev_new.p b/Script_AE/Figure6/overall_performance_spacev_new.p new file mode 100644 index 000000000..12ba2e0dc --- /dev/null +++ b/Script_AE/Figure6/overall_performance_spacev_new.p @@ -0,0 +1,92 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +# set terminal pdfcairo size 3.3, 2.07 font "UbuntuMono-Regular, 10" +set terminal pdfcairo size 7, 1.75 font 'Linux Biolinum O,12' + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 +# set style line 1 lc rgb '#00d5ff' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 2 lc rgb '#000080' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 3 lc rgb '#ff7f0e' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 4 lc rgb '#008176' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 5 lc rgb '#b3b3b3' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 6 lc rgb '#000000' lt 1 lw 1.5 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 +# set key box ls 20 opaque fc rgb "#3fffffff" +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set xtics ("0" 0, "20" 20, "40" 40, "60" 60, "80" 80, "100" 100) +set xrange [0:100] +set ytics nomirror out autofreq offset 0.5,0,0 +set border lw 2 +set yrange [0:*] + +# Start the first plot +set output "OverallPerformanceSPACEV1.pdf" +set multiplot + +set xlabel "Batches/Days" offset 0,1,0 +set ylabel "Latency (ms)" offset 1.1,0,0 +# set key reverse Left +unset key + +set size 0.26, 0.95 +set origin 0, 0.13 +set yrange [0:20] +set ytics ("0" 0, "" 2, "4" 4, "" 6, "8" 8, "" 10, "12" 12, "" 14, "16" 16, "" 18, "20" 20) +set title "Search Tail Latency" offset 0, -0.7 +plot "overall_performance_spacev.csv" using 1:4 every ::3 with lines ls 1, \ + "overall_performance_spacev.csv" using 1:12 every ::3 with lines ls 3, \ + "overall_performance_spacev.csv" using 1:20 every ::3 with lines ls 5 + +set size 0.26, 0.95 +set origin 0.25, 0.13 +set ylabel 'Throughput (QPS)' offset 1.1,0,0 # {/Symbol m} +set yrange [0:1000] +set ytics 200 +set title "Insert Throughput (per Thread)" offset 0, -0.7 +plot "overall_performance_spacev2.csv" using 1:(1000000./$2) every ::3 with lines ls 1, \ + "overall_performance_spacev2.csv" using 1:(1000000./$5) every ::3 with lines ls 3, \ + "overall_performance_spacev2.csv" using 1:(1000000./$8) every ::3 with lines ls 5 + +# "overall_performance_spacev2.csv" using 1:($3/1000.) every ::3 with lines ls 2, \ +# "overall_performance_spacev2.csv" using 1:($6/1000.) every ::3 with lines ls 4, \ +# "overall_performance_spacev2.csv" using 1:($9/1000.) every ::3 with lines ls 6 + +set size 0.26, 0.95 +set origin 0.49, 0.13 +set ylabel 'Recall 10\@10' offset 2,0,0 +set yrange [0.5:1] +set ytics 0.5,0.1,1 +set title "Accuracy" offset 0, -0.7 +plot "overall_performance_spacev2.csv" using 1:4 every ::3 with lines ls 1, \ + "overall_performance_spacev2.csv" using 1:7 every ::3 with lines ls 3, \ + "overall_performance_spacev2.csv" using 1:10 every ::3 with lines ls 5 + +set size 0.26, 0.95 +set origin 0.73, 0.13 +set ylabel "Memory (GB)" offset 2,0,0 +set yrange [0:*] +set ytics autofreq +set key reverse Left +set title "Memory Usage" offset 0, -0.7 +plot "overall_performance_spacev.csv" using 1:6 every ::3 with lines title 'SPFresh' at 0.35, 0.08 ls 1, \ + "overall_performance_spacev.csv" using 1:14 every ::3 with lines title 'DiskANN' at 0.5, 0.08 ls 3, \ + "overall_performance_spacev.csv" using 1:22 every ::3 with lines title 'SPANN+' at 0.65, 0.08 ls 5 + +unset key + +unset multiplot diff --git a/Script_AE/Figure6/overall_spacev_diskann.sh b/Script_AE/Figure6/overall_spacev_diskann.sh new file mode 100644 index 000000000..e0c3e431b --- /dev/null +++ b/Script_AE/Figure6/overall_spacev_diskann.sh @@ -0,0 +1,18 @@ +/home/sosp/DiskANN_Baseline/build/tests/overall_performance int8 ~/data/spacev_data/spacev100m_base.i8bin 75 32 1.2 75 64 1.2 100000000 1 25 0 ~/testbed/store_diskann_100m/diskann_spacev_100m true false ~/data/spacev_data/spacev200m_base.i8bin ~/data/spacev_data/query.i8bin ~/data/truth 10 40 2 ~/data/spacev_data/spacev100m_update_trace 100 |tee log_overall_performance_spacev_diskann.log +python process_diskann.py log_overall_performance_spacev_diskann.log overall_performance_spacev_diskann_result.csv + +mkdir diskann_result +mv /home/sosp/result_overall_spacev_diskann* diskann_result + +resultnamePrefix=/diskann_result/ +i=-1 +for FILE in `ls -v1 ./diskann_result/` +do + if [ $i -eq -1 ]; + then + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/msspacev-100M -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_diskann_$i + else + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/spacev100m_update_truth_after$i -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_diskann_$i + fi + let "i=i+1" +done \ No newline at end of file diff --git a/Script_AE/Figure6/overall_spacev_spann.sh b/Script_AE/Figure6/overall_spacev_spann.sh new file mode 100644 index 000000000..3bc45125a --- /dev/null +++ b/Script_AE/Figure6/overall_spacev_spann.sh @@ -0,0 +1,19 @@ +cp /home/sosp/store_spacev100m/data/indexloader_spann.ini /home/sosp/data/store_spacev100m/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_spacev100m/|tee log_overall_performance_spacev_spann.log +python process_spann.py log_overall_performance_spacev_spann.log overall_performance_spacev_spann_result.csv + +mkdir spann_result +mv /home/sosp/result_overall_spacev_spann* spann_result + +resultnamePrefix=/spann_result/ +i=-1 +for FILE in `ls -v1 ./spann_result/` +do + if [ $i -eq -1 ]; + then + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/msspacev-100M -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_spann_$i + else + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/spacev100m_update_truth_after$i -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_spann_$i + fi + let "i=i+1" +done \ No newline at end of file diff --git a/Script_AE/Figure6/overall_spacev_spfresh.sh b/Script_AE/Figure6/overall_spacev_spfresh.sh new file mode 100644 index 000000000..020ff049b --- /dev/null +++ b/Script_AE/Figure6/overall_spacev_spfresh.sh @@ -0,0 +1,19 @@ +cp /home/sosp/data/store_spacev100m/indexloader_spfresh.ini /home/sosp/data/store_spacev100m/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_spacev100m/|tee log_overall_performance_spacev_spfresh.log +python process_spfresh.py log_overall_performance_spacev_spfresh.log overall_performance_spacev_spfresh_result.csv + +mkdir spfresh_result +mv /home/sosp/result_overall_spacev_spfresh* spfresh_result + +resultnamePrefix=/spfresh_result/ +i=-1 +for FILE in `ls -v1 ./spfresh_result/` +do + if [ $i -eq -1 ]; + then + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/msspacev-100M -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_spfresh_$i + else + /home/sosp/SPFresh/Release/usefultool -CallRecall true -resultNum 10 -queryPath /home/sosp/data/spacev_data/query.i8bin -searchResult $PWD$resultnamePrefix$FILE -truthType DEFAULT -truthPath /home/sosp/data/spacev_data/spacev100m_update_truth_after$i -VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --vectortype int8 -d 100 -f DEFAULT |tee log_spfresh_$i + fi + let "i=i+1" +done \ No newline at end of file diff --git a/Script_AE/Figure6/plot_overall_result.sh b/Script_AE/Figure6/plot_overall_result.sh new file mode 100644 index 000000000..380e5a41a --- /dev/null +++ b/Script_AE/Figure6/plot_overall_result.sh @@ -0,0 +1,2 @@ +python OverallPerformance_merge_result.py log_spfresh_ log_spann_ log_diskann_ overall_performance_spacev_spfresh_result.csv overall_performance_spacev_spann_result.csv overall_performance_spacev_diskann_result.csv +gnuplot gnuplot overall_performance_spacev_new.p \ No newline at end of file diff --git a/Script_AE/Figure6/process_diskann.py b/Script_AE/Figure6/process_diskann.py new file mode 100644 index 000000000..f2d6a3ac4 --- /dev/null +++ b/Script_AE/Figure6/process_diskann.py @@ -0,0 +1,120 @@ +import sys +import re +import csv + +log_f = open(sys.argv[1]) + +avg_latency = [] +tail_latency = [] +throughput = [] +insert_throughput = [] +RSS = [] + +avg_latency_batch = [] +tail_latency_batch = [] +throughput_batch = [] +RSS_batch = [] + +number_point_per_batch = 4 + +number_thread = 3 + +last_process = 0 + +line_count = 0 + +memory = 0 + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if result_group[0] == "memory": + memory = float(result_group[6]) + + if result_group[0] != "Queries" and result_group[0] != "Inserted": + pre_result_group = result_group + continue + + if result_group[0] == "Queries": + if int(result_group[2]) < last_process: + # new batch + batch_len = len(avg_latency_batch) + duration = int(batch_len) / number_point_per_batch + + for i in range(0, number_point_per_batch): + avg_latency.append(float(avg_latency_batch[0+i*duration])) + tail_latency.append(float(tail_latency_batch[0+i*duration])) + throughput.append(float(throughput_batch[0+i*duration])) + RSS.append(RSS_batch[0+i*duration] / 1024 / 1024) + + avg_latency_batch = [] + tail_latency_batch = [] + throughput_batch = [] + RSS_batch = [] + + avg_latency_batch.append(pre_result_group[2]) + tail_latency_batch.append(pre_result_group[7]) + throughput_batch.append(pre_result_group[1]) + RSS_batch.append(memory) + last_process = int(result_group[2]) + + if result_group[0] == "Inserted": + insert_throughput.append(int(result_group[1]) / float(result_group[4].strip('s')) / number_thread) + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + + +batch_len = len(avg_latency_batch) +duration = int(batch_len) / number_point_per_batch + +for i in range(0, number_point_per_batch): + avg_latency.append(float(avg_latency_batch[0+i*duration])) + tail_latency.append(float(tail_latency_batch[0+i*duration])) + throughput.append(float(throughput_batch[0+i*duration])) + RSS.append(RSS_batch[0+i*duration] / 1024 / 1024) + +avg_latency_search = [] +avg_latency_search.append('Search') +avg_latency_search.append('Avg Latency') +for i in range(0, len(avg_latency)): + avg_latency_search.append(avg_latency[i]) + +tail_latency_search = [] +tail_latency_search.append(' ') +tail_latency_search.append('Tail Latency') +for i in range(0, len(tail_latency)): + tail_latency_search.append(tail_latency[i]) + +throughput_search = [] +throughput_search.append('') +throughput_search.append('Throughput') +for i in range(0, len(throughput)): + throughput_search.append(throughput[i]) + +RSS_search = [] +RSS_search.append('') +RSS_search.append('RSS') +for i in range(0, len(RSS)): + RSS_search.append(RSS[i]) + +insert_avg_latency_search = [] +insert_avg_latency_search.append('Insert') +insert_avg_latency_search.append('Avg Latency') +for i in range(0, len(insert_throughput)): + insert_avg_latency_search.append( 1000000 / insert_throughput[i]) + +for i in range(len(insert_throughput), len(RSS)): + insert_avg_latency_search.append('') + +with open(sys.argv[2], 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency_search, tail_latency_search, throughput_search, RSS_search, insert_avg_latency_search)) \ No newline at end of file diff --git a/Script_AE/Figure6/process_spann.py b/Script_AE/Figure6/process_spann.py new file mode 100644 index 000000000..294965945 --- /dev/null +++ b/Script_AE/Figure6/process_spann.py @@ -0,0 +1,123 @@ +import sys +import re +import csv + +log_f = open(sys.argv[1]) + +avg_latency = [] +tail_latency = [] +throughput = [] +insert_avg_latency = [] +RSS = [] + +line_count = 0 + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[1] == "Samppling": + done = False + while not done: + line = log_f.readline() + result_group = line.split() + if len(result_group) > 4 and result_group[4] == "RSS": + RSS.append(int(result_group[16])) + elif len(result_group) > 7 and result_group[7] == "AvgQPS:": + throughput.append(float(result_group[8].rstrip('.'))) + elif len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line = log_f.readline() + result_group = line.split() + avg_latency.append(float(result_group[1])) + tail_latency.append(float(result_group[6])) + done = True + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + +log_f.close() +log_f = open(sys.argv[1]) + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Insert" and result_group[2] == "Latency": + line = log_f.readline() + line_count+=1 + result_group = line.split() + read = False + while not read: + if len(result_group) <= 1: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + try: + data = float(result_group[1]) + except ValueError: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + read = True + insert_avg_latency.append(float(result_group[1])) + + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + except ValueError: + print('Error when processing [at line]: {}: {}'.format(line_count , line)) + break + +avg_latency_search = [] +avg_latency_search.append('Search') +avg_latency_search.append('Avg Latency') +for i in range(0, len(avg_latency)): + avg_latency_search.append(avg_latency[i]) + +tail_latency_search = [] +tail_latency_search.append(' ') +tail_latency_search.append('Tail Latency') +for i in range(0, len(tail_latency)): + tail_latency_search.append(tail_latency[i]) + +throughput_search = [] +throughput_search.append('') +throughput_search.append('Throughput') +for i in range(0, len(throughput)): + throughput_search.append(throughput[i]) + +RSS_search = [] +RSS_search.append('') +RSS_search.append('RSS') +for i in range(0, len(RSS)): + RSS_search.append(RSS[i]/1024) + +insert_avg_latency_search = [] +insert_avg_latency_search.append('Insert') +insert_avg_latency_search.append('Avg Latency') +for i in range(0, len(insert_avg_latency)): + insert_avg_latency_search.append(insert_avg_latency[i]) + +for i in range(len(insert_avg_latency), len(RSS)): + insert_avg_latency_search.append('') + +with open(sys.argv[2], 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency_search, tail_latency_search, throughput_search, RSS_search, insert_avg_latency_search)) \ No newline at end of file diff --git a/Script_AE/Figure6/process_spfresh.py b/Script_AE/Figure6/process_spfresh.py new file mode 100644 index 000000000..aa4504172 --- /dev/null +++ b/Script_AE/Figure6/process_spfresh.py @@ -0,0 +1,125 @@ +import sys +import re +import csv + +log_f = open(sys.argv[1]) + +avg_latency = [] +tail_latency = [] +throughput = [] +insert_avg_latency = [] +RSS = [] + +line_count = 0 + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[1] == "Samppling": + done = False + while not done: + line = log_f.readline() + result_group = line.split() + if len(result_group) > 4 and result_group[4] == "RSS": + RSS.append(int(result_group[16])) + elif len(result_group) > 7 and result_group[7] == "AvgQPS:": + throughput.append(float(result_group[8].rstrip('.'))) + elif len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line = log_f.readline() + result_group = line.split() + avg_latency.append(float(result_group[1])) + tail_latency.append(float(result_group[6])) + done = True + + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + +log_f.close() +log_f = open(sys.argv[1]) + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Insert" and result_group[2] == "Latency": + line = log_f.readline() + line_count+=1 + result_group = line.split() + read = False + while not read: + if len(result_group) <= 1: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + try: + data = float(result_group[1]) + except ValueError: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + read = True + insert_avg_latency.append(float(result_group[1])) + + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + except ValueError: + print('Error when processing [at line]: {}: {}'.format(line_count , line)) + break + +avg_latency_search = [] +avg_latency_search.append('Search') +avg_latency_search.append('Avg Latency') +for i in range(0, len(avg_latency)): + avg_latency_search.append(avg_latency[i]) + +tail_latency_search = [] +tail_latency_search.append(' ') +tail_latency_search.append('Tail Latency') +for i in range(0, len(tail_latency)): + tail_latency_search.append(tail_latency[i]) + +throughput_search = [] +throughput_search.append('') +throughput_search.append('Throughput') +for i in range(0, len(throughput)): + throughput_search.append(throughput[i]) + +RSS_search = [] +RSS_search.append('') +RSS_search.append('RSS') +for i in range(0, len(RSS)): + RSS_search.append(RSS[i]/1024) + +insert_avg_latency_search = [] +insert_avg_latency_search.append('Insert') +insert_avg_latency_search.append('Avg Latency') +for i in range(0, len(insert_avg_latency)): + insert_avg_latency_search.append(insert_avg_latency[i]) + +for i in range(len(insert_avg_latency), len(RSS)): + insert_avg_latency_search.append('') + + +with open(sys.argv[2], 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency_search, tail_latency_search, throughput_search, RSS_search, insert_avg_latency_search)) \ No newline at end of file diff --git a/Script_AE/Figure7/iops_limitation.sh b/Script_AE/Figure7/iops_limitation.sh new file mode 100644 index 000000000..c8c6383e8 --- /dev/null +++ b/Script_AE/Figure7/iops_limitation.sh @@ -0,0 +1,12 @@ +cp /home/sosp/data/store_spacev100m/indexloader_iopslimit.ini /home/sosp/data/store_spacev100m/indexloader.ini +SearchThreadNumLine="107c SearchThreadNum=" +loaderPath="/home/sosp/data/store_spacev100m/indexloader.ini " +storePath="/home/sosp/data/store_spacev100m" +logPath="log_searchthread" + +for i in 1 2 4 8 10 12 +do + newSearchThreadNumLine=$SearchThreadNumLine$i + sed -i "$newSearchThreadNumLine" $loaderPath + PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh $storePath |tee $logPath$i +done \ No newline at end of file diff --git a/Script_AE/Figure7/limits.p b/Script_AE/Figure7/limits.p new file mode 100644 index 000000000..c094a1f9d --- /dev/null +++ b/Script_AE/Figure7/limits.p @@ -0,0 +1,51 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 1.75 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 2.07 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 +set style fill solid + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 noxtics ytics +# set key box ls 20 opaque fc rgb "#3fffffff" +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set ytics nomirror out autofreq offset 0.5,0,0 +set border lw 2 +set yrange [0:*] + +# Start the first plot +set output "IOPSLimit.pdf" + +set multiplot + +set xlabel "ThreadNum" offset 0,1,0 +set ylabel "Query per Second" offset 1,0,0 + +set y2tics out autofreq format '%gk' +set y2label "IOPS" +set y2range [0:700] +set key reverse Left +set boxwidth 0.5 + +set size 1.045, 0.92 +set origin -0.01, 0.08 + +# set title "IOPS Limit" offset 0, -0.7 +plot "IOPS_limit.csv" using 1:2:xtic(4) every ::1 with boxes title 'QPS' at 0.2, 0.07, \ + "IOPS_limit.csv" using 1:($3/1000.):xtic(4) every ::1 with linespoints title 'IOPS' at 0.8, 0.07 axes x1y2 ls 1 pointsize 0.5 + +unset multiplot \ No newline at end of file diff --git a/Script_AE/Figure7/plot_iops_result.sh b/Script_AE/Figure7/plot_iops_result.sh new file mode 100644 index 000000000..f57fbd474 --- /dev/null +++ b/Script_AE/Figure7/plot_iops_result.sh @@ -0,0 +1,2 @@ +python process_iopslimit.py log_searchthread +gnuplot limits.p \ No newline at end of file diff --git a/Script_AE/Figure7/process_iopslimit.py b/Script_AE/Figure7/process_iopslimit.py new file mode 100644 index 000000000..6208410d8 --- /dev/null +++ b/Script_AE/Figure7/process_iopslimit.py @@ -0,0 +1,50 @@ +import sys +import re +import csv + + +process_list = [1, 2, 4, 8, 10, 12] + + +throughput = [] +KIOPS = [] + +throughput.append('') +KIOPS.append('') + +line_count = 0 + +for i in process_list: + log_f = open(sys.argv[1] + str(i)) + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 7 and result_group[7] == "AvgQPS:": + throughput.append(float(result_group[8].rstrip('.'))) + while result_group[0] != 'IOPS:': + line = log_f.readline() + line_count+=1 + result_group = line.split() + KIOPS.append(float((result_group[1].rstrip('k')))*1000) + break + +process_list_search = [] +process_list_search.append('') +process_list_search += process_list + +batch = [] +batch.append('') +for i in range(0, 6): + batch.append(i) + +print(KIOPS) + +with open("IOPS_limit.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(batch, throughput, KIOPS, process_list_search)) \ No newline at end of file diff --git a/Script_AE/Figure8/plot_stress_result.sh b/Script_AE/Figure8/plot_stress_result.sh new file mode 100644 index 000000000..9ea0cb8ee --- /dev/null +++ b/Script_AE/Figure8/plot_stress_result.sh @@ -0,0 +1,2 @@ +python process_stress_test.py log_stress_spfresh.log +gnuplot stress_test_new.p \ No newline at end of file diff --git a/Script_AE/Figure8/process_stress_test.py b/Script_AE/Figure8/process_stress_test.py new file mode 100644 index 000000000..a76f19bdd --- /dev/null +++ b/Script_AE/Figure8/process_stress_test.py @@ -0,0 +1,246 @@ +import sys +import re +import csv + +log_f = open(sys.argv[1]) + +avg_latency = [] +tail_latency = [] +throughput = [] +insert_avg_latency = [] +RSS = [] +KIOPS = [] + +avg_latency_batch = [] +tail_latency_batch = [] +throughput_batch = [] +RSS_batch = [] +KIOPS_batch = [] + + +number_point_per_batch = 3 + +last_process = 0 + +line_count = 0 + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[1] == "Samppling": + this_process = int(result_group[3]) % 10000000 + if this_process < last_process: + + batch_len = len(avg_latency_batch) + duration = int(batch_len) / number_point_per_batch + # print(batch_len) + for i in range(0, number_point_per_batch): + avg_latency.append(float(avg_latency_batch[0+i*duration])) + tail_latency.append(float(tail_latency_batch[0+i*duration])) + throughput.append(float(throughput_batch[0+i*duration])) + RSS.append(RSS_batch[0+i*duration]) + KIOPS.append(KIOPS_batch[0+i*duration]) + + avg_latency_batch = [] + tail_latency_batch = [] + throughput_batch = [] + RSS_batch = [] + + if this_process != 0: + done = False + while not done: + line = log_f.readline() + line_count+=1 + result_group = line.split() + if len(result_group) > 4 and result_group[4] == "RSS": + RSS_batch.append(int(result_group[16])) + elif len(result_group) > 7 and result_group[7] == "AvgQPS:": + throughput_batch.append(float(result_group[8].rstrip('.'))) + elif len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + read = False + while not read: + try: + data = float(result_group[1]) + except ValueError: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + read = True + avg_latency_batch.append(float(result_group[1])) + tail_latency_batch.append(float(result_group[5])) + done = True + line = log_f.readline() + line_count+=1 + result_group = line.split() + while result_group[0] != 'IOPS:': + line = log_f.readline() + line_count+=1 + result_group = line.split() + KIOPS_batch.append(float((result_group[1].rstrip('k')))) + + + last_process = this_process + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + except ValueError: + print('Error when processing [at line]: {}: {}'.format(line_count , line)) + break + +log_f.close() +log_f = open(sys.argv[1]) + +while True: + try: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Insert" and result_group[2] == "Latency": + line = log_f.readline() + line_count+=1 + result_group = line.split() + read = False + while not read: + if len(result_group) <= 1: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + try: + data = float(result_group[1]) + except ValueError: + line = log_f.readline() + line_count+=1 + result_group = line.split() + continue + read = True + insert_avg_latency.append(float(result_group[1])) + + + except AttributeError: + print('Error when processing [at line]: {}'.format(line_count)) + break + except ValueError: + print('Error when processing [at line]: {}: {}'.format(line_count , line)) + break + +if this_process != 0: + batch_len = len(avg_latency_batch) + duration = int(batch_len) / number_point_per_batch + print(batch_len) + + for i in range(0, number_point_per_batch): + avg_latency.append(float(avg_latency_batch[0+i*duration])) + tail_latency.append(float(tail_latency_batch[0+i*duration])) + throughput.append(float(throughput_batch[0+i*duration])) + RSS.append(RSS_batch[0+i*duration]) + KIOPS.append(KIOPS_batch[0+i*duration]) + +step = 0.25 +batch = [] +batch.append('SPFresh') +batch.append('Batch') +for i in range(0,20): + batch.append(i+step) + batch.append(i+step*2) + batch.append(i+step*3) + +avg_latency_search = [] +avg_latency_search.append('Search') +avg_latency_search.append('Avg Latency') +for i in range(0, len(avg_latency)): + avg_latency_search.append(avg_latency[i]) + +tail_latency_search = [] +tail_latency_search.append(' ') +tail_latency_search.append('Tail Latency') +for i in range(0, len(tail_latency)): + tail_latency_search.append(tail_latency[i]) + +tail_latency_search = [] +tail_latency_search.append('') +tail_latency_search.append('Tail Latency') +for i in range(0, len(tail_latency)): + tail_latency_search.append(tail_latency[i]) + +throughput_search = [] +throughput_search.append('') +throughput_search.append('Throughput') +for i in range(0, len(throughput)): + throughput_search.append(throughput[i]) + +RSS_search = [] +RSS_search.append('') +RSS_search.append('RSS') +for i in range(0, len(RSS)): + RSS_search.append(RSS[i]) + +insert_avg_latency_search = [] +insert_avg_latency_search.append('Insert') +insert_avg_latency_search.append('Avg Latency') +for i in range(0, len(insert_avg_latency)): + insert_avg_latency_search.append(insert_avg_latency[i]) + +for i in range(len(insert_avg_latency), len(RSS)): + insert_avg_latency_search.append('') + +KIOPS_search = [] +KIOPS_search.append('') +KIOPS_search.append('KIOPS') +for i in range(0, len(KIOPS)): + KIOPS_search.append(KIOPS[i]) + +temp_search = [] +temp_search.append('') +temp_search.append('') +for i in range(0, len(KIOPS)): + temp_search.append('') + +with open('stress_test.csv', 'w') as f: + writer = csv.writer(f, delimiter=',') + for line in zip(batch, temp_search, avg_latency_search, tail_latency_search, throughput_search, RSS_search, insert_avg_latency_search, temp_search, temp_search, temp_search, KIOPS_search): + writer.writerow(line) + +batch_new = [] +batch_new.append('SPFresh') +batch_new.append('Batch') +for i in range(0,20): + batch_new.append(i) + +insert_avg_latency_new = [] +insert_avg_latency_new.append('Insert') +insert_avg_latency_new.append('Avg Latency') +for i in range(0, len(insert_avg_latency)): + insert_avg_latency_new.append(insert_avg_latency[i]) + +temp_search_new = [] +temp_search_new.append('') +temp_search_new.append('') +for i in range(0, 20): + temp_search_new.append('') + +# print(KIOPS_search) + +with open('stress_test2.csv', 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(batch_new, insert_avg_latency_new, temp_search_new, temp_search_new)) + diff --git a/Script_AE/Figure8/stress_spfresh.sh b/Script_AE/Figure8/stress_spfresh.sh new file mode 100644 index 000000000..d7b653c68 --- /dev/null +++ b/Script_AE/Figure8/stress_spfresh.sh @@ -0,0 +1 @@ +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift1b/|tee log_stress_spfresh.log \ No newline at end of file diff --git a/Script_AE/Figure8/stress_test_new.p b/Script_AE/Figure8/stress_test_new.p new file mode 100644 index 000000000..40b8da114 --- /dev/null +++ b/Script_AE/Figure8/stress_test_new.p @@ -0,0 +1,80 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 2.8 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 1.5 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 +# set key box ls 20 opaque fc rgb "#3fffffff" width 0.5 +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 format "" +set ytics nomirror out autofreq offset 0.5,0,0 +set border lw 2 +# set yrange [0:*] + +# Start the first plot +set output "StressTest.pdf" +set multiplot # layout 6, 1 +set lmargin 10 + +unset xlabel +set yrange [3:6] +set ytics 1 +# set ylabel "latency (ms)" offset 1.5,0,0 +unset ylabel +set label 1 "Latency (ms)" at screen 0.05, graph 0.5 center rotate by 90 +set key at graph 0.99, 0.37 reverse Left + +unset key +set title "Search Tail Latency" offset 0, -0.7 +set origin 0, 0.68 +set size 1, 0.38 +plot "stress_test.csv" using 1:4 every ::3 with lines ls 1 + +# Start the second plot + +set yrange [0:4000] +set ytics ("0" 0, "1k" 1000, "2k" 2000, "3k" 3000, "4k" 4000) +# set ylabel "Query per Second" offset 1.5,0,0 +set label 1 "Query per Second" at screen 0.05, graph 0.5 center rotate by 90 +set title "Throughput" offset 0, -0.7 +set origin 0, 0.3 +set size 1, 0.48 +set key at graph 0.99, 0.37 reverse Left +plot "stress_test.csv" using 1:5 every ::3 with lines title 'Search Throughput' ls 1, \ + "stress_test2.csv" using 1:(4*1000000./$2) every ::3 with lines title 'Insert Throughput' ls 2 + +# Start the third plot + +set yrange [300:600] +set xlabel "Batches/Days" offset 0,1,0 +# set ylabel 'IOPS' offset 1.5,0,0 +set label 1 "IOPS" at screen 0.05, graph 0.63 center rotate by 90 +set xtics format '%g' +set ytics out format '%gk' +set ytics ("400k" 400, "600k" 600) +set title "IOPS" offset 0, -0.7 +# set key box ls 20 opaque fc rgb "#3fffffff" width 0.5 +set key at graph 0.99, 0.35 reverse Left +set origin 0, 0 +set size 1, 0.4 +plot "stress_test.csv" using 1:11 every ::3 with lines notitle ls 1, \ + (400) with lines title 'Guaranteed Limit' ls 5 dashtype 2 + +unset key + +unset multiplot \ No newline at end of file diff --git a/Script_AE/Figure9/data_shifting.sh b/Script_AE/Figure9/data_shifting.sh new file mode 100644 index 000000000..127f46858 --- /dev/null +++ b/Script_AE/Figure9/data_shifting.sh @@ -0,0 +1,16 @@ +# noreassign/static/top64/inplace + +# noreassign +cp /home/sosp/data/store_sift_cluster/indexloader_noreassign.ini /home/sosp/data/store_sift_cluster/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster |tee log_noreassign.log + +# static +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster_2m |tee log_static.log + +# inplace +cp /home/sosp/data/store_sift_cluster/indexloader_inplace.ini /home/sosp/data/store_sift_cluster/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster |tee log_inplace.log + +# top64 +cp /home/sosp/data/store_sift_cluster/indexloader_top64.ini /home/sosp/data/store_sift_cluster/indexloader.ini +PCI_ALLOWED="c636:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=/home/sosp/SPFresh/bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E /home/sosp/SPFresh/Release/spfresh /home/sosp/data/store_sift_cluster |tee log_split+reassign.log \ No newline at end of file diff --git a/Script_AE/Figure9/parameter_study_shifting.p b/Script_AE/Figure9/parameter_study_shifting.p new file mode 100644 index 000000000..1c866d479 --- /dev/null +++ b/Script_AE/Figure9/parameter_study_shifting.p @@ -0,0 +1,53 @@ +# For a single column, set the width at 3.3 inches +# For across two columns, set the width at 7 inches + +set terminal pdfcairo size 3.3, 1.75 font 'Linux Biolinum O,12' +# set terminal pdfcairo size 7, 1.75 font "UbuntuMono-Regular, 11" + +# set default line style +set style line 1 lc rgb '#056bfa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 2 lc rgb '#05a8fa' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 3 lc rgb '#fb8500' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 4 lc rgb '#ffb703' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 5 lc rgb '#b30018' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 6 lc rgb '#fa3605' lt 1 lw 1.7 pt 7 ps 1.5 +set style line 7 lc rgb '#80d653' lt 1 lw 1.7 pt 7 ps 1.5 +# set style line 1 lc rgb '#00d5ff' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 2 lc rgb '#000080' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 3 lc rgb '#ff7f0e' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 4 lc rgb '#008176' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 5 lc rgb '#b3b3b3' lt 1 lw 1.5 pt 7 ps 1.5 +# set style line 6 lc rgb '#000000' lt 1 lw 1.5 pt 7 ps 1.5 + +# set grid style +set style line 20 lc rgb '#dddddd' lt 1 lw 1 + +set datafile separator "," +set encoding utf8 +set autoscale +set grid ls 20 +set key box ls 20 opaque width -2 +set tics scale 0.5 +set xtics nomirror out autofreq offset 0,0.5,0 +set ytics nomirror out autofreq offset 0.5,0,0 +set border lw 2 + +# Start the first plot +set output "ParameterStudy1.pdf" + +set xlabel "Latency (ms)" offset 0,1,0 +set ylabel 'Recall 10\@10' offset 1.5,0,0 +set key bottom right reverse Left + +set size 1, 1 +set origin 0, 0 +# set title "Search Latency" offset 0, -0.7 +unset title +set yrange [0.92:1] +set ytics 0.02 +plot "parameter_study_shifting.csv" using 3:4 every ::3 with lines title 'Static' ls 7, \ + "parameter_study_shifting.csv" using 1:2 every ::3 with lines title 'In-place Update' ls 4, \ + "parameter_study_shifting.csv" using 5:6 every ::3 with lines title 'In-place Update + Split Only' ls 3, \ + "parameter_study_shifting.csv" using 7:8 every ::3 with lines title 'In-place Update + Split/Reassign' ls 1 + +unset key \ No newline at end of file diff --git a/Script_AE/Figure9/plot_shifting_result.sh b/Script_AE/Figure9/plot_shifting_result.sh new file mode 100644 index 000000000..351af878a --- /dev/null +++ b/Script_AE/Figure9/plot_shifting_result.sh @@ -0,0 +1,2 @@ +python process_shifting.py +gnuplot parameter_study_shifting.p \ No newline at end of file diff --git a/Script_AE/Figure9/process_shifting.py b/Script_AE/Figure9/process_shifting.py new file mode 100644 index 000000000..a970fa7c4 --- /dev/null +++ b/Script_AE/Figure9/process_shifting.py @@ -0,0 +1,84 @@ +import sys +import re +import csv + +process_list = ["log_inplace.log", "log_static.log", "log_noreassign.log", "log_split+reassign.log"] + +avg_latency = [] +accuracy = [] + +last_data_group = [] + +for fileName in process_list: + + log_f = open(fileName) + + templist_latency = [] + + templist_accuracy = [] + + templist_accuracy.append('') + templist_latency.append('') + + line_count = 0 + + if fileName != "log_static.log": + + while True: + line = log_f.readline() + line_count+=1 + + result_group = line.split() + + if len(result_group) > 2 and result_group[1] == "Total" and result_group[2] == "Vector": + break + + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + templist_latency.append(float(result_group[1])) + last_data_group = result_group + if len(result_group) > 2 and result_group[1] == "Recall10@10:": + templist_accuracy.append(float(result_group[2])) + else: + while True: + line = log_f.readline() + line_count+=1 + + if line == '': + break + + result_group = line.split() + + if len(result_group) > 3 and result_group[1] == "Updating" and result_group[2] == "numThread:": + break + + if len(result_group) > 1 and result_group[0] == "Total" and result_group[1] == "Latency": + line = log_f.readline() + line_count+=1 + line = log_f.readline() + line_count+=1 + result_group = line.split() + templist_latency.append(float(result_group[1])) + last_data_group = result_group + if len(result_group) > 2 and result_group[1] == "Recall10@10:": + templist_accuracy.append(float(result_group[2])) + + accuracy.append(templist_accuracy) + avg_latency.append(templist_latency) + +with open("parameter_study_shifting.csv", 'w') as f: + writer = csv.writer(f, delimiter=',') + writer.writerows(zip(avg_latency[0], accuracy[0], avg_latency[1], accuracy[1], avg_latency[2], accuracy[2], avg_latency[3], accuracy[3])) diff --git a/Script_AE/README.md b/Script_AE/README.md new file mode 100644 index 000000000..016fb6019 --- /dev/null +++ b/Script_AE/README.md @@ -0,0 +1,144 @@ +# **Reproduce All Experiment Results(Result Reproduced)** + +In this folder, we provide scripts for reproducing figures in our paper. A Standard_L16s_v3 instance is needed to reproduce all the results. + +The name of each script corresponds to the number of each figure in our paper. Since some of the scripts in this directory take a long computing time (as we mark below), we strongly recommend you create a tmux session to avoid script interruption due to network instability. + +## **Dataset and Benchmark** +> **For artifact evaluation, we strongly recommend the reviewers to use those pre-downloaded and built dataset and benchmark in the ~/data/ folder on the provided Lsv3 instance, since the provided index and dataset requires lots of computation resource and takes about two weeaks to generate these data in a 80-core machine (we build these data offline)** + +To generate data, you can refer to [./Script_AE/iniFile/README.md](./Script_AE_iniFile). + +## **Additional Setup** +here is additional setup for evaluation + +### **How to check current state** +> if we can not see the /dev/nvme0n1 after the following command, it means this disk has been taken over by SPDK +> Using SPDK will bind the disk to SPDK and the /dev/nvme0n1 will not be shown after "lsblk" +```bash +lsblk +``` + +> If the device has been binded to SPDK, the following command can reset +```bash +sudo ./SPFresh/ThirdParty/spdk/scripts/setup.sh reset +``` + +### **SPFresh & SPANN+** +> Since we use SPDK to build our storage, we need to bind PCI dev to SPDK +```bash +sudo nvme format /dev/nvme0n1 +sudo ./SPFresh/ThirdParty/spdk/scripts/setup.sh +cp bdev.json /home/sosp/SPFresh/ +``` + +### **DiskANN** +> DiskANN Baseline use the filesystem provided by kernel to maintain its on-disk file, so if we want to run DiskANN, we need to release the disk from SPDK +```bash +sudo ./SPFresh/ThirdParty/spdk/scripts/setup.sh reset +``` + +> Prepare DiskANN for evaluation +```bash +sudo mkfs.ext4 /dev/nvme0n1 +sudo mount /dev/nvme0n1 /home/sosp/testbed +sudo chmod 777 /home/sosp/testbed +``` + +> If you want to switch Evaluation from DiskANN to SPFresh, you must umount the filesystem at first and follow the above SPDK command +```bash +sudo umount /home/sosp/testbed +sudo nvme format /dev/nvme0n1 +sudo ./SPFresh/ThirdParty/spdk/scripts/setup.sh +``` + +## **Start to Run** + +> all files for running is already set in those folders in Azure VM for AE. + +### **Motivation (Figure 1)** +> It takes about 22 minutes +```bash +bash motivation.sh +``` +> Plot the result +```bash +bash plot_motivation_result.sh +``` + +### **Overall Performance (Figure 6)** +> It takes about 6 days to reproduce this figure +#### **SPFresh** +> This takes about 43 hours and 50 minutes +```bash +bash overall_spacev_spfresh.sh +``` +#### **SPANN+** +> This takes about 43 hours and 30 minutes +```bash +bash overall_spacev_spann.sh +``` +#### **DiskANN** +> Before running, move DiskANN Index to the disk (if binded by SPDK, First reset and mkfs (follow the instruction above)) +```bash +cp -r /home/sosp/data/store_diskann_100m /home/sosp/testbed +``` +> This takes about 35 hours and 44 minutes +```bash +bash overall_spacev_diskann.sh +``` +#### **Plot the result of the baselines** +```bash +bash plot_overall_result.sh +``` + +### **Disk IOPS Limitation (Figure 7)** +> This takes about 5 minutes +```bash +bash iops_limitation.sh +``` +> Plot the result +```bash +bash plot_iops_result.sh +``` + +### **Stress Test (Figure 8)** +> This takes about 35 hours and 26 minutes +```bash +bash stress_spfresh.sh +``` +> Plot the result +```bash +bash plot_stress_result.sh +``` + +### **Data Shifting Micro-benchmark (Figure 9)** +> This takes about 1 hours and 10 minutes +```bash +bash data_shifting.sh +``` +> Plot the result +```bash +bash plot_shifting_result.sh +``` +### **Parameter Study: Reassign Range (Figure 10)** +> This takes about 1 hours and 30 minutes +```bash +bash parameter_study_range.sh +``` +> Plot the result +```bash +bash plot_range_result.sh +``` +### **Foreground Scalability (Figure 11)** +> This takes about 11 minutes +```bash +bash parameter_study_balance.sh +``` +> Plot the result +```bash +bash plot_balance_result.sh +``` + + + diff --git a/Script_AE/bdev.json b/Script_AE/bdev.json new file mode 100644 index 000000000..35445e4d2 --- /dev/null +++ b/Script_AE/bdev.json @@ -0,0 +1,17 @@ +{ + "subsystems": [ + { + "subsystem": "bdev", + "config": [ + { + "method": "bdev_nvme_attach_controller", + "params": { + "trtype": "pcie", + "name": "Nvme0", + "traddr": "c636:00:00.0" + } + } + ] + } + ] +} diff --git a/Script_AE/data_clustering_sift.py b/Script_AE/data_clustering_sift.py new file mode 100644 index 000000000..8f81933b1 --- /dev/null +++ b/Script_AE/data_clustering_sift.py @@ -0,0 +1,112 @@ +import numpy as np +import argparse +import struct +from sklearn.cluster import KMeans +import sklearn.metrics + + +def process_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", help="The input file (.i8bin)") + parser.add_argument("--dst", help="The output file prefix (.i8bin)") + return parser.parse_args() + + +if __name__ == "__main__": + clusters = 2 + args = process_args() + + # Read topk vector one by one + vecs = [] + row_bin = ""; + dim_bin = ""; + with open(args.src, "rb") as f: + + row_bin = f.read(4) + assert row_bin != b'' + row, = struct.unpack('i', row_bin) + + dim_bin = f.read(4) + assert dim_bin != b'' + dim, = struct.unpack('i', dim_bin) + + i = 0 + while 1: + + # The next dim byte is for a vector for spacev + vec = struct.unpack('b' * dim, f.read(dim)) + + # Store it + vecs.append(vec) + i += 1 + if i == row: + break + + # clustering vectors + + vecs = np.array(vecs, dtype=np.int8) + assert vecs.shape[0] == row + print("vecs.shape:", vecs.shape) + estimator = KMeans(n_clusters=clusters) + estimator.fit(vecs) + label_pred = estimator.labels_ + + print("cluster finished") + #print(sklearn.metrics.silhouette_score(vecs, label_pred, metric='euclidean')) + + # generate result + vec_list = [] + vec_num_list = [] + for i in range(0, clusters): + vec_num_list.append(0) + + for i in range(0, clusters): + with open(args.src, "rb") as f: + + row_bin = f.read(4) + assert row_bin != b'' + row, = struct.unpack('i', row_bin) + + dim_bin = f.read(4) + assert dim_bin != b'' + dim, = struct.unpack('i', dim_bin) + + j = 0 + + vecs = "" + + while 1: + + + # The next dim byte is for a vector for spacev + + if label_pred[j] == i: + vecs += f.read(dim) + vec_num_list[label_pred[j]] += 1 + else: + f.read(dim) + + j += 1 + + if j % 100000 == 0: + print(j) + + if j == row: + break + vec_list.append(vecs) + + print("cluster_result: ", vec_num_list) + + for i in range(0, clusters): + with open(args.dst + str(i), "wb") as f: + f.write(struct.pack('i', vec_num_list[i])) + f.write(dim_bin) + f.write(vec_list[i]) + + with open(args.dst + str(clusters), "wb") as f: + f.write(struct.pack('i', row)) + f.write(dim_bin) + for i in range(0, clusers): + f.write(vec_list[i]) + + diff --git a/Script_AE/generateOverallPerformanceTraceAndTruth.sh b/Script_AE/generateOverallPerformanceTraceAndTruth.sh new file mode 100644 index 000000000..d41a24f2c --- /dev/null +++ b/Script_AE/generateOverallPerformanceTraceAndTruth.sh @@ -0,0 +1,21 @@ +setname="6c VectorPath=spacev100m_update_set" +truthname="18c TruthPath=spacev100m_update_truth" +deletesetname="spacev100m_update_set" +reservesetname="spacev100m_update_reserve" +currentsetname="spacev100m_update_current" +for i in {0..99} +do + /home/sosp/SPFresh/Release/usefultool -GenTrace true --vectortype int8 --VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --filetype DEFAULT --UpdateSize 1000000 --BaseNum 100000000 --ReserveNum 100000000 --CurrentListFileName spacev100m_update_current --ReserveListFileName spacev100m_update_reserve --TraceFileName spacev100m_update_trace -NewDataSetFileName spacev100m_update_set -d 100 --Batch $i -f DEFAULT + newsetname=$setname$i + newtruthname=$truthname$i + newdeletesetname=$deletesetname$i + newreservesetname=$reservesetname$i + newcurrentsetname=$currentsetname$i + sed -i "$newsetname" genTruth.ini + sed -i "$newtruthname" genTruth.ini + /home/sosp/SPFresh/Release/ssdserving genTruth.ini + /home/sosp/SPFresh/Release/usefultool -ConvertTruth true --vectortype int8 --VectorPath /home/sosp/data/spacev_data/spacev200m_base.i8bin --filetype DEFAULT --UpdateSize 1000000 --BaseNum 100000000 --ReserveNum 100000000 --CurrentListFileName spacev100m_update_current --ReserveListFileName spacev100m_update_reserve --TraceFileName spacev100m_update_trace -NewDataSetFileName spacev100m_update_set -d 100 --Batch $i -f DEFAULT --truthPath spacev100m_update_truth --truthType DEFAULT --querySize 10000 --resultNum 100 + rm -rf $deletesetname + rm -rf $newreservesetname + rm -rf $newcurrentsetname +done \ No newline at end of file diff --git a/Script_AE/generate_dataset.py b/Script_AE/generate_dataset.py new file mode 100644 index 000000000..bd6e5fde7 --- /dev/null +++ b/Script_AE/generate_dataset.py @@ -0,0 +1,45 @@ +import numpy as np +import argparse +import struct + +def process_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", help="The input file (.fvecs)") + parser.add_argument("--dst", help="The output file (.fvecs)") + parser.add_argument("--topk", type=int, help="The number of element to pick up") + return parser.parse_args() + + +if __name__ == "__main__": + args = process_args() + + # Read topk vector one by one + vecs = "" + row_bin = ""; + dim_bin = ""; + with open(args.src, "rb") as f: + + row_bin = f.read(4) + assert row_bin != b'' + row, = struct.unpack('i', row_bin) + + dim_bin = f.read(4) + assert dim_bin != b'' + dim, = struct.unpack('i', dim_bin) + + i = 0 + while 1: + + # The next 4 * dim byte is for a vector + vec = f.read(dim) + + # Store it + vecs += vec + i += 1 + if i == args.topk: + break + + with open(args.dst, "wb") as f: + f.write(struct.pack('i', args.topk)) + f.write(dim_bin) + f.write(vecs) \ No newline at end of file diff --git a/Script_AE/iniFile/README.md b/Script_AE/iniFile/README.md new file mode 100644 index 000000000..95ffadc0e --- /dev/null +++ b/Script_AE/iniFile/README.md @@ -0,0 +1,96 @@ +## **If want to generate these data** + +> In our paper, we use SIFT and SPACEV dataset and synthetic benchmark. +these dataset can be download here: +``` +http://big-ann-benchmarks.com/neurips21.html +``` +> Download Dataset +```bash +wget https://dl.fbaipublicfiles.com/billion-scale-ann-benchmarks/bigann/base.1B.u8bin +wget https://comp21storage.blob.core.windows.net/publiccontainer/comp21/spacev1b/spacev1b_base.i8bin +``` +> Download Query +```bash +wget https://dl.fbaipublicfiles.com/billion-scale-ann-benchmarks/bigann/query.public.10K.u8bin +wget https://comp21storage.blob.core.windows.net/publiccontainer/comp21/spacev1b/query.i8bin +``` + +> To generate dataset for Overall Performance +```bash +# generate dataset for index build +python generate_dataset.py --src source_file_path --dst output_file_path --topk 100000000 +# generate dataset for update +python generate_dataset.py --src source_file_path --dst output_file_path --topk 200000000 +``` +> To generate trace for Overall Performance & generate truth for Overall Performance + +This will takes hundreds of hours to generate groundTruth (the nearest K vectors of all queries) without GPU support + +```bash +bash generateOverallPerformanceTraceAndTruth.sh +``` + +> To generate trace for Stress Test + +```bash +/home/sosp/SPFresh/Release/usefultool -GenStress true --vectortype UInt8 --VectorPath /home/sosp/data/sift_data/base.1B.u8bin --filetype DEFAULT --UpdateSize 10000000 --BaseNum 1000000000 --TraceFileName bigann1b_update_trace -d 128 --Batch 20 -f DEFAULT +``` + +> To generate SPANN Index for Overall Peformance + +build base SPANN Index + +This will takes 1 days to generate SPANN Index of SPACEV100M with a 160 threads machine (we build it offline) + +```bash +/home/sosp/SPFresh/Release/ssdserving build_SPANN_spacev100m.ini +mv iniFile/store_spacev100m/*.ini /home/sosp/data/store_spacev100m +``` + +> To generate DiskANN Index for Overall Performance + +This will takes 0.5 days to generate DiskANN Index of SPACEV100M with a 160 threads machine (we build it offline) + +```bash +mkdir /home/yuming/data/store_diskann_100m +/home/sosp/DiskANN_Baseline/build/tests/build_disk_index int8 /home/sosp/data/spacev_data/spacev100m_base.i8bin /home/yuming/data/store_diskann_100m/diskann_spacev_100m_ 64 75 128 128 16 l2 0 +``` + +> To generate SPANN Index for Stress Test + +This will takes 5 days to generate SPANN Index of SPACEV100M with a 160 threads machine (we build it offline) + +```bash +/home/sosp/SPFresh/Release/ssdserving build_SPANN_sift1b.ini +mv iniFile/store_sift1b/indexloader_stress.ini /home/sosp/data/store_sift_cluster/indexloader.ini +``` + +> To generate data for figure 1,9,10 +```bash +# using sift +python generate_dataset.py --src /home/sosp/data/sift_data/base.1B.u8bin --dst /home/sosp/data/sift_data/bigann2m_base.u8bin --topk 2000000 +# this command require numpy and sklearn +python data_clustering_sift.py --src /home/sosp/data/sift_data/bigann2m_base.u8bin --dst /home/sosp/data/sift_data/bigann1m_update_clustering +mv /home/sosp/data/sift_data/bigann1m_update_clustering0 /home/sosp/data/sift_data/bigann1m_update_clustering +mv /home/sosp/data/sift_data/bigann1m_update_clustering1 /home/sosp/data/sift_data/bigann1m_update_clustering_trace0 +mv /home/sosp/data/sift_data/bigann1m_update_clustering2 /home/sosp/data/sift_data/bigann2m_update_clustering +#generate truth +/home/sosp/SPFresh/Release/ssdserving genTruth_clustering.ini +mv /home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth0 /home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth + +#build index +/home/sosp/SPFresh/Release/ssdserving build_clustering_1m.ini +mv iniFile/store_sift_cluster/*.ini /home/sosp/data/store_sift_cluster/ +/home/sosp/SPFresh/Release/ssdserving build_clustering_2m.ini +mv iniFile/store_sift_cluster_2m/indexloader_clustering_2m.ini /home/sosp/data/store_sift_cluster/indexloader.ini + +``` + +> To generate data for figure 11 +```bash +python generate_dataset.py --src /home/sosp/data/sift_data/base.1B.u8bin --dst /home/sosp/data/sift_data/bigann1m_base.u8bin --topk 1000000 +# build index +/home/sosp/SPFresh/Release/ssdserving build_sift1m.ini +mv iniFile/store_sift1m/indexloader_sift1m.ini /home/sosp/data/store_sift1m/indexloader.ini +``` \ No newline at end of file diff --git a/Script_AE/iniFile/build_SPANN_sift1b.ini b/Script_AE/iniFile/build_SPANN_sift1b.ini new file mode 100644 index 000000000..96d56da32 --- /dev/null +++ b/Script_AE/iniFile/build_SPANN_sift1b.ini @@ -0,0 +1,56 @@ +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/base.1B.u8bin +VectorType=DEFAULT +VectorSize=1000000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_Int8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_Int8_L2_base_DEFUALT.bin +IndexDirectory=/home/sift/data/store_sift1b/ + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=true +NumberOfThreads=160 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +InternalResultNum=64 +NumberOfThreads=40 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/sift/data/store_sift1b/tmpdir \ No newline at end of file diff --git a/Script_AE/iniFile/build_SPANN_spacev100m.ini b/Script_AE/iniFile/build_SPANN_spacev100m.ini new file mode 100644 index 000000000..fc7b50446 --- /dev/null +++ b/Script_AE/iniFile/build_SPANN_spacev100m.ini @@ -0,0 +1,56 @@ +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=100 +VectorPath=/home/sosp/data/spacev_data/spacev100m_base.i8bin +VectorType=DEFAULT +VectorSize=100000000 +VectorDelimiter= +QueryPath=/home/sosp/data/spacev_data/query.i8bin +QueryType=DEFAULT +QuerySize=29316 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=29316 +WarmupDelimiter= +TruthPath=/home/sosp/data/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_spacev100m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=true + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +InternalResultNum=64 +NumberOfThreads=160 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_spacev100m/tmpdir \ No newline at end of file diff --git a/Script_AE/iniFile/build_clustering_1m.ini b/Script_AE/iniFile/build_clustering_1m.ini new file mode 100644 index 000000000..6b0013eb3 --- /dev/null +++ b/Script_AE/iniFile/build_clustering_1m.ini @@ -0,0 +1,96 @@ +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=true +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir \ No newline at end of file diff --git a/Script_AE/iniFile/build_clustering_2m.ini b/Script_AE/iniFile/build_clustering_2m.ini new file mode 100644 index 000000000..c6d94294c --- /dev/null +++ b/Script_AE/iniFile/build_clustering_2m.ini @@ -0,0 +1,100 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +VectorType=DEFAULT +VectorSize=2000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth0 +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster_2m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=true +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir diff --git a/Script_AE/iniFile/build_sift1m.ini b/Script_AE/iniFile/build_sift1m.ini new file mode 100644 index 000000000..69d505ab2 --- /dev/null +++ b/Script_AE/iniFile/build_sift1m.ini @@ -0,0 +1,94 @@ +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_base.u8bin +VectorType=DEFAULT +VectorSize=1000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/1m_trace/bigann-1M +TruthType=DEFAULT +GenerateTruth=false +IndexDirectory=/home/sosp/data/store_sift1m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=true +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=true +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=160 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +NumberOfThreads=40 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift1m/tmpdir \ No newline at end of file diff --git a/Script_AE/iniFile/genTruth.ini b/Script_AE/iniFile/genTruth.ini new file mode 100644 index 000000000..460d083f2 --- /dev/null +++ b/Script_AE/iniFile/genTruth.ini @@ -0,0 +1,24 @@ +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=100 +VectorPath=spacev100m_update_set88 +VectorType=DEFAULT +VectorSize=2000000 +VectorDelimiter= +QueryPath=/home/sosp/data/spacev_data/query.i8bin +QueryType=DEFAULT +QuerySize=29316 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=spacev100m_update_truth88 +TruthType=DEFAULT +GenerateTruth=true + +[SearchSSDIndex] +ResultNum=100 +NumberOfThreads=160 \ No newline at end of file diff --git a/Script_AE/iniFile/genTruth_clustering.ini b/Script_AE/iniFile/genTruth_clustering.ini new file mode 100644 index 000000000..67b3f9bda --- /dev/null +++ b/Script_AE/iniFile/genTruth_clustering.ini @@ -0,0 +1,24 @@ +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +VectorType=DEFAULT +VectorSize=2000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth0 +TruthType=DEFAULT +GenerateTruth=true + +[SearchSSDIndex] +ResultNum=100 +NumberOfThreads=160 \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift1b/indexloader_stress.ini b/Script_AE/iniFile/store_sift1b/indexloader_stress.ini new file mode 100644 index 000000000..7dda84e09 --- /dev/null +++ b/Script_AE/iniFile/store_sift1b/indexloader_stress.ini @@ -0,0 +1,110 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/base.1B.u8bin +VectorType=DEFAULT +VectorSize=1000000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_Int8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_Int8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift1b +HeadIndexFolder=HeadIndex + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=0.001000 +TPTNumber=64 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=3 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=16324 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=160 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=4 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +InternalResultNum=64 +NumberOfThreads=16 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +UseSPDK=true +ExcludeHead=false +UseDirectIO=true +SpdkBatchSize=256 +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=8 +SearchTimes=1 +Update=true +Days=20 +SteadyState=true +StressTest=true +InsertThreadNum=4 +AppendThreadNum=2 +ReassignThreadNum=0 +FullVectorPath=/home/sosp/data/sift_data/base.1B.u8bin +DisableReassign=false +ReassignK=64 +LatencyLimit=9.5 +CalTruth=true +SearchPostingPageLimit=3 +MaxDistRatio=1000000 +SearchDuringUpdate=true +UpdateFilePrefix=/home/sosp/data/sift_data/1b_trace/bigann1b_update_trace +SearchResult=/home/sosp/result_stress_sift_spfresh +DeleteQPS=3000 +MergeThreshold=10 +Sampling=4 +ShowUpdateProgress=false +BufferLength=1 \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift1m/indexloader_sift1m.ini b/Script_AE/iniFile/store_sift1m/indexloader_sift1m.ini new file mode 100644 index 000000000..0707dba7a --- /dev/null +++ b/Script_AE/iniFile/store_sift1m/indexloader_sift1m.ini @@ -0,0 +1,130 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_base.u8bin +VectorType=DEFAULT +VectorSize=1000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/1m_trace/bigann-1M +TruthType=DEFAULT +GenerateTruth=false +IndexDirectory=/home/sosp/data/store_sift1m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=160 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=40 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift1m/tmpdir +UseSPDK=true +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +SteadyState=true +Days=1 +InsertThreadNum=8 +AppendThreadNum=4 +ReassignThreadNum=0 +FullVectorPath=/home/sosp/data/sift_data/bigann2m_base.u8bin +DisableReassign=false +ReassignK=64 +LatencyLimit=100.0 +CalTruth=false +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=true +MergeThreshold=10 +UpdateFilePrefix=/home/sosp/data/sift_data/1m_trace/bigann1m_update_trace +DeleteQPS=8000 +ShowUpdateProgress=true +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=-1 +StepInternalResultNum=16 +SearchResult=/home/sosp/result_spfresh_balance +EndVectorNum=2000000 \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift_cluster/indexloader_inplace.ini b/Script_AE/iniFile/store_sift_cluster/indexloader_inplace.ini new file mode 100644 index 000000000..2f6431b49 --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster/indexloader_inplace.ini @@ -0,0 +1,133 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=300 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +Days=1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=true +ReassignK=0 +LatencyLimit=9.0 +CalTruth=true +SearchPostingPageLimit=300 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift_cluster/indexloader_nolimit.ini b/Script_AE/iniFile/store_sift_cluster/indexloader_nolimit.ini new file mode 100644 index 000000000..2c5d7ade4 --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster/indexloader_nolimit.ini @@ -0,0 +1,133 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=300 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +Days=1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=true +ReassignK=0 +LatencyLimit=100.0 +CalTruth=true +SearchPostingPageLimit=300 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift_cluster/indexloader_noreassign.ini b/Script_AE/iniFile/store_sift_cluster/indexloader_noreassign.ini new file mode 100644 index 000000000..3ae87b030 --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster/indexloader_noreassign.ini @@ -0,0 +1,133 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +Days=1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=true +ReassignK=0 +LatencyLimit=9.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift_cluster/indexloader_range.ini b/Script_AE/iniFile/store_sift_cluster/indexloader_range.ini new file mode 100644 index 000000000..1b71ded17 --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster/indexloader_range.ini @@ -0,0 +1,133 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +Days=1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=false +ReassignK=0 +LatencyLimit=9.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true diff --git a/Script_AE/iniFile/store_sift_cluster/indexloader_top64.ini b/Script_AE/iniFile/store_sift_cluster/indexloader_top64.ini new file mode 100644 index 000000000..1a5f01d8a --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster/indexloader_top64.ini @@ -0,0 +1,133 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann1m_update_clustering +VectorType=DEFAULT +VectorSize=1049411 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann1m_update_clustering_origin_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +Days=1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=false +ReassignK=64 +LatencyLimit=9.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=true +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true \ No newline at end of file diff --git a/Script_AE/iniFile/store_sift_cluster_2m/indexloader_clustering_2m.ini b/Script_AE/iniFile/store_sift_cluster_2m/indexloader_clustering_2m.ini new file mode 100644 index 000000000..8eecbc8b5 --- /dev/null +++ b/Script_AE/iniFile/store_sift_cluster_2m/indexloader_clustering_2m.ini @@ -0,0 +1,134 @@ +[Index] +IndexAlgoType=SPANN +ValueType=UInt8 + +[Base] +ValueType=UInt8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=128 +VectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +VectorType=DEFAULT +VectorSize=2000000 +VectorDelimiter= +QueryPath=/home/sosp/data/sift_data/query.public.10K.u8bin +QueryType=DEFAULT +QuerySize=10000 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=10000 +WarmupDelimiter= +TruthPath=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth0 +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_sift_cluster_2m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=1 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=10 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=3 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_sift_cluster/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=4 +SearchTimes=1 +Update=true +SteadyState=true +Days=-1 +Step=950589 +InsertThreadNum=1 +AppendThreadNum=8 +ReassignThreadNum=0 +TruthFilePrefix=/home/sosp/data/sift_data/bigann2m_update_clustering_origin_truth +FullVectorPath=/home/sosp/data/sift_data/bigann2m_update_clustering +DisableReassign=false +ReassignK=0 +LatencyLimit=9.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=false +MergeThreshold=-1 +UpdateFilePrefix=/home/sosp/data/sift_data/bigann1m_update_clustering_trace +DeleteQPS=16000 +ShowUpdateProgress=false +Sampling=2 +OnlySearchFinalBatch=false +MinInternalResultNum=16 +MaxInternalResultNum=128 +StepInternalResultNum=16 +LoadAllVectors=true \ No newline at end of file diff --git a/Script_AE/iniFile/store_spacev100m/indexloader_iopslimit.ini b/Script_AE/iniFile/store_spacev100m/indexloader_iopslimit.ini new file mode 100644 index 000000000..d58e082e0 --- /dev/null +++ b/Script_AE/iniFile/store_spacev100m/indexloader_iopslimit.ini @@ -0,0 +1,129 @@ +[Index] +IndexAlgoType=SPANN +ValueType=Int8 + +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=100 +VectorPath=/home/sosp/data/spacev_data/spacev100m_base.i8bin +VectorType=DEFAULT +VectorSize=100000000 +VectorDelimiter= +QueryPath=/home/sosp/data/spacev_data/query.i8bin +QueryType=DEFAULT +QuerySize=29316 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=29316 +WarmupDelimiter= +TruthPath=/home/sosp/data/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_spacev100m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=4 +DataBlockSize=20000000 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +InternalResultNum=64 +NumberOfThreads=40 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_spacev100m/tmpdir +UseSPDK=true +SpdkBatchSize=256 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +Days=-1 +SteadyState=true +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +TruthFilePrefix=/home/spsp/data/ +FullVectorPath=/home/sosp/data/spacev_data/spacev200m_base.i8bin +DisableReassign=false +ReassignK=64 +LatencyLimit=10.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=true +UpdateFilePrefix=/home/sosp/data/spacev_data/spacev100m_update_trace +SearchResult=/home/sosp/result_overall_spacev_spfresh +DeleteQPS=800 +MergeThreshold=10 +Sampling=4 +EndVectorNum=200000000 \ No newline at end of file diff --git a/Script_AE/iniFile/store_spacev100m/indexloader_spann.ini b/Script_AE/iniFile/store_spacev100m/indexloader_spann.ini new file mode 100644 index 000000000..a7af53395 --- /dev/null +++ b/Script_AE/iniFile/store_spacev100m/indexloader_spann.ini @@ -0,0 +1,132 @@ +[Index] +IndexAlgoType=SPANN +ValueType=Int8 + +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=100 +VectorPath=/home/sosp/data/spacev_data/spacev100m_base.i8bin +VectorType=DEFAULT +VectorSize=100000000 +VectorDelimiter= +QueryPath=/home/sosp/data/spacev_data/query.i8bin +QueryType=DEFAULT +QuerySize=29316 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=29316 +WarmupDelimiter= +TruthPath=/home/sosp/data/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_spacev100m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=4 +DataBlockSize=20000000 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +InternalResultNum=64 +NumberOfThreads=40 +ReplicaCount=8 +PostingPageLimit=900 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_spacev100m/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +Days=100 +SteadyState=true +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +TruthFilePrefix=/home/spsp/data/ +FullVectorPath=/home/sosp/data/spacev_data/spacev200m_base.i8bin +DisableReassign=true +ReassignK=64 +LatencyLimit=10.0 +CalTruth=true +SearchPostingPageLimit=900 +MaxDistRatio=1000000 +SearchDuringUpdate=true +UpdateFilePrefix=/home/sosp/data/spacev_data/spacev100m_update_trace +SearchResult=/home/sosp/result_overall_spacev_spann +DeleteQPS=800 +MergeThreshold=-1 +Sampling=4 +EndVectorNum=200000000 +InPlace=true +ShowUpdateProgress=false +BufferLength=6 \ No newline at end of file diff --git a/Script_AE/iniFile/store_spacev100m/indexloader_spfresh.ini b/Script_AE/iniFile/store_spacev100m/indexloader_spfresh.ini new file mode 100644 index 000000000..e7320e6f6 --- /dev/null +++ b/Script_AE/iniFile/store_spacev100m/indexloader_spfresh.ini @@ -0,0 +1,131 @@ +[Index] +IndexAlgoType=SPANN +ValueType=Int8 + +[Base] +ValueType=Int8 +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=100 +VectorPath=/home/sosp/data/spacev_data/spacev100m_base.i8bin +VectorType=DEFAULT +VectorSize=100000000 +VectorDelimiter= +QueryPath=/home/sosp/data/spacev_data/query.i8bin +QueryType=DEFAULT +QuerySize=29316 +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=29316 +WarmupDelimiter= +TruthPath=/home/sosp/data/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_UInt8_L2_base_DEFUALT.bin +HeadVectors=head_vectors_UInt8_L2_base_DEFUALT.bin +IndexDirectory=/home/sosp/data/store_spacev100m +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=80 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=80 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=4 +DataBlockSize=20000000 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +InternalResultNum=64 +NumberOfThreads=40 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=/home/sosp/data/store_spacev100m/tmpdir +UseSPDK=true +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=true +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +Days=100 +SteadyState=true +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +TruthFilePrefix=/home/spsp/data/ +FullVectorPath=/home/sosp/data/spacev_data/spacev200m_base.i8bin +DisableReassign=false +ReassignK=64 +LatencyLimit=10.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=true +UpdateFilePrefix=/home/sosp/data/spacev_data/spacev100m_update_trace +SearchResult=/home/sosp/result_overall_spacev_spfresh +DeleteQPS=800 +MergeThreshold=10 +Sampling=4 +EndVectorNum=200000000 +ShowUpdateProgress=false +BufferLength=6 \ No newline at end of file diff --git a/Test/CMakeLists.txt b/Test/CMakeLists.txt index ed1002763..52f4168a9 100644 --- a/Test/CMakeLists.txt +++ b/Test/CMakeLists.txt @@ -19,7 +19,7 @@ if (NOT LIBRARYONLY) message (FATAL_ERROR "Could not find Boost 1.67!") endif() - include_directories(${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Test) + include_directories(${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Test ${PROJECT_SOURCE_DIR}/ThirdParty/spdk/build/include) file(GLOB TEST_HDR_FILES ${PROJECT_SOURCE_DIR}/Test/inc/Test.h) file(GLOB TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/Test/src/*.cpp) diff --git a/Test/Test.vcxproj b/Test/Test.vcxproj index 8c93771e3..0db121f75 100644 --- a/Test/Test.vcxproj +++ b/Test/Test.vcxproj @@ -29,28 +29,29 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte + Spectre @@ -97,6 +98,7 @@ true /Zc:twoPhase- %(AdditionalOptions) stdcpp17 + Guard true @@ -104,6 +106,7 @@ Console + true @@ -152,46 +155,52 @@ + + + + + + + + - - - - - - - - - - + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + + + + + + + + + \ No newline at end of file diff --git a/Test/Test.vcxproj.filters b/Test/Test.vcxproj.filters index 2f5bd1ccb..c30f98344 100644 --- a/Test/Test.vcxproj.filters +++ b/Test/Test.vcxproj.filters @@ -42,6 +42,15 @@ Source Files + + Source Files + + + Source Files + + + Source Files + Source Files @@ -51,11 +60,29 @@ Source Files + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + Header Files + + Header Files + diff --git a/Test/WinRTTest/WinRTTest.cpp b/Test/WinRTTest/WinRTTest.cpp new file mode 100644 index 000000000..bff0f0264 --- /dev/null +++ b/Test/WinRTTest/WinRTTest.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include + +extern "C" __declspec(dllexport) winrt::SPTAG::LogLevel SPTAG_GetLoggerLevel() +{ + return winrt::SPTAG::LogLevel::Empty; +} + +using namespace winrt; +using namespace Windows::Security::Cryptography; +using namespace Windows::Data::Json; + +JsonObject CreateJsonObject(bool v) +{ + auto js = JsonObject{}; + js.Insert(L"value", JsonValue::CreateBooleanValue(v)); + return js; +} +JsonObject CreateJsonObject(std::wstring_view v) +{ + auto js = JsonObject{}; + js.Insert(L"value", JsonValue::CreateStringValue(v)); + return js; +} +template JsonObject CreateJsonObject(const wchar_t (&v)[N]) +{ + return CreateJsonObject(winrt::hstring(v)); +} +JsonObject CreateJsonObject(double v) +{ + auto js = JsonObject{}; + js.Insert(L"value", JsonValue::CreateNumberValue(v)); + return js; +} +JsonObject CreateJsonObject(nullptr_t) +{ + auto js = JsonObject{}; + js.Insert(L"value", JsonValue::CreateNullValue()); + return js; +} + +template winrt::array_view Serialize(const T &value) +{ + JsonObject json = CreateJsonObject(value); + auto str = json.Stringify(); + auto ibuffer = CryptographicBuffer::ConvertStringToBinary(str, BinaryStringEncoding::Utf16LE); + auto start = ibuffer.data(); + winrt::com_array _array(start, start + ibuffer.Length()); + return _array; +} + +winrt::hstring Deserialize(winrt::array_view v) +{ + std::wstring_view sv{reinterpret_cast(v.begin()), v.size() / sizeof(wchar_t)}; + std::wstring wstr{sv}; + JsonObject js; + if (JsonObject::TryParse(wstr, js)) + { + auto value = js.GetNamedValue(L"value"); + return value.Stringify(); + } + else + { + return winrt::hstring{wstr}; + } +} + +int main() +{ + SPTAG::AnnIndex idx; + using embedding_t = std::array; + auto b = CryptographicBuffer::ConvertStringToBinary(L"first one", BinaryStringEncoding::Utf16LE); + idx.AddWithMetadata(embedding_t{1, 0}, winrt::com_array(b.data(), b.data() + b.Length())); + idx.AddWithMetadata(embedding_t{0, 1, 0}, Serialize(L"second one")); + idx.AddWithMetadata(embedding_t{0, 0.5f, 0.7f, 0}, Serialize(3.14)); + idx.AddWithMetadata(embedding_t{0, 0.7f, 0.5f, 0}, Serialize(true)); + idx.AddWithMetadata(embedding_t{0, 0.7f, 0.8f, 0}, Serialize(L"fifth")); + + auto res = idx.Search(embedding_t{0.f, 0.99f, 0.01f}, 12); + for (const winrt::SPTAG::SearchResult &r : res) + { + std::wcout << Deserialize(r.Metadata()); + std::wcout << L" -- " << r.Distance() << L"\n"; + } + + auto folder = winrt::Windows::Storage::KnownFolders::DocumentsLibrary(); + auto file = + folder.CreateFileAsync(L"vector_index", winrt::Windows::Storage::CreationCollisionOption::ReplaceExisting) + .get(); + idx.Save(file); + + SPTAG::AnnIndex idx2; + idx2.Load(file); +} diff --git a/Test/WinRTTest/WinRTTest.vcxproj b/Test/WinRTTest/WinRTTest.vcxproj new file mode 100644 index 000000000..72457cda7 --- /dev/null +++ b/Test/WinRTTest/WinRTTest.vcxproj @@ -0,0 +1,166 @@ + + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 16.0 + Win32Proj + {c9df3099-a142-4aa7-b936-1541816a1f21} + WinRTTest + 10.0 + + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + + + + + + + + + + + + + + + + + + + false + + + false + + + false + + + false + + + + Level3 + true + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + Guard + ProgramDatabase + + + Console + + + + + Level3 + true + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + Guard + + + Console + true + true + + + + + Level3 + true + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + Guard + ProgramDatabase + + + Console + + + + + Level3 + true + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + Guard + + + Console + true + true + + + + + + + + + + + {8dc74c33-6e15-43ed-9300-2a140589e3da} + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + \ No newline at end of file diff --git a/Test/WinRTTest/WinRTTest.vcxproj.filters b/Test/WinRTTest/WinRTTest.vcxproj.filters new file mode 100644 index 000000000..902db81f8 --- /dev/null +++ b/Test/WinRTTest/WinRTTest.vcxproj.filters @@ -0,0 +1,25 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Source Files + + + + + + \ No newline at end of file diff --git a/Test/WinRTTest/packages.config b/Test/WinRTTest/packages.config new file mode 100644 index 000000000..3375d5bb6 --- /dev/null +++ b/Test/WinRTTest/packages.config @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/Test/cuda/buildssd_test.cu b/Test/cuda/buildssd_test.cu new file mode 100644 index 000000000..4bdb5c97c --- /dev/null +++ b/Test/cuda/buildssd_test.cu @@ -0,0 +1,128 @@ +#include "common.hxx" +#include "inc/Core/Common/cuda/TailNeighbors.hxx" + +template +int GPUBuildSSDTest(int rows, int metric, int iters); + +int GPUBuildSSDTest_All() { + + int errors = 0; + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting Cosine BuildSSD tests\n"); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::Cosine, 10); + + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting L2 BuildSSD tests\n"); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype...\n"); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + errors += GPUBuildSSDTest(10000, (int)DistMetric::L2, 10); + + return errors; +} + +template +int GPUBuildSSDTest(int rows, int metric, int iters) { + int errors = 0; + int num_heads = rows/10; + + // Create random data for head vectors + T* head_data = create_dataset(num_heads, dim); + T* d_head_data; + CUDA_CHECK(cudaMalloc(&d_head_data, dim*num_heads*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_head_data, head_data, dim*num_heads*sizeof(T), cudaMemcpyHostToDevice)); + PointSet h_head_ps; + h_head_ps.dim = dim; + h_head_ps.data = d_head_data; + PointSet* d_head_ps; + CUDA_CHECK(cudaMalloc(&d_head_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_head_ps, &h_head_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + // Create random data for tail vectors + T* tail_data = create_dataset(rows, dim); + T* d_tail_data; + CUDA_CHECK(cudaMalloc(&d_tail_data, dim*rows*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_tail_data, tail_data, dim*rows*sizeof(T), cudaMemcpyHostToDevice)); + PointSet h_tail_ps; + h_tail_ps.dim = dim; + h_tail_ps.data = d_tail_data; + PointSet* d_tail_ps; + CUDA_CHECK(cudaMalloc(&d_tail_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_tail_ps, &h_tail_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + DistPair* d_results; + CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(DistPair))); + int TPTlevels = (int)std::log2(num_heads/100); + + TPtree* h_tree = new TPtree; + h_tree->initialize(num_heads, TPTlevels, dim); + TPtree* d_tree; + CUDA_CHECK(cudaMalloc(&d_tree, sizeof(TPtree))); + + // Alloc memory for QuerySet structure + int* d_queryMem; + CUDA_CHECK(cudaMalloc(&d_queryMem, rows*sizeof(int) + 2*h_tree->num_leaves*sizeof(int))); + QueryGroup* d_queryGroups; + CUDA_CHECK(cudaMalloc(&d_queryGroups, sizeof(QueryGroup))); + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + for(int i=0; i(&h_tree, &d_head_ps, num_heads, TPTlevels, 1, &stream, 2, NULL); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(d_tree, h_tree, sizeof(TPtree), cudaMemcpyHostToDevice)); + + get_query_groups(d_queryGroups, d_tree, d_tail_ps, rows, (int)h_tree->num_leaves, d_queryMem, 1024, 32, dim); + CUDA_CHECK(cudaDeviceSynchronize()); + + if(metric == (int)DistMetric::Cosine) { + findTailNeighbors_selector<<<1024, 32, sizeof(DistPair)*32*K>>>(d_head_ps, d_tail_ps, d_tree, K, d_results, rows, num_heads, d_queryGroups, dim); + } + else { + findTailNeighbors_selector<<<1024, 32, sizeof(DistPair)*32*K>>>(d_head_ps, d_tail_ps, d_tree, K, d_results, rows, num_heads, d_queryGroups, dim); + } + CUDA_CHECK(cudaDeviceSynchronize()); + } + + CUDA_CHECK(cudaFree(d_head_data)); + CUDA_CHECK(cudaFree(d_head_ps)); + CUDA_CHECK(cudaFree(d_tail_data)); + CUDA_CHECK(cudaFree(d_tail_ps)); + CUDA_CHECK(cudaFree(d_results)); + h_tree->destroy(); + CUDA_CHECK(cudaFree(d_tree)); + + return errors; +} + diff --git a/Test/cuda/common.hxx b/Test/cuda/common.hxx new file mode 100644 index 000000000..b0cee514e --- /dev/null +++ b/Test/cuda/common.hxx @@ -0,0 +1,59 @@ +//#include "inc/Core/Common/cuda/KNN.hxx" +#include +#include + +#define CHECK_ERRS(errs) \ + if(errs > 0) { \ + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "%d errors found\n", errs); \ + } + +#define CHECK_VAL(val,exp,errs) \ + if(val != exp) { \ + errs++; \ + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s != %s\n",#val,#exp); \ + } + +#define CHECK_VAL_LT(val,exp,errs) \ + if(val > exp) { \ + errs++; \ + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s > %s\n",#val,#exp); \ + } + +#define GPU_CHECK_VAL(val,exp,dtype,errs) \ + dtype temp; \ + CUDA_CHECK(cudaMemcpy(&temp, val, sizeof(dtype), cudaMemcpyDeviceToHost)); \ + float eps = 0.01; \ + if((float)temp>0.0 && ((float)temp*(1.0+eps) < (float)(exp) || (float)temp*(1.0-eps) > (float)(exp))) { \ + errs++; \ + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s != %s\n",#val,#exp); \ + } + + +template +T* create_dataset(size_t rows, int dim) { + + srand(0); + T* h_data = new T[rows*dim]; + for(size_t i=0; i::value) { + h_data[i] = (rand()/(float)RAND_MAX); + } + else if(std::is_same::value) { + h_data[i] = static_cast((rand()%INT_MAX)); + } + else if(std::is_same::value) { + h_data[i] = static_cast((rand()%127)); + } + else if(std::is_same::value) { + h_data[i] = static_cast((rand()%127)); + } + } + return h_data; +} +/* +__global__ void count_leaf_sizes(LeafNode* leafs, int* node_ids, int N, int internal_nodes); +__global__ void assign_leaf_points_in_batch(LeafNode* leafs, int* leaf_points, int* node_ids, int N, int internal_nodes, int min_id, int max_id); +__global__ void assign_leaf_points_out_batch(LeafNode* leafs, int* leaf_points, int* node_ids, int N, int internal_nodes, int min_id, int max_id); +__global__ void compute_mean(KEYTYPE* split_keys, int* node_sizes, int num_nodes); +__global__ void initialize_rands(curandState* states, int iter); +*/ diff --git a/Test/cuda/cuda_tests.cpp b/Test/cuda/cuda_tests.cpp new file mode 100644 index 000000000..40c6cfd4c --- /dev/null +++ b/Test/cuda/cuda_tests.cpp @@ -0,0 +1,44 @@ +//#include "test_kernels.cu" + +#define BOOST_TEST_MODULE GPU + +#include +#include + +#include +#include +#include + +int GPUBuildKNNTest(); + +BOOST_AUTO_TEST_CASE(RandomTests) +{ + BOOST_CHECK(1 == 1); + + int errors = GPUBuildKNNTest(); + printf("outside\n"); + BOOST_CHECK(errors == 0); +} + +/* +int GPUTestDistance_All(); + +BOOST_AUTO_TEST_CASE(DistanceTests) { + int errs = GPUTestDistance_All(); + BOOST_CHECK(errs == 0); +} + +int GPUBuildTPTTest(); + +BOOST_AUTO_TEST_CASE(TPTreeTests) { + int errs = GPUBuildTPTTest(); + BOOST_CHECK(errs == 0); +} + +int GPUBuildSSDTest_All(); + +BOOST_AUTO_TEST_CASE(BuildSSDTests) { + int errs = GPUBuildSSDTest_All(); + BOOST_CHECK(errs == 0); +} +*/ diff --git a/Test/cuda/distance_tests.cu b/Test/cuda/distance_tests.cu new file mode 100644 index 000000000..2129272e6 --- /dev/null +++ b/Test/cuda/distance_tests.cu @@ -0,0 +1,202 @@ +#include "common.hxx" +#include "inc/Core/Common/cuda/KNN.hxx" + +#define GPU_CHECK_CORRECT(a,b,msg) \ + if(a != (SUMTYPE)b) { \ + printf(msg); \ + errs = 0; \ + return; \ + } + + +template +__global__ void GPUTestDistancesKernelStatic(PointSet* ps, int errs) { + SUMTYPE ab, ac, bc; +// Expected results + SUMTYPE l2_res[3] = {Dim, 4*Dim, Dim}; + SUMTYPE cosine_res[3] = {BASE(), BASE(), BASE()-2*Dim}; + + // l2 dist + ab = dist(ps->getVec(0), ps->getVec(1)); + ac = dist(ps->getVec(0), ps->getVec(2)); + bc = dist(ps->getVec(1), ps->getVec(2)); + GPU_CHECK_CORRECT(ab,l2_res[0],"Static L2 distance check failed\n"); + GPU_CHECK_CORRECT(ac,l2_res[1], "Static L2 distance check failed\n"); + GPU_CHECK_CORRECT(bc,l2_res[2], "Static L2 distance check failed\n"); + + // cosine dist + ab = dist(ps->getVec(0), ps->getVec(1)); + ac = dist(ps->getVec(0), ps->getVec(2)); + bc = dist(ps->getVec(1), ps->getVec(2)); + GPU_CHECK_CORRECT(ab,cosine_res[0],"Static Cosine distance check failed\n"); + GPU_CHECK_CORRECT(ac,cosine_res[1],"Static Cosine distance check failed\n"); + GPU_CHECK_CORRECT(bc,cosine_res[2],"Static Cosine distance check failed\n"); + +} + +template +int GPUTestDistancesSimple() { + T* h_data = new T[dim*3]; + for(int i=0; i<3; ++i) { + for(int j=0; j h_ps; + h_ps.dim = dim; + h_ps.data = d_data; + PointSet* d_ps; + CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + + int errs=0; + GPUTestDistancesKernelStatic<<<1, 1>>>(d_ps, errs); // TODO - make sure result is correct returned + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaFree(d_data)); + CUDA_CHECK(cudaFree(d_ps)); + return errs; +} + + +template +__global__ void GPUTestDistancesRandomKernel(PointSet* ps, int vecs, SUMTYPE* cosine_dists, SUMTYPE* l2_dists, int errs) { + SUMTYPE cos, l2; + SUMTYPE diff; + float eps = 0.001; // exact distances are not expected, but should be within an epsilon + for(int i=blockIdx.x*blockDim.x + threadIdx.x; i(ps->getVec(i), ps->getVec(j)); + if( (((float)(cos - (SUMTYPE)cosine_dists[i*vecs+j]))/(float)cos) > eps) { + printf("Error, cosine calculations differ by too much i:%d, j:%d - GPU:%u, CPU:%u, diff:%f\n", i, j, cos, (SUMTYPE)cosine_dists[i*vecs+j], (((float)(cos - (SUMTYPE)cosine_dists[i*vecs+j]))/(float)cos)); +return; + } + + l2 = dist(ps->getVec(i), ps->getVec(j)); + if( (((float)(l2 - (SUMTYPE)l2_dists[i*vecs+j]))/(float)l2) > eps) { + printf("Error, l2 calculations differ by too much i:%d, j:%d - GPU:%d, CPU:%d\n", i, j, l2, (int)l2_dists[i*vecs+j]); +return; + } + } + } +} + +template +int GPUTestDistancesComplex(int vecs) { + + srand(time(NULL)); + T* h_data = new T[vecs*dim]; + for(int i=0; i::value) { + h_data[i*dim+j] = (rand()/(float)RAND_MAX); + } + else if(std::is_same::value) { + h_data[i*dim+j] = static_cast((rand()%INT_MAX)); + } + else if(std::is_same::value) { + h_data[i*dim+j] = static_cast((rand()%127)); + } + else if(std::is_same::value) { + h_data[i*dim+j] = static_cast((rand()%127)); + } + } + } + + // Compute CPU distances to verify with GPU metric + SUMTYPE* cpu_cosine_dists = new SUMTYPE[vecs*vecs]; + SUMTYPE* cpu_l2_dists = new SUMTYPE[vecs*vecs]; + for(int i=0; i(&h_data[i*dim], &h_data[j*dim], dim)); + cpu_l2_dists[i*vecs+j] = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeL2Distance(&h_data[i*dim], &h_data[j*dim], dim)); + } + } + int errs=0; + + T* d_data; + CUDA_CHECK(cudaMalloc(&d_data, dim*vecs*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_data, h_data, dim*vecs*sizeof(T), cudaMemcpyHostToDevice)); + + PointSet h_ps; + h_ps.dim = dim; + h_ps.data = d_data; + PointSet* d_ps; + CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + + SUMTYPE* d_cosine_dists; + CUDA_CHECK(cudaMalloc(&d_cosine_dists, vecs*vecs*sizeof(SUMTYPE))); + CUDA_CHECK(cudaMemcpy(d_cosine_dists, cpu_cosine_dists, vecs*vecs*sizeof(SUMTYPE), cudaMemcpyHostToDevice)); + SUMTYPE* d_l2_dists; + CUDA_CHECK(cudaMalloc(&d_l2_dists, vecs*vecs*sizeof(SUMTYPE))); + CUDA_CHECK(cudaMemcpy(d_l2_dists, cpu_l2_dists, vecs*vecs*sizeof(SUMTYPE), cudaMemcpyHostToDevice)); + + + GPUTestDistancesRandomKernel<<<1024, 32>>>(d_ps, vecs, d_cosine_dists, d_l2_dists, errs); // TODO - make sure result is correct returned + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaFree(d_data)); + CUDA_CHECK(cudaFree(d_ps)); + CUDA_CHECK(cudaFree(d_cosine_dists)); + CUDA_CHECK(cudaFree(d_l2_dists)); + + return errs; +} + +int GPUTestDistance_All() { + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Static distance tests...\n"); + // Test Distances with float datatype + int errs = 0; + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + + // Test distances with int datatype + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + + // Test distances with int8 datatype + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + errs += GPUTestDistancesSimple(); + + CHECK_ERRS(errs) + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Randomized vector distance tests...\n"); + // Test distances between random vectors and compare with CPU calculation + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + errs += GPUTestDistancesComplex(100); + + CHECK_ERRS(errs) + + return errs; +} diff --git a/Test/cuda/gpu_pq_perf.cu b/Test/cuda/gpu_pq_perf.cu new file mode 100644 index 000000000..1b96403d2 --- /dev/null +++ b/Test/cuda/gpu_pq_perf.cu @@ -0,0 +1,135 @@ + +#include "common.hxx" +#include "inc/Core/Common/cuda/KNN.hxx" + +template +__global__ void top1_nopq_kernel(T* data, int* res_idx, float* results, int datasize) { + float temp; + for(int i=blockIdx.x*blockDim.x + threadIdx.x; i(); + for(int j=0; j(&data[i*Dim], &data[j*Dim]); + if(temp < results[i]) { + results[i] = temp; + res_idx[i] = j; + } + } + } + } +} + +#define TEST_BLOCKS 512 +#define TEST_THREADS 32 + +template +void GPU_top1_nopq(std::shared_ptr& real_vecset, DistCalcMethod distMethod, int* res_idx, float* res_dist) { + + R* d_data; + CUDA_CHECK(cudaMalloc(&d_data, real_vecset->Count()*real_vecset->Dimension()*sizeof(R))); + CUDA_CHECK(cudaMemcpy(d_data, reinterpret_cast(real_vecset->GetVector(0)), real_vecset->Count()*real_vecset->Dimension()*sizeof(R), cudaMemcpyHostToDevice)); + +// float* nearest = new float[real_vecset->Count()]; + float* d_res_dist; + CUDA_CHECK(cudaMalloc(&d_res_dist, real_vecset->Count()*sizeof(float))); + int* d_res_idx; + CUDA_CHECK(cudaMalloc(&d_res_idx, real_vecset->Count()*sizeof(int))); + + // Run kernel that performs + // Create options for different dims and metrics + if(real_vecset->Dimension() == 256) { + top1_nopq_kernel<<>>(d_data, d_res_idx, d_res_dist, real_vecset->Count()); + } + else { + printf("Add support for testing with %d dimensions\n", real_vecset->Dimension()); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(res_idx, d_res_idx, real_vecset->Count()*sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(res_dist, d_res_dist, real_vecset->Count()*sizeof(float), cudaMemcpyDeviceToHost)); + // Optional check results for debugging + +} + +void GPU_nopq_alltype(std::shared_ptr& real_vecset, DistCalcMethod distMethod, int* res_idx, float* res_dist) { + GPU_top1_nopq(real_vecset, distMethod, res_idx, res_dist); +} + + +template +__global__ void top1_pq_kernel(uint8_t* data, int* res_idx, float* res_dist, int datasize, GPU_Quantizer* quantizer) { + + if(threadIdx.x==0 && blockIdx.x==0) { + printf("Quantizer numSub:%d, KsPerSub:%ld, BlockSize:%ld, dimPerSub:%ld\n", quantizer->m_NumSubvectors, quantizer->m_KsPerSubvector, quantizer->m_BlockSize, quantizer->m_DimPerSubvector); + printf("dataSize:%d, data: %u, %u\n", datasize, data[0] & 255, data[1] & 255); + } + float temp; + for(int i=blockIdx.x*blockDim.x + threadIdx.x; i(); + for(int j=0; jdist(&data[i*Dim], &data[j*Dim]); + if(temp < res_dist[i]) { + res_dist[i] = temp; + res_idx[i] = j; + if(i==0) printf("j:%d, dist:%f\n", j, temp); + } + } + } + } +} + +template +void GPU_top1_pq(std::shared_ptr& real_vecset, std::shared_ptr& quan_vecset,DistCalcMethod distMethod, std::shared_ptr& quantizer, int* res_idx, float* res_dist) { + + printf("Running GPU PQ - PQ dims:%d\n", quan_vecset->Dimension()); + + std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType()); + vecIndex->SetQuantizer(quantizer); + vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod)); + + GPU_Quantizer* d_quantizer = NULL; + GPU_Quantizer* h_quantizer = NULL; + +printf("QTYPE:%d\n", (int)(quantizer->GetQuantizerType())); + +printf("Creating GPU_Quantizer\n"); + h_quantizer = new GPU_Quantizer(quantizer, DistMetric::L2); // TODO - add other metric option + CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_Quantizer))); + CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_Quantizer), cudaMemcpyHostToDevice)); + + uint8_t* d_data; + CUDA_CHECK(cudaMalloc(&d_data, quan_vecset->Count()*quan_vecset->Dimension()*sizeof(uint8_t))); + CUDA_CHECK(cudaMemcpy(d_data, reinterpret_cast(quan_vecset->GetVector(0)), quan_vecset->Count()*quan_vecset->Dimension()*sizeof(uint8_t), cudaMemcpyHostToDevice)); + +// float* nearest = new float[quan_vecset->Count()]; +// float* d_nearest; +// CUDA_CHECK(cudaMalloc(&d_nearest, quan_vecset->Count()*sizeof(float))); + + float* d_res_dist; + CUDA_CHECK(cudaMalloc(&d_res_dist, quan_vecset->Count()*sizeof(float))); + int* d_res_idx; + CUDA_CHECK(cudaMalloc(&d_res_idx, quan_vecset->Count()*sizeof(int))); + + // Run kernel that performs + // Create options for different dims and metrics + if(quan_vecset->Dimension() == 128) { + top1_pq_kernel<<>>(d_data, d_res_idx, d_res_dist, quan_vecset->Count(), d_quantizer); + } + else { + printf("Add support for testing with %d PQ dimensions\n", quan_vecset->Dimension()); + } + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(res_idx, d_res_idx, quan_vecset->Count()*sizeof(int), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(res_dist, d_res_dist, quan_vecset->Count()*sizeof(float), cudaMemcpyDeviceToHost)); + // Optional check results for debugging + +} + + +void GPU_pq_alltype(std::shared_ptr& real_vecset, std::shared_ptr& quan_vecset, DistCalcMethod distMethod, std::shared_ptr& quantizer, int* res_idx, float* res_dist) { + GPU_top1_pq(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist); + +} + diff --git a/Test/cuda/knn_tests.cu b/Test/cuda/knn_tests.cu new file mode 100644 index 000000000..f191c7aa2 --- /dev/null +++ b/Test/cuda/knn_tests.cu @@ -0,0 +1,255 @@ +#include "common.hxx" +#include "inc/Core/Common/cuda/KNN.hxx" + +template +__global__ void test_KNN(PointSet* ps, int* results, int rows, int K) { + + extern __shared__ char sharememory[]; + DistPair* threadList = (&((DistPair*)sharememory)[K*threadIdx.x]); + + T query[Dim]; + T candidate_vec[Dim]; + + DistPair target; + DistPair candidate; + DistPair temp; + + bool good; + SUMTYPE max_dist = INFTY(); + int read_id, write_id; + + SUMTYPE (*dist_comp)(T*,T*) = &dist; + + for(size_t i=blockIdx.x*blockDim.x + threadIdx.x; igetVec(i)[j]; + } + for(int k=0; k(); + } + + for(size_t j=0; jgetVec(j)); + + if(max_dist > candidate.dist) { + for(read_id=0; candidate.dist > threadList[read_id].dist && good; read_id++) { + if(violatesRNG(candidate_vec, ps->getVec(threadList[read_id].idx), candidate.dist, dist_comp)) { + good = false; + } + } + if(good) { + target = threadList[read_id]; + threadList[read_id] = candidate; + read_id++; + for(write_id = read_id; read_id < K && threadList[read_id].idx != -1; read_id++) { + if(!violatesRNG(ps->getVec(threadList[read_id].idx), candidate_vec, threadList[read_id].dist, dist_comp)) { + if(read_id == write_id) { + temp = threadList[read_id]; + threadList[write_id] = target; + target = temp; + } + else { + threadList[write_id] = target; + target = threadList[read_id]; + } + write_id++; + } + } + if(write_id < K) { + threadList[write_id] = target; + write_id++; + } + for(int k=write_id; k(); + threadList[k].idx = -1; + } + max_dist = threadList[K-1].dist; + } + + } + + } + for(size_t j=0; j +int GPUBuildKNNCosineTest(int rows) { + + T* data = create_dataset(rows, dim); + T* d_data; + CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice)); + + int* d_results; + CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(int))); + + PointSet h_ps; + h_ps.dim = dim; + h_ps.data = d_data; + + PointSet* d_ps; + + CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + test_KNN<<<1024, 64, K*64*sizeof(DistPair)>>>(d_ps, d_results, rows, K); + CUDA_CHECK(cudaDeviceSynchronize()); + + int* h_results = new int[K*rows]; + CUDA_CHECK(cudaMemcpy(h_results, d_results, rows*K*sizeof(int), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaDeviceSynchronize()); + +// Verify that the neighbor list of each vector is ordered correctly + for(int i=0; i(&data[i*dim], &data[neighborId*dim], dim)); + SUMTYPE nextDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeCosineDistance(&data[i*dim], &data[nextNeighborId*dim], dim)); + if(neighborDist > nextDist) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Neighbor list not in ascending distance order. i:%d, neighbor:%d (dist:%f), next:%d (dist:%f)\n", i, neighborId, neighborDist, nextNeighborId, nextDist); + return 1; + } + } + } + } + + CUDA_CHECK(cudaFree(d_data)); + CUDA_CHECK(cudaFree(d_results)); + CUDA_CHECK(cudaFree(d_ps)); + + return 0; +} + +int GPUBuildKNNCosineTest() { + + int errors = 0; + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype tests...\n"); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int32 datatype tests...\n"); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype tests...\n"); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + errors += GPUBuildKNNCosineTest(1000); + + return errors; +} + +template +int GPUBuildKNNL2Test(int rows) { + + T* data = create_dataset(rows, dim); + T* d_data; + CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice)); + + int* d_results; + CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(int))); + + PointSet h_ps; + h_ps.dim = dim; + h_ps.data = d_data; + + PointSet* d_ps; + + CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + test_KNN<<<1024, 64, K*64*sizeof(DistPair)>>>(d_ps, d_results, rows, K); + CUDA_CHECK(cudaDeviceSynchronize()); + + int* h_results = new int[K*rows]; + CUDA_CHECK(cudaMemcpy(h_results, d_results, rows*K*sizeof(int), cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaDeviceSynchronize()); + +// Verify that the neighbor list of each vector is ordered correctly + for(int i=0; i(&data[i*dim], &data[neighborId*dim], dim)); + SUMTYPE nextDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeL2Distance(&data[i*dim], &data[nextNeighborId*dim], dim)); + if(neighborDist > nextDist) { + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Neighbor list not in ascending distance order. i:%d, neighbor:%d (dist:%f), next:%d (dist:%f)\n", i, neighborId, neighborDist, nextNeighborId, nextDist); + return 1; + } + } + } + } + CUDA_CHECK(cudaFree(d_data)); + CUDA_CHECK(cudaFree(d_results)); + CUDA_CHECK(cudaFree(d_ps)); + + return 0; +} + +int GPUBuildKNNL2Test() { + + int errors = 0; + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype tests...\n"); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + + CHECK_ERRS(errors) + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int32 datatype tests...\n"); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + + CHECK_ERRS(errors) + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype tests...\n"); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + errors += GPUBuildKNNL2Test(1000); + + CHECK_ERRS(errors) +} + +int GPUBuildKNNTest() { + int errors = 0; + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting KNN Cosine metric tests\n"); + errors += GPUBuildKNNCosineTest(); + + CHECK_ERRS(errors) + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting KNN L2 metric tests\n"); + errors += GPUBuildKNNL2Test(); + + CHECK_ERRS(errors) + + return errors; +} + +int GPUBuildTPTreeTest(); diff --git a/Test/cuda/pq_perf.cpp b/Test/cuda/pq_perf.cpp new file mode 100644 index 000000000..c256f3b12 --- /dev/null +++ b/Test/cuda/pq_perf.cpp @@ -0,0 +1,672 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//#include "../inc/Test.h" + +#define BOOST_TEST_MODULE GPU + +#include +#include +#include +#include + +#include + +#include "inc/Core/Common/cuda/params.h" + +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/Common/DistanceUtils.h" +#include "inc/Core/Common/PQQuantizer.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/VectorSetReader.h" +#include +#include +#include +#include +#include + +using namespace SPTAG; + +/* +template +void Search(std::shared_ptr& vecIndex, std::shared_ptr& queryset, int k, +std::shared_ptr& truth) +{ + std::vector> res(queryset->Count(), SPTAG::COMMON::QueryResultSet(nullptr, k * +2)); auto t1 = std::chrono::high_resolution_clock::now(); for (SizeType i = 0; i < queryset->Count(); i++) + { + res[i].Reset(); + res[i].SetTarget((const T*)queryset->GetVector(i), vecIndex->m_pQuantizer); + vecIndex->SearchIndex(res[i]); + } + auto t2 = std::chrono::high_resolution_clock::now(); + std::cout << "Search time: " << (std::chrono::duration_cast(t2 - t1).count() / +(float)(queryset->Count())) << "us" << std::endl; + + float eps = 1e-6f, recall = 0; + + int truthDimension = min(k, truth->Dimension()); + for (SizeType i = 0; i < queryset->Count(); i++) { + SizeType* nn = (SizeType*)(truth->GetVector(i)); + std::vector visited(2 * k, false); + for (int j = 0; j < truthDimension; j++) { + float truthdist = vecIndex->ComputeDistance(res[i].GetQuantizedTarget(), vecIndex->GetSample(nn[j])); + for (int l = 0; l < k*2; l++) { + if (visited[l]) continue; + + //std::cout << res[i].GetResult(l)->Dist << " " << truthdist << std::endl; + if (res[i].GetResult(l)->VID == nn[j]) { + recall += 1.0; + visited[l] = true; + break; + } + else if (fabs(res[i].GetResult(l)->Dist - truthdist) <= eps) { + recall += 1.0; + visited[l] = true; + break; + } + } + } + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", truthDimension, k*2, recall / queryset->Count() / +truthDimension); +} +*/ + +/* +template +std::shared_ptr PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, +std::shared_ptr& meta, std::shared_ptr& queryset, int k, std::shared_ptr& truth, +std::string out, std::shared_ptr quantizer) +{ + std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + vecIndex->SetQuantizer(quantizer); + BOOST_CHECK(nullptr != vecIndex); + + if (algo == IndexAlgoType::KDT) vecIndex->SetParameter("KDTNumber", "2"); + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "12"); + vecIndex->SetParameter("RefineIterations", "3"); + vecIndex->SetParameter("MaxCheck", "4096"); + vecIndex->SetParameter("MaxCheckForRefineGraph", "8192"); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); + //Search(vecIndex, queryset, k, truth); + return vecIndex; +} +*/ + +template +void GenerateReconstructData(std::shared_ptr &real_vecset, std::shared_ptr &rec_vecset, + std::shared_ptr &quan_vecset, std::shared_ptr &metaset, + std::shared_ptr &queryset, std::shared_ptr &truth, + DistCalcMethod distCalcMethod, int k, std::shared_ptr &quantizer) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1000, 1000); + int n = 20000, q = 200; + int m = 256; + int M = 128; + int Ks = 256; + int QuanDim = m / M; + std::string CODEBOOK_FILE = "quantest_quantizer.bin"; + + if (fileexists("quantest_vector.bin") && fileexists("quantest_query.bin")) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(options); + if (ErrorCode::Success != vectorReader->LoadFile("quantest_vector.bin")) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + exit(1); + } + real_vecset = vectorReader->GetVectorSet(); + + if (ErrorCode::Success != vectorReader->LoadFile("quantest_query.bin")) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + exit(1); + } + queryset = vectorReader->GetVectorSet(); + } + else + { + ByteArray real_vec = ByteArray::Alloc(sizeof(R) * n * m); + for (int i = 0; i < n * m; i++) + { + ((R *)real_vec.Data())[i] = (R)(dist(gen)); + } + real_vecset.reset(new BasicVectorSet(real_vec, GetEnumValueType(), m, n)); + // real_vecset->Save("quantest_vector.bin"); + + ByteArray real_query = ByteArray::Alloc(sizeof(R) * q * m); + for (int i = 0; i < q * m; i++) + { + ((R *)real_query.Data())[i] = (R)(dist(gen)); + } + queryset.reset(new BasicVectorSet(real_query, GetEnumValueType(), m, q)); + // queryset->Save("quantest_query.bin"); + } + + if (fileexists(("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)).c_str())) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), k, VectorFileType::DEFAULT)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(options); + if (ErrorCode::Success != + vectorReader->LoadFile("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod))) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n"); + exit(1); + } + truth = vectorReader->GetVectorSet(); + } + else + { + omp_set_num_threads(5); + + ByteArray tru = ByteArray::Alloc(sizeof(SizeType) * queryset->Count() * k); + +#pragma omp parallel for + for (SizeType i = 0; i < queryset->Count(); ++i) + { + SizeType *neighbors = ((SizeType *)tru.Data()) + i * k; + + COMMON::QueryResultSet res((const R *)queryset->GetVector(i), k); + for (SizeType j = 0; j < real_vecset->Count(); j++) + { + float dist = COMMON::DistanceUtils::ComputeDistance(res.GetTarget(), + reinterpret_cast(real_vecset->GetVector(j)), + queryset->Dimension(), distCalcMethod); + res.AddPoint(j, dist); + } + res.SortResult(); + for (int j = 0; j < k; j++) + neighbors[j] = res.GetResult(j)->VID; + } + truth.reset(new BasicVectorSet(tru, GetEnumValueType(), k, queryset->Count())); + // truth->Save("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)); + } + + if (fileexists(CODEBOOK_FILE.c_str()) && fileexists("quantest_quan_vector.bin") && + fileexists("quantest_rec_vector.bin")) + { + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) + { + BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); + } + quantizer->LoadIQuantizer(ptr); + BOOST_ASSERT(quantizer); + + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); + auto vectorReader = Helper::VectorSetReader::CreateInstance(options); + if (ErrorCode::Success != vectorReader->LoadFile("quantest_rec_vector.bin")) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + exit(1); + } + rec_vecset = vectorReader->GetVectorSet(); + + std::shared_ptr quanOptions( + new Helper::ReaderOptions(GetEnumValueType(), M, VectorFileType::DEFAULT)); + vectorReader = Helper::VectorSetReader::CreateInstance(quanOptions); + if (ErrorCode::Success != vectorReader->LoadFile("quantest_quan_vector.bin")) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + exit(1); + } + quan_vecset = vectorReader->GetVectorSet(); + } + else + { + omp_set_num_threads(16); + + std::cout << "Building codebooks!" << std::endl; + R *vecs = (R *)(real_vecset->GetData()); + + std::unique_ptr codebooks = std::make_unique(M * Ks * QuanDim); + std::unique_ptr belong(new int[n]); + for (int i = 0; i < M; i++) + { + R *kmeans = codebooks.get() + i * Ks * QuanDim; + for (int j = 0; j < Ks; j++) + { + std::memcpy(kmeans + j * QuanDim, vecs + j * m + i * QuanDim, sizeof(R) * QuanDim); + } + int cnt = 100; + while (cnt--) + { + // calculate cluster +#pragma omp parallel for + for (int ii = 0; ii < n; ii++) + { + double min_dis = 1e9; + int min_id = 0; + for (int jj = 0; jj < Ks; jj++) + { + double now_dis = COMMON::DistanceUtils::ComputeDistance( + vecs + ii * m + i * QuanDim, kmeans + jj * QuanDim, QuanDim, DistCalcMethod::L2); + if (now_dis < min_dis) + { + min_dis = now_dis; + min_id = jj; + } + } + belong[ii] = min_id; + } + // recalculate kmeans + std::memset(kmeans, 0, sizeof(R) * Ks * QuanDim); +#pragma omp parallel for + for (int ii = 0; ii < Ks; ii++) + { + int num = 0; + for (int jj = 0; jj < n; jj++) + { + if (belong[jj] == ii) + { + num++; + for (int kk = 0; kk < QuanDim; kk++) + { + kmeans[ii * QuanDim + kk] += vecs[jj * m + i * QuanDim + kk]; + } + } + } + for (int jj = 0; jj < QuanDim; jj++) + { + kmeans[ii * QuanDim + jj] /= num; + } + } + } + } + + std::cout << "Building Finish!" << std::endl; + quantizer = std::make_shared>(M, Ks, QuanDim, false, std::move(codebooks)); + + printf("After built, pq type:%d\n", (int)quantizer->GetQuantizerType()); + auto ptr = SPTAG::f_createIO(); + if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::out)) + { + BOOST_ASSERT("Canot Open CODEBOOK_FILE to write!" == "Error"); + } + // quantizer->SaveQuantizer(ptr); + ptr->ShutDown(); + + if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) + { + BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); + } + quantizer->LoadIQuantizer(ptr); + BOOST_ASSERT(quantizer); + + rec_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(R) * n * m), GetEnumValueType(), m, n)); + quan_vecset.reset( + new BasicVectorSet(ByteArray::Alloc(sizeof(std::uint8_t) * n * M), GetEnumValueType(), M, n)); + for (int i = 0; i < n; i++) + { + auto nvec = &vecs[i * m]; + quantizer->QuantizeVector(nvec, (uint8_t *)quan_vecset->GetVector(i)); + quantizer->ReconstructVector((uint8_t *)quan_vecset->GetVector(i), rec_vecset->GetVector(i)); + } + // quan_vecset->Save("quantest_quan_vector.bin"); + // rec_vecset->Save("quantest_rec_vector.bin"); + } +} + +void GPU_nopq_alltype(std::shared_ptr &real_vecset, DistCalcMethod distMethod, int *res_idx, + float *res_dist); + +void GPU_pq_alltype(std::shared_ptr &real_vecset, std::shared_ptr &quan_vecset, + DistCalcMethod distMethod, std::shared_ptr &quantizer, int *res_idx, + float *res_dist); + +bool DEBUG_REPORT = false; + +template +void CPU_top1_nopq(std::shared_ptr &real_vecset, DistCalcMethod distMethod, int numThreads, int *result, + float *res_dist, bool randomized) +{ + + float testDist; + // float* nearest = new float[real_vecset->Count()]; + + printf("CPU top1 nopq - real dim:%d\n", real_vecset->Dimension()); + + if (randomized) + { + srand(time(NULL)); +#pragma omp parallel for num_threads(numThreads) + for (int i = 0; i < real_vecset->Count(); ++i) + { + int idx; + res_dist[i] = std::numeric_limits::max(); + + for (int j = 0; j < real_vecset->Count(); ++j) + { + idx = rand() % real_vecset->Count(); + if (i == idx) + continue; + + testDist = COMMON::DistanceUtils::ComputeDistance(reinterpret_cast(real_vecset->GetVector(i)), + reinterpret_cast(real_vecset->GetVector(idx)), + real_vecset->Dimension(), distMethod); + if (testDist < res_dist[i]) + { + res_dist[i] = testDist; + result[i] = idx; + } + } + } + } + else + { +#pragma omp parallel for num_threads(numThreads) + for (int i = 0; i < real_vecset->Count(); ++i) + { + res_dist[i] = std::numeric_limits::max(); + + for (int j = 0; j < real_vecset->Count(); ++j) + { + if (i == j) + continue; + + testDist = COMMON::DistanceUtils::ComputeDistance(reinterpret_cast(real_vecset->GetVector(i)), + reinterpret_cast(real_vecset->GetVector(j)), + real_vecset->Dimension(), distMethod); + if (testDist < res_dist[i]) + { + res_dist[i] = testDist; + result[i] = j; + } + } + } + } + + /* + if(DEBUG_REPORT) { + for(int i=0; iCount(); ++i) { + printf("%f\n", nearest[i]); + } + } + */ +} + +template +void CPU_top1_pq(std::shared_ptr &real_vecset, std::shared_ptr &quan_vecset, + DistCalcMethod distMethod, std::shared_ptr &quantizer, int numThreads, + int *res_idx, float *res_dist, bool randomized) +{ + + float testDist; + // float* nearest = new float[real_vecset->Count()]; + + printf("CPU top1 PQ - PQ dim:%d, count:%d\n", quan_vecset->Dimension(), quan_vecset->Count()); + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType()); + vecIndex->SetQuantizer(quantizer); + vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod)); + + if (randomized) + { + int idx; +#pragma omp parallel for num_threads(numThreads) + for (int i = 0; i < quan_vecset->Count(); ++i) + { + res_dist[i] = std::numeric_limits::max(); + + for (int j = 0; j < quan_vecset->Count(); ++j) + { + idx = rand() % (quan_vecset->Count()); + if (i == idx) + continue; + testDist = vecIndex->ComputeDistance(quan_vecset->GetVector(i), quan_vecset->GetVector(idx)); + if (testDist < res_dist[i]) + { + res_dist[i] = testDist; + res_idx[i] = idx; + } + } + } + } + else + { +#pragma omp parallel for num_threads(numThreads) + for (int i = 0; i < quan_vecset->Count(); ++i) + { + res_dist[i] = std::numeric_limits::max(); + + for (int j = 0; j < quan_vecset->Count(); ++j) + { + if (i == j) + continue; + testDist = vecIndex->ComputeDistance(quan_vecset->GetVector(i), quan_vecset->GetVector(j)); + if (testDist < res_dist[i]) + { + res_dist[i] = testDist; + res_idx[i] = j; + } + } + } + } +} + +#define EPS 0.01 + +void verify_results(int *gt_idx, float *gt_dist, int *res_idx, float *res_dist, int N) +{ + // Verify both CPU baselines have same result + for (int i = 0; i < N; ++i) + { + if (gt_idx[i] != res_idx[i] || gt_dist[i] > res_dist[i] * (1 + EPS)) + { + printf("mismatch! gt:%d (%f), result:%d (%f)\n", gt_idx[i], gt_dist[i], res_idx[i], res_dist[i]); + } + } +} + +void compute_accuracy(int *gt_idx, float *gt_dist, int *res_idx, float *res_dist, int N) +{ + float matches = 0.0; + float dist_sum = 0.0; + float total_dists = 0.0; + for (int i = 0; i < N; ++i) + { + if (gt_idx[i] == res_idx[i]) + { + matches++; + } + total_dists += gt_dist[i]; + dist_sum += abs(gt_dist[i] - res_dist[i]); + } + printf("KNN accuracy:%0.3f, avg. distance error:%4.3e\n", matches / (float)N, dist_sum / total_dists); +} + +template void DistancePerfSuite(IndexAlgoType algo, DistCalcMethod distMethod) +{ + std::shared_ptr real_vecset, rec_vecset, quan_vecset, queryset, truth; + std::shared_ptr metaset; + std::shared_ptr quantizer; + GenerateReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, + quantizer); + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType()); + vecIndex->SetQuantizer(quantizer); + vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod)); + + // printf("Truth dimension:%d, count:%d\n", truth->Dimension(), truth->Count()); + + printf("numSubvec:%d\n", quantizer->GetNumSubvectors()); + printf("quantizer type:%d\n", (int)(quantizer->GetQuantizerType())); + + int *gt_idx = new int[real_vecset->Count()]; + float *gt_dist = new float[real_vecset->Count()]; + + int *res_idx = new int[real_vecset->Count()]; + float *res_dist = new float[real_vecset->Count()]; + + // BASELINE non-PQ perf timing: + // CPU distance comparison timing with non-quantized + auto start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_nopq(real_vecset, distMethod, 1, gt_idx, gt_dist, false); + auto end_t = std::chrono::high_resolution_clock::now(); + double CPU_baseline_t = GET_CHRONO_TIME(start_t, end_t); + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_nopq(real_vecset, distMethod, 16, res_idx, res_dist, false); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_parallel_t = GET_CHRONO_TIME(start_t, end_t); + + printf("Verifying correctness of CPU non-pq baseline results...\n"); + verify_results(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // GPU distance comparisons using non-quantized + start_t = std::chrono::high_resolution_clock::now(); + GPU_nopq_alltype(real_vecset, distMethod, res_idx, res_dist); + end_t = std::chrono::high_resolution_clock::now(); + double GPU_baseline_t = GET_CHRONO_TIME(start_t, end_t); + + printf("Verifying correctness of GPU non-pq baseline results...\n"); + verify_results(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // Time CPU all-to-all distance comparisons between quan_vecset + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_pq(real_vecset, quan_vecset, distMethod, quantizer, 1, res_idx, res_dist, false); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_PQ_t = GET_CHRONO_TIME(start_t, end_t); + printf("CPU PQ single-threaded\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_pq(real_vecset, quan_vecset, distMethod, quantizer, 16, res_idx, res_dist, false); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_PQ_parallel_t = GET_CHRONO_TIME(start_t, end_t); + printf("CPU PQ 16 threads\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // Time each GPU method of all-to-all distance between quan_vecset + start_t = std::chrono::high_resolution_clock::now(); + GPU_pq_alltype(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist); + end_t = std::chrono::high_resolution_clock::now(); + double GPU_PQ_t = GET_CHRONO_TIME(start_t, end_t); + printf("GPU PQ threads\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 1-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_baseline_t, + CPU_PQ_t); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 16-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_parallel_t, + CPU_PQ_parallel_t); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU time - baseline:%0.3lf, PQ:%0.3lf\n", GPU_baseline_t, GPU_PQ_t); + + // LoadReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10); + + /* + auto real_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), real_vecset, + metaset, queryset, 10, truth, "real_idx", nullptr); Search(real_idx, queryset, 10, truth); auto rec_idx = + PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), rec_vecset, metaset, queryset, + 10, truth, "rec_idx", nullptr); Search(rec_idx, queryset, 10, truth); auto quan_idx = + PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), quan_vecset, metaset, + queryset, 10, truth, "quan_idx", quantizer); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with SDC"); + Search(quan_idx, queryset, 10, truth); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with ADC"); + quan_idx->SetQuantizerADC(true); + Search(quan_idx, queryset, 10, truth); + */ +} + +template void DistancePerfRandomized(IndexAlgoType algo, DistCalcMethod distMethod) +{ + std::shared_ptr real_vecset, rec_vecset, quan_vecset, queryset, truth; + std::shared_ptr metaset; + std::shared_ptr quantizer; + GenerateReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, + quantizer); + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType()); + vecIndex->SetQuantizer(quantizer); + vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod)); + + // printf("Truth dimension:%d, count:%d\n", truth->Dimension(), truth->Count()); + + printf("numSubvec:%d\n", quantizer->GetNumSubvectors()); + printf("quantizer type:%d\n", (int)(quantizer->GetQuantizerType())); + + int *gt_idx = new int[real_vecset->Count()]; + float *gt_dist = new float[real_vecset->Count()]; + + int *res_idx = new int[real_vecset->Count()]; + float *res_dist = new float[real_vecset->Count()]; + + // BASELINE non-PQ perf timing: + // CPU distance comparison timing with non-quantized + auto start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_nopq(real_vecset, distMethod, 1, gt_idx, gt_dist, true); + auto end_t = std::chrono::high_resolution_clock::now(); + double CPU_baseline_t = GET_CHRONO_TIME(start_t, end_t); + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_nopq(real_vecset, distMethod, 16, res_idx, res_dist, true); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_parallel_t = GET_CHRONO_TIME(start_t, end_t); + + printf("Randomized accuracy between CPU runs...\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // GPU distance comparisons using non-quantized + start_t = std::chrono::high_resolution_clock::now(); + GPU_nopq_alltype(real_vecset, distMethod, res_idx, res_dist); + end_t = std::chrono::high_resolution_clock::now(); + double GPU_baseline_t = GET_CHRONO_TIME(start_t, end_t); + + printf("Accuracy of randomized GPU non-pq run\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // Time CPU all-to-all distance comparisons between quan_vecset + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_pq(real_vecset, quan_vecset, distMethod, quantizer, 1, res_idx, res_dist, true); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_PQ_t = GET_CHRONO_TIME(start_t, end_t); + printf("CPU PQ single-threaded\n"); + + start_t = std::chrono::high_resolution_clock::now(); + CPU_top1_pq(real_vecset, quan_vecset, distMethod, quantizer, 16, res_idx, res_dist, true); + end_t = std::chrono::high_resolution_clock::now(); + double CPU_PQ_parallel_t = GET_CHRONO_TIME(start_t, end_t); + printf("CPU PQ 16 threads\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + // Time each GPU method of all-to-all distance between quan_vecset + start_t = std::chrono::high_resolution_clock::now(); + GPU_pq_alltype(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist); + end_t = std::chrono::high_resolution_clock::now(); + double GPU_PQ_t = GET_CHRONO_TIME(start_t, end_t); + printf("GPU PQ threads\n"); + compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count()); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 1-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_baseline_t, + CPU_PQ_t); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 16-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_parallel_t, + CPU_PQ_parallel_t); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU time - baseline:%0.3lf, PQ:%0.3lf\n", GPU_baseline_t, GPU_PQ_t); +} + +BOOST_AUTO_TEST_SUITE(GPUPQPerfTest) + +BOOST_AUTO_TEST_CASE(GPUPQCosineTest) +{ + + DistancePerfSuite(IndexAlgoType::BKT, DistCalcMethod::L2); + + DistancePerfRandomized(IndexAlgoType::BKT, DistCalcMethod::L2); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/cuda/tptree_tests.cu b/Test/cuda/tptree_tests.cu new file mode 100644 index 000000000..fadbca8f5 --- /dev/null +++ b/Test/cuda/tptree_tests.cu @@ -0,0 +1,127 @@ +#include "common.hxx" +#include "inc/Core/Common/cuda/TPtree.hxx" + +template +int TPTKernelsTest(int rows) { + + + int errs = 0; + + T* data = create_dataset(rows, dim); + T* d_data; + CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T))); + CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice)); + + PointSet h_ps; + h_ps.dim = dim; + h_ps.data = d_data; + + PointSet* d_ps; + + CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet))); + CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet), cudaMemcpyHostToDevice)); + + int levels = (int)std::log2(rows/100); // TPT levels + TPtree* tptree = new TPtree; + tptree->initialize(rows, levels, dim); + + // Check that tptree structure properly initialized + CHECK_VAL(tptree->Dim,dim,errs) + CHECK_VAL(tptree->levels,levels,errs) + CHECK_VAL(tptree->N,rows,errs) + CHECK_VAL(tptree->num_leaves,pow(2,levels),errs) + + // Create TPT structure and random weights + KEYTYPE* h_weights = new KEYTYPE[tptree->levels*tptree->Dim]; + for(int i=0; ilevels*tptree->Dim; ++i) { + h_weights[i] = ((rand()%2)*2)-1; + } + + tptree->reset(); + CUDA_CHECK(cudaMemcpy(tptree->weight_list, h_weights, tptree->levels*tptree->Dim*sizeof(KEYTYPE), cudaMemcpyHostToDevice)); + + + curandState* states; + CUDA_CHECK(cudaMalloc(&states, 1024*32*sizeof(curandState))); + initialize_rands<<<1024,32>>>(states, 0); + + + int nodes_on_level=1; + for(int i=0; ilevels; ++i) { + + find_level_sum<<<1024,32>>>(d_ps, tptree->weight_list, dim, tptree->node_ids, tptree->split_keys, tptree->node_sizes, rows, nodes_on_level, i, rows); + CUDA_CHECK(cudaDeviceSynchronize()); + + float* split_key_sum = new float[nodes_on_level]; + CUDA_CHECK(cudaMemcpy(split_key_sum, &tptree->split_keys[nodes_on_level-1], nodes_on_level*sizeof(float), cudaMemcpyDeviceToHost)); // Copy the sum to compare with mean computed later + int* node_sizes = new int[nodes_on_level]; + CUDA_CHECK(cudaMemcpy(node_sizes, &tptree->node_sizes[nodes_on_level-1], nodes_on_level*sizeof(int), cudaMemcpyDeviceToHost)); + + compute_mean<<<1024, 32>>>(tptree->split_keys, tptree->node_sizes, tptree->num_nodes); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check the mean values for each node + for(int j=0; jsplit_keys[nodes_on_level-1+j],split_key_sum[j]/(float)(node_sizes[j]),float,errs) + } + + update_node_assignments<<<1024, 32>>>(d_ps, tptree->weight_list, tptree->node_ids, tptree->split_keys, tptree->node_sizes, rows, i, dim); + CUDA_CHECK(cudaDeviceSynchronize()); + + nodes_on_level *= 2; + } + count_leaf_sizes<<<1024, 32>>>(tptree->leafs, tptree->node_ids, rows, tptree->num_nodes - tptree->num_leaves); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check that total leaf node sizes equals total vectors + int total_leaf_sizes=0; + for(int j=0; jnum_leaves; ++j) { + LeafNode temp_leaf; + CUDA_CHECK(cudaMemcpy(&temp_leaf, &tptree->leafs[j], sizeof(LeafNode), cudaMemcpyDeviceToHost)); + total_leaf_sizes+=temp_leaf.size; + } + CHECK_VAL_LT(total_leaf_sizes,rows,errs) + + assign_leaf_points_out_batch<<<1024, 32>>>(tptree->leafs, tptree->leaf_points, tptree->node_ids, rows, tptree->num_nodes - tptree->num_leaves, 0, rows); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check that points were correctly assigned to leaf nodes + int* h_leaf_points = new int[rows]; + CUDA_CHECK(cudaMemcpy(h_leaf_points, tptree->leaf_points, rows*sizeof(int), cudaMemcpyDeviceToHost)); + for(int j=0; jnum_leaves,errs) + } + + CUDA_CHECK(cudaFree(d_data)); + CUDA_CHECK(cudaFree(d_ps)); + CUDA_CHECK(cudaFree(states)); + tptree->destroy(); + + return errs; +} + +int GPUBuildTPTTest() { + + int errors = 0; + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting TPTree Kernel tests\n"); + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n"); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + +// SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "int32 datatype...\n"); +// errors += TPTKernelsTest(1000); +// errors += TPTKernelsTest(1000); +// errors += TPTKernelsTest(1000); +// errors += TPTKernelsTest(1000); + + SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "int8 datatype...\n"); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + errors += TPTKernelsTest(1000); + + return errors; +} diff --git a/Test/inc/TestDataGenerator.h b/Test/inc/TestDataGenerator.h new file mode 100644 index 000000000..883df381c --- /dev/null +++ b/Test/inc/TestDataGenerator.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SPANN/Index.h" +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/Common/DistanceUtils.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/Helper/DiskIO.h" + +namespace TestUtils { + + template + class TestDataGenerator { + public: + TestDataGenerator(int n, int q, int m, int k, std::string distMethod); + + void Run(std::shared_ptr& vecset, + std::shared_ptr& metaset, + std::shared_ptr& queryset, + std::shared_ptr& truth, + std::shared_ptr& addvecset, + std::shared_ptr& addmetaset, + std::shared_ptr& addtruth); + + static std::shared_ptr GenerateRandomVectorSet(SPTAG::SizeType count, SPTAG::DimensionType dim); + + static std::shared_ptr GenerateMetadataSet(SPTAG::SizeType count, SPTAG::SizeType offsetStart); + + void RunBatches(std::shared_ptr &vecset, std::shared_ptr &metaset, + std::shared_ptr &addvecset, std::shared_ptr &addmetaset, + std::shared_ptr &queryset, int base, int batchinsert, int batchdelete, int batches, + std::shared_ptr &truths); + private: + int m_n, m_q, m_m, m_k; + std::string m_distMethod; + + bool FileExists(const std::string& filename); + + std::shared_ptr LoadReader(const std::string& filename); + + void LoadOrGenerateBase(std::shared_ptr& vecset, std::shared_ptr& metaset); + + void LoadOrGenerateQuery(std::shared_ptr& queryset); + + void LoadOrGenerateAdd(std::shared_ptr& addvecset, std::shared_ptr& addmetaset); + + void LoadOrGenerateTruth(const std::string& filename, + std::shared_ptr vecset, + std::shared_ptr queryset, + std::shared_ptr& truth, + bool normalize); + + void LoadOrGenerateBatchTruth(const std::string &filename, std::shared_ptr vecset, + std::shared_ptr queryset, + std::shared_ptr &truths, int base, int batchinsert, + int batchdelete, int batches, bool normalize); + + std::shared_ptr CombineVectorSets(std::shared_ptr base, + std::shared_ptr additional); + }; + +} // namespace TestUtils \ No newline at end of file diff --git a/Test/packages.config b/Test/packages.config index 8adb5ddbb..3c07fed8b 100644 --- a/Test/packages.config +++ b/Test/packages.config @@ -1,13 +1,12 @@  - - - - - - - - - - + + + + + + + + + \ No newline at end of file diff --git a/Test/src/AlgoTest.cpp b/Test/src/AlgoTest.cpp index a656c04da..6adcea67c 100644 --- a/Test/src/AlgoTest.cpp +++ b/Test/src/AlgoTest.cpp @@ -1,26 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" -#include "inc/Helper/SimpleIniReader.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" -#include #include +#include template -void Build(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +void Build(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) { - std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); - if (algo != SPTAG::IndexAlgoType::SPANN) { + if (algo != SPTAG::IndexAlgoType::SPANN) + { vecIndex->SetParameter("DistCalcMethod", distCalcMethod); vecIndex->SetParameter("NumberOfThreads", "16"); } - else { + else + { vecIndex->SetParameter("IndexAlgoType", "BKT", "Base"); vecIndex->SetParameter("DistCalcMethod", distCalcMethod, "Base"); @@ -46,17 +50,21 @@ void Build(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_pt } template -void BuildWithMetaMapping(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +void BuildWithMetaMapping(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) { - std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); - if (algo != SPTAG::IndexAlgoType::SPANN) { + if (algo != SPTAG::IndexAlgoType::SPANN) + { vecIndex->SetParameter("DistCalcMethod", distCalcMethod); vecIndex->SetParameter("NumberOfThreads", "16"); } - else { + else + { vecIndex->SetParameter("IndexAlgoType", "BKT", "Base"); vecIndex->SetParameter("DistCalcMethod", distCalcMethod, "Base"); @@ -81,22 +89,22 @@ void BuildWithMetaMapping(SPTAG::IndexAlgoType algo, std::string distCalcMethod, BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); } -template -void Search(const std::string folder, T* vec, SPTAG::SizeType n, int k, std::string* truthmeta) +template void Search(const std::string folder, T *vec, SPTAG::SizeType n, int k, std::string *truthmeta) { std::shared_ptr vecIndex; BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); BOOST_CHECK(nullptr != vecIndex); - for (SPTAG::SizeType i = 0; i < n; i++) + for (SPTAG::SizeType i = 0; i < n; i++) { SPTAG::QueryResult res(vec, k, true); vecIndex->SearchIndex(res); std::unordered_set resmeta; for (int j = 0; j < k; j++) { - resmeta.insert(std::string((char*)res.GetMetadata(j).Data(), res.GetMetadata(j).Length())); - std::cout << res.GetResult(j)->Dist << "@(" << res.GetResult(j)->VID << "," << std::string((char*)res.GetMetadata(j).Data(), res.GetMetadata(j).Length()) << ") "; + resmeta.insert(std::string((char *)res.GetMetadata(j).Data(), res.GetMetadata(j).Length())); + std::cout << res.GetResult(j)->Dist << "@(" << res.GetResult(j)->VID << "," + << std::string((char *)res.GetMetadata(j).Data(), res.GetMetadata(j).Length()) << ") "; } std::cout << std::endl; for (int j = 0; j < k; j++) @@ -109,7 +117,8 @@ void Search(const std::string folder, T* vec, SPTAG::SizeType n, int k, std::str } template -void Add(const std::string folder, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +void Add(const std::string folder, std::shared_ptr &vec, std::shared_ptr &meta, + const std::string out) { std::shared_ptr vecIndex; BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); @@ -121,65 +130,75 @@ void Add(const std::string folder, std::shared_ptr& vec, std:: } template -void AddOneByOne(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +void AddOneByOne(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) { - std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); vecIndex->SetParameter("DistCalcMethod", distCalcMethod); vecIndex->SetParameter("NumberOfThreads", "16"); - + auto t1 = std::chrono::high_resolution_clock::now(); - for (SPTAG::SizeType i = 0; i < vec->Count(); i++) { + for (SPTAG::SizeType i = 0; i < vec->Count(); i++) + { SPTAG::ByteArray metaarr = meta->GetMetadata(i); - std::uint64_t offset[2] = { 0, metaarr.Length() }; - std::shared_ptr metaset(new SPTAG::MemMetadataSet(metaarr, SPTAG::ByteArray((std::uint8_t*)offset, 2 * sizeof(std::uint64_t), false), 1)); + std::uint64_t offset[2] = {0, metaarr.Length()}; + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + metaarr, SPTAG::ByteArray((std::uint8_t *)offset, 2 * sizeof(std::uint64_t), false), 1)); SPTAG::ErrorCode ret = vecIndex->AddIndex(vec->GetVector(i), 1, vec->Dimension(), metaset, true); - if (SPTAG::ErrorCode::Success != ret) std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; + if (SPTAG::ErrorCode::Success != ret) + std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; } auto t2 = std::chrono::high_resolution_clock::now(); - std::cout << "AddIndex time: " << (std::chrono::duration_cast(t2 - t1).count() / (float)(vec->Count())) << "us" << std::endl; - + std::cout << "AddIndex time: " + << (std::chrono::duration_cast(t2 - t1).count() / (float)(vec->Count())) + << "us" << std::endl; + Sleep(10000); BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); } -template -void Delete(const std::string folder, T* vec, SPTAG::SizeType n, const std::string out) +template void Delete(const std::string folder, T *vec, SPTAG::SizeType n, const std::string out) { std::shared_ptr vecIndex; BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); BOOST_CHECK(nullptr != vecIndex); - BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->DeleteIndex((const void*)vec, n)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->DeleteIndex((const void *)vec, n)); BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); vecIndex.reset(); } -template -void Test(SPTAG::IndexAlgoType algo, std::string distCalcMethod) +template void Test(SPTAG::IndexAlgoType algo, std::string distCalcMethod) { SPTAG::SizeType n = 2000, q = 3; SPTAG::DimensionType m = 10; int k = 3; std::vector vec; - for (SPTAG::SizeType i = 0; i < n; i++) { - for (SPTAG::DimensionType j = 0; j < m; j++) { + for (SPTAG::SizeType i = 0; i < n; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { vec.push_back((T)i); } } - + std::vector query; - for (SPTAG::SizeType i = 0; i < q; i++) { - for (SPTAG::DimensionType j = 0; j < m; j++) { - query.push_back((T)i*2); + for (SPTAG::SizeType i = 0; i < q; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { + query.push_back((T)i * 2); } } - + std::vector meta; std::vector metaoffset; - for (SPTAG::SizeType i = 0; i < n; i++) { + for (SPTAG::SizeType i = 0; i < n; i++) + { metaoffset.push_back((std::uint64_t)meta.size()); std::string a = std::to_string(i); for (size_t j = 0; j < a.length(); j++) @@ -188,44 +207,44 @@ void Test(SPTAG::IndexAlgoType algo, std::string distCalcMethod) metaoffset.push_back((std::uint64_t)meta.size()); std::shared_ptr vecset(new SPTAG::BasicVectorSet( - SPTAG::ByteArray((std::uint8_t*)vec.data(), sizeof(T) * n * m, false), - SPTAG::GetEnumValueType(), m, n)); + SPTAG::ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), SPTAG::GetEnumValueType(), m, n)); std::shared_ptr metaset(new SPTAG::MemMetadataSet( - SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), - SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), - n)); - + SPTAG::ByteArray((std::uint8_t *)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t *)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), n)); + Build(algo, distCalcMethod, vecset, metaset, "testindices"); - std::string truthmeta1[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + std::string truthmeta1[] = {"0", "1", "2", "2", "1", "3", "4", "3", "5"}; Search("testindices", query.data(), q, k, truthmeta1); - if (algo != SPTAG::IndexAlgoType::SPANN) { + if (algo != SPTAG::IndexAlgoType::SPANN) + { Add("testindices", vecset, metaset, "testindices"); - std::string truthmeta2[] = { "0", "0", "1", "2", "2", "1", "4", "4", "3" }; + std::string truthmeta2[] = {"0", "0", "1", "2", "2", "1", "4", "4", "3"}; Search("testindices", query.data(), q, k, truthmeta2); Delete("testindices", query.data(), q, "testindices"); - std::string truthmeta3[] = { "1", "1", "3", "1", "3", "1", "3", "5", "3" }; + std::string truthmeta3[] = {"1", "1", "3", "1", "3", "1", "3", "5", "3"}; Search("testindices", query.data(), q, k, truthmeta3); } BuildWithMetaMapping(algo, distCalcMethod, vecset, metaset, "testindices"); - std::string truthmeta4[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + std::string truthmeta4[] = {"0", "1", "2", "2", "1", "3", "4", "3", "5"}; Search("testindices", query.data(), q, k, truthmeta4); - if (algo != SPTAG::IndexAlgoType::SPANN) { + if (algo != SPTAG::IndexAlgoType::SPANN) + { Add("testindices", vecset, metaset, "testindices"); - std::string truthmeta5[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; + std::string truthmeta5[] = {"0", "1", "2", "2", "1", "3", "4", "3", "5"}; Search("testindices", query.data(), q, k, truthmeta5); AddOneByOne(algo, distCalcMethod, vecset, metaset, "testindices"); - std::string truthmeta6[] = { "0", "1", "2", "2", "1", "3", "4", "3", "5" }; - Search("testindices", query.data(), q, k, truthmeta6); + std::string truthmeta6[] = {"0", "1", "2", "2", "1", "3", "4", "3", "5"}; + Search("testindices", query.data(), q, k, truthmeta6); } } -BOOST_AUTO_TEST_SUITE (AlgoTest) +BOOST_AUTO_TEST_SUITE(AlgoTest) BOOST_AUTO_TEST_CASE(KDTTest) { diff --git a/Test/src/Base64HelperTest.cpp b/Test/src/Base64HelperTest.cpp index 2ead4753e..1878ed011 100644 --- a/Test/src/Base64HelperTest.cpp +++ b/Test/src/Base64HelperTest.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" #include "inc/Helper/Base64Encode.h" +#include "inc/Test.h" #include diff --git a/Test/src/CommonHelperTest.cpp b/Test/src/CommonHelperTest.cpp index 17015642c..281bf0521 100644 --- a/Test/src/CommonHelperTest.cpp +++ b/Test/src/CommonHelperTest.cpp @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" #include "inc/Helper/CommonHelper.h" +#include "inc/Test.h" #include BOOST_AUTO_TEST_SUITE(CommonHelperTest) - BOOST_AUTO_TEST_CASE(ToLowerInPlaceTest) { - auto runTestCase = [](std::string p_input, const std::string& p_expected) - { + auto runTestCase = [](std::string p_input, const std::string &p_expected) { SPTAG::Helper::StrUtils::ToLowerInPlace(p_input); BOOST_CHECK(p_input == p_expected); }; @@ -24,12 +22,11 @@ BOOST_AUTO_TEST_CASE(ToLowerInPlaceTest) runTestCase("123!-=aBc", "123!-=abc"); } - BOOST_AUTO_TEST_CASE(SplitStringTest) { std::string input("seg1 seg2 seg3 seg4"); - const auto& segs = SPTAG::Helper::StrUtils::SplitString(input, " "); + const auto &segs = SPTAG::Helper::StrUtils::SplitString(input, " "); BOOST_CHECK(segs.size() == 4); BOOST_CHECK(segs[0] == "seg1"); BOOST_CHECK(segs[1] == "seg2"); @@ -37,24 +34,18 @@ BOOST_AUTO_TEST_CASE(SplitStringTest) BOOST_CHECK(segs[3] == "seg4"); } - BOOST_AUTO_TEST_CASE(FindTrimmedSegmentTest) { using namespace SPTAG::Helper::StrUtils; std::string input("\t Space End \r\n\t"); - const auto& pos = FindTrimmedSegment(input.c_str(), - input.c_str() + input.size(), - [](char p_val)->bool - { - return std::isspace(p_val) > 0; - }); + const auto &pos = FindTrimmedSegment(input.c_str(), input.c_str() + input.size(), + [](char p_val) -> bool { return std::isspace(p_val) > 0; }); BOOST_CHECK(pos.first == input.c_str() + 2); BOOST_CHECK(pos.second == input.c_str() + 13); } - BOOST_AUTO_TEST_CASE(StartsWithTest) { using namespace SPTAG::Helper::StrUtils; @@ -72,7 +63,6 @@ BOOST_AUTO_TEST_CASE(StartsWithTest) BOOST_CHECK(!StartsWith("Abcd", "Abcde")); } - BOOST_AUTO_TEST_CASE(StrEqualIgnoreCaseTest) { using namespace SPTAG::Helper::StrUtils; @@ -88,6 +78,4 @@ BOOST_AUTO_TEST_CASE(StrEqualIgnoreCaseTest) BOOST_CHECK(!StrEqualIgnoreCase("000", "OOO")); } - - BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/ConcurrentTest.cpp b/Test/src/ConcurrentTest.cpp index 0e742e96f..b137037d5 100644 --- a/Test/src/ConcurrentTest.cpp +++ b/Test/src/ConcurrentTest.cpp @@ -1,19 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" -#include "inc/Helper/SimpleIniReader.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" +#include #include #include -#include template -void ConcurrentAddSearchSave(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, const std::string out) +void ConcurrentAddSearchSave(SPTAG::IndexAlgoType algo, std::string distCalcMethod, + std::shared_ptr &vec, std::shared_ptr &meta, + const std::string out) { - std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); vecIndex->SetParameter("DistCalcMethod", distCalcMethod); @@ -28,10 +31,12 @@ void ConcurrentAddSearchSave(SPTAG::IndexAlgoType algo, std::string distCalcMeth while (!stop) { SPTAG::ByteArray metaarr = meta->GetMetadata(i); - std::uint64_t offset[2] = { 0, metaarr.Length() }; - std::shared_ptr metaset(new SPTAG::MemMetadataSet(metaarr, SPTAG::ByteArray((std::uint8_t*)offset, 2 * sizeof(std::uint64_t), false), 1)); + std::uint64_t offset[2] = {0, metaarr.Length()}; + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + metaarr, SPTAG::ByteArray((std::uint8_t *)offset, 2 * sizeof(std::uint64_t), false), 1)); SPTAG::ErrorCode ret = vecIndex->AddIndex(vec->GetVector(i), 1, vec->Dimension(), metaset, true); - if (SPTAG::ErrorCode::Success != ret) std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; + if (SPTAG::ErrorCode::Success != ret) + std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; i = (i + 1) % vec->Count(); std::this_thread::sleep_for(std::chrono::milliseconds(10)); } @@ -51,7 +56,8 @@ void ConcurrentAddSearchSave(SPTAG::IndexAlgoType algo, std::string distCalcMeth }; auto SearchThread = [&stop, &vecIndex, &vec]() { - while (!stop) { + while (!stop) + { SPTAG::QueryResult res(vec->GetVector(SPTAG::COMMON::Utils::rand(vec->Count())), 5, true); vecIndex->SearchIndex(res); std::this_thread::sleep_for(std::chrono::milliseconds(10)); @@ -77,32 +83,39 @@ void ConcurrentAddSearchSave(SPTAG::IndexAlgoType algo, std::string distCalcMeth std::this_thread::sleep_for(std::chrono::seconds(30)); stop = true; - for (auto& thread : threads) { thread.join(); } + for (auto &thread : threads) + { + thread.join(); + } std::cout << "Main Thread quit!" << std::endl; } -template -void CTest(SPTAG::IndexAlgoType algo, std::string distCalcMethod) +template void CTest(SPTAG::IndexAlgoType algo, std::string distCalcMethod) { SPTAG::SizeType n = 2000, q = 3; SPTAG::DimensionType m = 10; std::vector vec; - for (SPTAG::SizeType i = 0; i < n; i++) { - for (SPTAG::DimensionType j = 0; j < m; j++) { + for (SPTAG::SizeType i = 0; i < n; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { vec.push_back((T)i); } } std::vector query; - for (SPTAG::SizeType i = 0; i < q; i++) { - for (SPTAG::DimensionType j = 0; j < m; j++) { + for (SPTAG::SizeType i = 0; i < q; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { query.push_back((T)i * 2); } } std::vector meta; std::vector metaoffset; - for (SPTAG::SizeType i = 0; i < n; i++) { + for (SPTAG::SizeType i = 0; i < n; i++) + { metaoffset.push_back((std::uint64_t)meta.size()); std::string a = std::to_string(i); for (size_t j = 0; j < a.length(); j++) @@ -111,13 +124,11 @@ void CTest(SPTAG::IndexAlgoType algo, std::string distCalcMethod) metaoffset.push_back((std::uint64_t)meta.size()); std::shared_ptr vecset(new SPTAG::BasicVectorSet( - SPTAG::ByteArray((std::uint8_t*)vec.data(), sizeof(T) * n * m, false), - SPTAG::GetEnumValueType(), m, n)); + SPTAG::ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), SPTAG::GetEnumValueType(), m, n)); std::shared_ptr metaset(new SPTAG::MemMetadataSet( - SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), - SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), - n)); + SPTAG::ByteArray((std::uint8_t *)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t *)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), n)); ConcurrentAddSearchSave(algo, distCalcMethod, vecset, metaset, "testindices"); } diff --git a/Test/src/DistanceTest.cpp b/Test/src/DistanceTest.cpp index 872fb01fe..d95dd46dc 100644 --- a/Test/src/DistanceTest.cpp +++ b/Test/src/DistanceTest.cpp @@ -1,60 +1,63 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "inc/Core/Common/DistanceUtils.h" +#include "inc/Test.h" #include #include #include #include -#include "inc/Test.h" -#include "inc/Core/Common/DistanceUtils.h" -template -static float ComputeCosineDistance(const T *pX, const T *pY, SPTAG::DimensionType length) { +template static float ComputeCosineDistance(const T *pX, const T *pY, SPTAG::DimensionType length) +{ float diff = 0; - const T* pEnd1 = pX + length; - while (pX < pEnd1) diff += (*pX++) * (*pY++); + const T *pEnd1 = pX + length; + while (pX < pEnd1) + diff += (*pX++) * (*pY++); return diff; } -template -static float ComputeL2Distance(const T *pX, const T *pY, SPTAG::DimensionType length) +template static float ComputeL2Distance(const T *pX, const T *pY, SPTAG::DimensionType length) { float diff = 0; - const T* pEnd1 = pX + length; - while (pX < pEnd1) { - float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1; + const T *pEnd1 = pX + length; + while (pX < pEnd1) + { + float c1 = ((float)(*pX++) - (float)(*pY++)); + diff += c1 * c1; } return diff; } -template -T random(int high = RAND_MAX, int low = 0) // Generates a random value. +template T random(int high = RAND_MAX, int low = 0) // Generates a random value. { - return (T)(low + float(high - low)*(std::rand()/static_cast(RAND_MAX + 1.0))); + return (T)(low + float(high - low) * (std::rand() / static_cast(RAND_MAX + 1.0))); } -template -void test(int high) { +template void test(int high) +{ SPTAG::DimensionType dimension = random(256, 2); T *X = new T[dimension], *Y = new T[dimension]; BOOST_ASSERT(X != nullptr && Y != nullptr); - for (SPTAG::DimensionType i = 0; i < dimension; i++) { + for (SPTAG::DimensionType i = 0; i < dimension; i++) + { X[i] = random(high, -high); Y[i] = random(high, -high); } - BOOST_CHECK_CLOSE_FRACTION(ComputeL2Distance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeDistance(X, Y, dimension, SPTAG::DistCalcMethod::L2), 1e-5); - BOOST_CHECK_CLOSE_FRACTION(high * high - ComputeCosineDistance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeDistance(X, Y, dimension, SPTAG::DistCalcMethod::Cosine), 1e-5); + BOOST_CHECK_CLOSE_FRACTION( + ComputeL2Distance(X, Y, dimension), + SPTAG::COMMON::DistanceUtils::ComputeDistance(X, Y, dimension, SPTAG::DistCalcMethod::L2), 1e-5); + BOOST_CHECK_CLOSE_FRACTION( + high * high - ComputeCosineDistance(X, Y, dimension), + SPTAG::COMMON::DistanceUtils::ComputeDistance(X, Y, dimension, SPTAG::DistCalcMethod::Cosine), 1e-5); delete[] X; delete[] Y; } template -void test_dist_calc_performance( - int high, - SPTAG::DimensionType dimension = 256, - SPTAG::SizeType size = 100, - SPTAG::DistCalcMethod calc_method = SPTAG::DistCalcMethod::L2) +void test_dist_calc_performance(int high, SPTAG::DimensionType dimension = 256, SPTAG::SizeType size = 100, + SPTAG::DistCalcMethod calc_method = SPTAG::DistCalcMethod::L2, int threads = 1) { T **X = new T *[size]; T **Y = new T *[size]; @@ -69,15 +72,38 @@ void test_dist_calc_performance( } } - double start, end; - start = omp_get_wtime(); -#pragma omp parallel for - for (SPTAG::SizeType i = 0; i < size; i++) + auto start = std::chrono::steady_clock::now(); + std::vector mythreads; + mythreads.reserve(threads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < threads; tid++) { - SPTAG::COMMON::DistanceUtils::ComputeDistance(X[i], Y[i], dimension, calc_method); + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < size) + { + SPTAG::COMMON::DistanceUtils::ComputeDistance(X[i], Y[i], dimension, calc_method); + } + else + { + return; + } + } + }); } - end = omp_get_wtime(); - std::cout << "Time to calculate distance (ms): " << (end - start) * 1000 << std::endl; + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + + auto end = std::chrono::steady_clock::now(); + std::cout << "Time to calculate distance (ms): " + << (std::chrono::duration_cast(end - start).count() * 1.0) / 1000.0 + << std::endl; delete[] X; delete[] Y; @@ -102,14 +128,13 @@ BOOST_AUTO_TEST_CASE(TestDistanceComputationPerformance) for (int num_threads : nums_threads) { std::cout << "num_thread: " << num_threads << std::endl; - omp_set_num_threads(num_threads); for (SPTAG::DistCalcMethod calc_method : calc_methods) { std::cout << "calc_method: " << (calc_method == SPTAG::DistCalcMethod::L2 ? "L2" : "Cosine") << std::endl; for (auto dimension : dimensions) { std::cout << "type: int8, dimension: " << dimension << ", size: " << size << std::endl; - test_dist_calc_performance(127, dimension, size, calc_method); + test_dist_calc_performance(127, dimension, size, calc_method, num_threads); } } } diff --git a/Test/src/FilterTest.cpp b/Test/src/FilterTest.cpp new file mode 100644 index 000000000..285df5b42 --- /dev/null +++ b/Test/src/FilterTest.cpp @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" + +#include +#include + +template +void BuildWithMetaMapping(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) +{ + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "16"); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +template void SearchWithFilter(const std::string folder, T *vec, SPTAG::SizeType n, int k) +{ + std::cout << "start search with filter" << std::endl; + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + std::string value = "2"; + std::function filterFunction = [value](const SPTAG::ByteArray &meta) -> bool { + std::string metaValue((char *)meta.Data(), meta.Length()); + std::cout << metaValue << std::endl; + return metaValue != value; + }; + for (SPTAG::SizeType i = 0; i < n; i++) + { + SPTAG::QueryResult res(vec, k, true); + vecIndex->SearchIndexWithFilter(res, filterFunction); + std::unordered_set resmeta; + for (int j = 0; j < k; j++) + { + resmeta.insert(std::string((char *)res.GetMetadata(j).Data(), res.GetMetadata(j).Length())); + std::cout << res.GetResult(j)->Dist << "@(" << res.GetResult(j)->VID << "," + << std::string((char *)res.GetMetadata(j).Data(), res.GetMetadata(j).Length()) << ") "; + } + std::cout << std::endl; + for (int j = 0; j < k; j++) + { + BOOST_CHECK(resmeta.find("2") == resmeta.end()); + } + vec += vecIndex->GetFeatureDim(); + } + vecIndex.reset(); +} + +template void FTest(SPTAG::IndexAlgoType algo, std::string distCalcMethod) +{ + SPTAG::SizeType n = 2000, q = 3; + SPTAG::DimensionType m = 10; + int k = 3; + std::vector vec; + for (SPTAG::SizeType i = 0; i < n; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { + vec.push_back((T)i); + } + } + + std::vector query; + for (SPTAG::SizeType i = 0; i < q; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { + query.push_back((T)i * 2); + } + } + + std::vector meta; + std::vector metaoffset; + for (SPTAG::SizeType i = 0; i < n; i++) + { + metaoffset.push_back((std::uint64_t)meta.size()); + std::string a = std::to_string(i); + for (size_t j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back((std::uint64_t)meta.size()); + + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + SPTAG::ByteArray((std::uint8_t *)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t *)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), n)); + + BuildWithMetaMapping(algo, distCalcMethod, vecset, metaset, "testindices"); + + SearchWithFilter("testindices", query.data(), q, k); +} + +BOOST_AUTO_TEST_SUITE(FilterTest) + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + FTest(SPTAG::IndexAlgoType::BKT, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/IniReaderTest.cpp b/Test/src/IniReaderTest.cpp index c5dd0baaf..9386a35a6 100644 --- a/Test/src/IniReaderTest.cpp +++ b/Test/src/IniReaderTest.cpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" #include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" #include diff --git a/Test/src/IterativeScanTest.cpp b/Test/src/IterativeScanTest.cpp new file mode 100644 index 000000000..b6dc109e2 --- /dev/null +++ b/Test/src/IterativeScanTest.cpp @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/ResultIterator.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" +#include "inc/TestDataGenerator.h" + +#include +#include + +using namespace SPTAG; + +template +void BuildIndex(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) +{ + std::shared_ptr vecIndex = + VectorIndex::CreateInstance(algo, GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + if (algo != IndexAlgoType::SPANN) + { + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "16"); + } + else + { + vecIndex->SetParameter("IndexAlgoType", "BKT", "Base"); + vecIndex->SetParameter("DistCalcMethod", distCalcMethod, "Base"); + + vecIndex->SetParameter("isExecute", "true", "SelectHead"); + vecIndex->SetParameter("NumberOfThreads", "4", "SelectHead"); + vecIndex->SetParameter("Ratio", "0.2", "SelectHead"); // vecIndex->SetParameter("Count", "200", "SelectHead"); + + vecIndex->SetParameter("isExecute", "true", "BuildHead"); + vecIndex->SetParameter("RefineIterations", "3", "BuildHead"); + vecIndex->SetParameter("NumberOfThreads", "4", "BuildHead"); + + vecIndex->SetParameter("isExecute", "true", "BuildSSDIndex"); + vecIndex->SetParameter("BuildSsdIndex", "true", "BuildSSDIndex"); + vecIndex->SetParameter("NumberOfThreads", "4", "BuildSSDIndex"); + vecIndex->SetParameter("PostingPageLimit", std::to_string(4 * sizeof(T)), "BuildSSDIndex"); + vecIndex->SetParameter("SearchPostingPageLimit", std::to_string(4 * sizeof(T)), "BuildSSDIndex"); + vecIndex->SetParameter("InternalResultNum", "64", "BuildSSDIndex"); + vecIndex->SetParameter("SearchInternalResultNum", "64", "BuildSSDIndex"); + vecIndex->SetParameter("MaxCheck", "8192", "BuildSSDIndex"); + } + + BOOST_CHECK(ErrorCode::Success == vecIndex->BuildIndex(vec, meta)); + BOOST_CHECK(ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +template +void SearchIterativeBatch(const std::string folder, T *vec, SizeType n, std::string *truthmeta) +{ + std::shared_ptr vecIndex; + + BOOST_CHECK(ErrorCode::Success == VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + vecIndex->SetParameter("MaxCheck", "5", "BuildSSDIndex"); + vecIndex->UpdateIndex(); + + std::shared_ptr resultIterator = vecIndex->GetIterator(vec); + // std::cout << "relaxedMono:" << resultIterator->GetRelaxedMono() << std::endl; + int batch = 5; + int ri = 0; + for (int i = 0; i < 2; i++) + { + auto results = resultIterator->Next(batch); + int resultCount = results->GetResultNum(); + if (resultCount <= 0) + break; + for (int j = 0; j < resultCount; j++) + { + + BOOST_CHECK(std::string((char *)((results->GetMetadata(j)).Data()), (results->GetMetadata(j)).Length()) == + truthmeta[ri]); + BOOST_CHECK(results->GetResult(j)->RelaxedMono == true); + std::cout << "Result[" << ri << "] VID:" << results->GetResult(j)->VID + << " Dist:" << results->GetResult(j)->Dist + << " RelaxedMono:" << results->GetResult(j)->RelaxedMono << std::endl; + ri++; + } + } + resultIterator->Close(); +} + +template void TestIterativeScan(IndexAlgoType algo, std::string distCalcMethod) +{ + SizeType n = 6000, q = 1; + DimensionType m = 10; + std::vector vec; + for (SizeType i = 0; i < n; i++) + { + for (DimensionType j = 0; j < m; j++) + { + vec.push_back((T)i); + } + } + + std::vector query; + for (SizeType i = 0; i < q; i++) + { + for (DimensionType j = 0; j < m; j++) + { + query.push_back((T)i * 2); + } + } + + std::vector meta; + std::vector metaoffset; + for (SizeType i = 0; i < n; i++) + { + metaoffset.push_back((std::uint64_t)meta.size()); + std::string a = std::to_string(i); + for (size_t j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back((std::uint64_t)meta.size()); + + std::shared_ptr vecset(new BasicVectorSet( + ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new MemMetadataSet( + ByteArray((std::uint8_t *)meta.data(), meta.size() * sizeof(char), false), + ByteArray((std::uint8_t *)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), n)); + + BuildIndex(algo, distCalcMethod, vecset, metaset, "testindices"); + std::string truthmeta1[] = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"}; + SearchIterativeBatch("testindices", query.data(), q, truthmeta1); +} + +template +float EvaluateRecall(const std::vector &res, std::shared_ptr &vecIndex, + std::shared_ptr &queryset, std::shared_ptr &truth, + std::shared_ptr &baseVec, int k) +{ + if (!truth) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth data is null. Cannot compute recall.\n"); + return 0.0f; + } + + const SizeType recallK = min(k, static_cast(truth->Dimension())); + float totalRecall = 0.0f; + float eps = 1e-4f; + + for (SizeType i = 0; i < queryset->Count(); ++i) + { + const SizeType *truthNN = reinterpret_cast(truth->GetVector(i)); + for (int j = 0; j < recallK; ++j) + { + SizeType truthVid = truthNN[j]; + float truthDist = vecIndex->ComputeDistance(queryset->GetVector(i), baseVec->GetVector(truthVid)); + + for (int l = 0; l < k; ++l) + { + const auto result = res[i].GetResult(l); + if (truthVid == result->VID || + std::fabs(truthDist - result->Dist) <= eps * (std::fabs(truthDist) + eps)) + { + totalRecall += 1.0f; + break; + } + } + } + } + + float avgRecall = totalRecall / (queryset->Count() * recallK); + return avgRecall; +} + +template void TestIterativeScanRandom(IndexAlgoType algo, std::string distCalcMethod) +{ + SizeType n = 50000, q = 3, k = 5; + DimensionType m = 128; + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(n, q, m, k, distCalcMethod); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + BuildIndex(algo, distCalcMethod, vecset, metaset, "testindices"); + std::shared_ptr vecIndex; + BOOST_CHECK(ErrorCode::Success == VectorIndex::LoadIndex("testindices", vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + std::vector res(queryset->Count(), QueryResult(nullptr, k, true)); + std::vector resiter(queryset->Count(), QueryResult(nullptr, k, true)); + for (int i = 0; i < q; i++) + { + res[i].SetTarget(queryset->GetVector(i)); + vecIndex->SearchIndex(res[i]); + int scanned = res[i].GetScanned(); + + std::shared_ptr resultIterator = vecIndex->GetIterator((T *)(queryset->GetVector(i))); + int batch = 1; + int ri = 0; + int iterscanned = 0; + bool relaxMono = false; + while (!relaxMono) + { + auto results = resultIterator->Next(batch); + int resultCount = results->GetResultNum(); + if (resultCount <= 0) break; + for (int j = 0; j < resultCount; j++) + { + relaxMono = results->GetResult(j)->RelaxedMono; + ((COMMON::QueryResultSet *)(&resiter[i]))->AddPoint(results->GetResult(j)->VID, results->GetResult(j)->Dist); + } + ri += resultCount; + iterscanned = results->GetScanned(); + } + resultIterator->Close(); + + std::cout << "TopK scanned:" << scanned << " Iterator scanned:" << iterscanned << std::endl; + } + std::cout << "TopK Recall:" << EvaluateRecall(res, vecIndex, queryset, truth, vecset, k) + << " Iterator Recall:" << EvaluateRecall(resiter, vecIndex, queryset, truth, vecset, k) << std::endl; +} + +BOOST_AUTO_TEST_SUITE(IterativeScanTest) + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + TestIterativeScan(IndexAlgoType::BKT, "L2"); +} + +BOOST_AUTO_TEST_CASE(BKTRandomTest) +{ + TestIterativeScanRandom(IndexAlgoType::BKT, "L2"); +} + +BOOST_AUTO_TEST_CASE(SPANNRandomTest) +{ + TestIterativeScanRandom(IndexAlgoType::SPANN, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/KVTest.cpp b/Test/src/KVTest.cpp new file mode 100644 index 000000000..9c49d4ae1 --- /dev/null +++ b/Test/src/KVTest.cpp @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/SPANN/ExtraFileController.h" +#include "inc/Core/SPANN/IExtraSearcher.h" +#include "inc/Test.h" +#include +#include + +// enable rocksdb io_uring + +#ifdef ROCKSDB +#include "inc/Core/SPANN/ExtraRocksDBController.h" +// extern "C" bool RocksDbIOUringEnable() { return true; } +#endif + +#ifdef SPDK +#include "inc/Core/SPANN/ExtraSPDKController.h" +#endif + +using namespace SPTAG; +using namespace SPTAG::SPANN; + +void Search(std::shared_ptr db, int internalResultNum, int totalSize, int times, bool debug, + SPTAG::SPANN::ExtraWorkSpace &workspace) +{ + std::vector headIDs(internalResultNum, 0); + + std::vector values; + double latency = 0; + for (int i = 0; i < times; i++) + { + values.clear(); + for (int j = 0; j < internalResultNum; j++) + headIDs[j] = (j + i * internalResultNum) % totalSize; + auto t1 = std::chrono::high_resolution_clock::now(); + db->MultiGet(headIDs, &values, MaxTimeout, &(workspace.m_diskRequests)); + auto t2 = std::chrono::high_resolution_clock::now(); + latency += std::chrono::duration_cast(t2 - t1).count(); + + if (debug) + { + for (int j = 0; j < internalResultNum; j++) + { + std::cout << values[j].substr(PageSize) << std::endl; + } + } + } + std::cout << "avg get time: " << (latency / (float)(times)) << "us" << std::endl; +} + +void Test(std::string path, std::string type, bool debug = false) +{ + int internalResultNum = 64; + int totalNum = 1024; + int mergeIters = 3; + std::shared_ptr db; + SPTAG::SPANN::ExtraWorkSpace workspace; + SPTAG::SPANN::Options opt; + auto idx = path.find_last_of(FolderSep); + opt.m_indexDirectory = path.substr(0, idx); + opt.m_ssdMappingFile = path.substr(idx + 1); + workspace.Initialize(4096, 2, internalResultNum, 4 * PageSize, true, false); + + if (type == "RocksDB") + { +#ifdef ROCKSDB + db.reset(new RocksDBIO(path.c_str(), true)); +#else + { + std::cerr << "RocksDB is not supported in this build." << std::endl; + return; + } +#endif + } + else if (type == "SPDK") + { +#ifdef SPDK + db.reset(new SPDKIO(opt)); +#else + { + std::cerr << "SPDK is not supported in this build." << std::endl; + return; + } +#endif + } + else + { + db.reset(new FileIO(opt)); + } + + auto t1 = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < totalNum; i++) + { + int len = std::to_string(i).length(); + std::string val(PageSize - len, '0'); + db->Put(i, val, MaxTimeout, &(workspace.m_diskRequests)); + } + auto t2 = std::chrono::high_resolution_clock::now(); + std::cout << "avg put time: " + << (std::chrono::duration_cast(t2 - t1).count() / (float)(totalNum)) << "us" + << std::endl; + + db->ForceCompaction(); + + t1 = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < totalNum; i++) + { + for (int j = 0; j < mergeIters; j++) + { + db->Merge(i, std::to_string(i), MaxTimeout, &(workspace.m_diskRequests)); + } + } + t2 = std::chrono::high_resolution_clock::now(); + std::cout << "avg merge time: " + << (std::chrono::duration_cast(t2 - t1).count() / + (float)(totalNum * mergeIters)) + << "us" << std::endl; + + Search(db, internalResultNum, totalNum, 10, debug, workspace); + + db->ForceCompaction(); + db->ShutDown(); + + // if (type == "RocksDB") { + // db.reset(new RocksDBIO(path.c_str(), true)); + // Search(db, internalResultNum, totalNum, 10, debug); + // db->ForceCompaction(); + // db->ShutDown(); + // } +} + +BOOST_AUTO_TEST_SUITE(KVTest) + +BOOST_AUTO_TEST_CASE(RocksDBTest) +{ + if (!direxists("tmp_rocksdb")) + mkdir("tmp_rocksdb"); + Test(std::string("tmp_rocksdb") + FolderSep + "test", "RocksDB", false); +} + +BOOST_AUTO_TEST_CASE(SPDKTest) +{ + if (!direxists("tmp_spdk")) + mkdir("tmp_spdk"); + Test(std::string("tmp_spdk") + FolderSep + "test", "SPDK", false); +} + +BOOST_AUTO_TEST_CASE(FileTest) +{ + if (!direxists("tmp_file")) + mkdir("tmp_file"); + Test(std::string("tmp_file") + FolderSep + "test", "File", false); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/MultiIndexExperiment.cpp b/Test/src/MultiIndexExperiment.cpp new file mode 100644 index 000000000..60b3c87c2 --- /dev/null +++ b/Test/src/MultiIndexExperiment.cpp @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/MultiIndexScan.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" + +#include +#include +#include +std::vector VID2rid; + +// top number to query +int topk = 50; +// vector dimension +SPTAG::DimensionType m = 1024; +std::string pathEmbeddings = "/embeddings/"; +std::string outputResult = "/result/"; +std::vector testSearchLimit = {256, 512, 1024}; + +struct FinalData +{ + double Dist; + int rid, vid; + FinalData(double x, int y, int z) : Dist(x), rid(y), vid(z) + { + } + bool operator<(const FinalData x) const + { + return Dist < x.Dist; + } +}; + +static std::string indexName(unsigned int id) +{ + return "testindices" + std::to_string(id); +} + +template +void BuildIndex(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) +{ + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "16"); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +float sumFunc(std::vector in) +{ + return std::accumulate(in.begin(), in.end(), (float)0.0); +} + +template +void GenerateVectorDataSet(SPTAG::IndexAlgoType algo, std::string distCalcMethod, unsigned int id, + std::vector &query, std::vector &query_id, std::string File_name) +{ + SPTAG::SizeType n = 0, q = 0; + + std::string fileName; + fileName = pathEmbeddings + File_name + "_embeds_collection.tsv"; + std::ifstream in1(fileName.c_str()); + if (!in1) + { + std::cout << "Fail to read " << fileName << "." << std::endl; + return; + } + std::vector vec; + std::string line; + std::vector meta; + + for (; getline(in1, line); n++) + { + std::uint64_t number = 0; + std::string a = line.substr(0, line.find("\t")); + for (size_t j = 0; j < a.length(); j++) + number = number * 10 + a[j] - '0'; + VID2rid.push_back((int)number); + meta.push_back(number); + + int l = (int)vec.size(); + line.erase(0, line.find("[") + 1); + for (SPTAG::DimensionType j = 0; j < m - 1; j++) + { + vec.push_back((T)std::stof(line.substr(0, line.find(",")))); + line.erase(0, line.find(",") + 2); + } + vec.push_back((T)std::stof(line.substr(0, line.find("]")))); + // normalization; + int r = (int)vec.size(); + double sum = 0.0; + for (int i = l; i < r; i++) + sum += vec[i] * vec[i]; + sum = sqrt(sum); + for (int i = l; i < r; i++) + vec[i] = (T)(vec[i] / sum); + } + fileName = pathEmbeddings + File_name + "_embeds_query.tsv"; + std::ifstream in2(fileName.c_str()); + if (!in2) + { + std::cout << "Fail to read " << fileName << "." << std::endl; + return; + } + for (; getline(in2, line); q++) + { + query_id.push_back((T)std::stof(line.substr(0, line.find("\t")))); + + int l = (int)query.size(); + line.erase(0, line.find("[") + 1); + for (SPTAG::DimensionType j = 0; j < m - 1; j++) + { + query.push_back((T)std::stof(line.substr(0, line.find(",")))); + line.erase(0, line.find(",") + 2); + } + query.push_back((T)std::stof(line.substr(0, line.find("]")))); + // normalization; + int r = (int)query.size(); + double sum = 0.0; + for (int i = l; i < r; i++) + sum += query[i] * query[i]; + sum = sqrt(sum); + for (int i = l; i < r; i++) + query[i] = (T)(query[i] / sum); + } + + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset( + new SPTAG::MemMetadataSet(n * sizeof(std::uint64_t), n * sizeof(std::uint64_t), n)); + for (auto x : meta) + { + std::uint8_t result[sizeof(x)]; + std::memcpy(result, &x, sizeof(x)); + metaset->Add(SPTAG::ByteArray(result, 8, false)); + } + + BuildIndex(algo, distCalcMethod, vecset, metaset, indexName(id).c_str()); +} + +template void TestMultiIndexScanN(SPTAG::IndexAlgoType algo, std::string distCalcMethod, unsigned int n) +{ + + std::vector> queries(n, std::vector()); + std::vector> query_id(n, std::vector()); + + GenerateVectorDataSet(algo, distCalcMethod, 0, queries[0], query_id[0], "img"); + GenerateVectorDataSet(algo, distCalcMethod, 1, queries[1], query_id[1], "rec"); + + for (auto KK : testSearchLimit) + { + std::string output_result_path = outputResult + std::to_string(KK) + "_qrels.txt"; + std::string output_lantency_path = outputResult + std::to_string(KK) + "_latency.txt"; + + std::ofstream out1; + out1.open(output_result_path); + + if (!out1.is_open()) + { + std::cout << "Cannot open file out1" << std::endl; + return; + } + + std::ofstream out2; + out2.open(output_lantency_path); + if (!out2.is_open()) + { + std::cout << "Cannot open file out2" << std::endl; + return; + } + + for (int i = 0; i < query_id[0].size(); i++) + { + std::vector> query_current(n, std::vector()); + for (unsigned int j = 0; j < n; j++) + { + for (int k = i * m; k < i * m + m; k++) + { + query_current[j].push_back(queries[j][k]); + } + } + int idquery = (int)query_id[0][i]; + std::vector> vecIndices; + std::vector p_targets; + for (unsigned int i = 0; i < n; i++) + { + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(indexName(i).c_str(), vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + vecIndices.push_back(vecIndex); + p_targets.push_back(query_current[i].data()); + } + + double Begin_time = clock(); + SPTAG::MultiIndexScan scan(vecIndices, p_targets, topk, &sumFunc, false, 10000000, KK); + SPTAG::BasicResult result; + + std::set FinalResult; + + for (int i = 0; i < topk; i++) + { + bool hasResult = scan.Next(result); + if (!hasResult) + break; + + FinalResult.insert(FinalData(result.Dist, VID2rid[result.VID], result.VID)); + auto Finaltmp = FinalResult.end(); + Finaltmp--; + if (FinalResult.size() > topk) + FinalResult.erase(Finaltmp); + } + + double End_time = clock(); + out2 << idquery << "\t" << (End_time - Begin_time) / CLOCKS_PER_SEC << std::endl; + int Rank = 0; + for (auto i = FinalResult.begin(); i != FinalResult.end(); i++) + out1 << idquery << " " << i->rid << " " << ++Rank << " " << i->vid << " " << i->Dist << std::endl; + + for (int i = (int)FinalResult.size(); i < topk; i++) + out1 << idquery << " " << -1 << " " << ++Rank << " " << -1 << " " << -1 << std::endl; + + scan.Close(); + } + } +} + +BOOST_AUTO_TEST_SUITE(MultiIndexExperiment) + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + TestMultiIndexScanN(SPTAG::IndexAlgoType::BKT, "InnerProduct", 2); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/MultiIndexScanTest.cpp b/Test/src/MultiIndexScanTest.cpp new file mode 100644 index 000000000..3fca5ea01 --- /dev/null +++ b/Test/src/MultiIndexScanTest.cpp @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/MultiIndexScan.h" +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Test.h" + +#include + +#include + +static std::string indexName(unsigned int id) +{ + return "testindices" + std::to_string(id); +} +template +void BuildIndex(SPTAG::IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, const std::string out) +{ + + std::shared_ptr vecIndex = + SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + BOOST_CHECK(nullptr != vecIndex); + + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "16"); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); +} + +float rankFunc(std::vector in) +{ + return std::accumulate(in.begin(), in.end(), 0.0f); +} + +template +void MultiIndexSearch(unsigned int n, std::vector> &queries, int k, bool userTimer, int termCond) +{ + + std::vector> vecIndices; + std::vector p_targets; + for (unsigned int i = 0; i < n; i++) + { + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(indexName(i).c_str(), vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + vecIndices.push_back(vecIndex); + p_targets.push_back(queries[i].data()); + } + + SPTAG::MultiIndexScan scan(vecIndices, p_targets, k, &rankFunc, false, 10, 10); + SPTAG::BasicResult result; + + std::cout << "Start Scanning!!! " << std::endl; + + for (int i = 0; i < 100; i++) + { + bool hasResult = scan.Next(result); + if (!hasResult) + break; + std::cout << "hasResult: " << hasResult << std::endl; + std::cout << "result: " << result.VID << std::endl; + } + scan.Close(); +} + +template +void GenerateVectorDataSet(SPTAG::IndexAlgoType algo, std::string distCalcMethod, unsigned int id, + std::vector &query) +{ + SPTAG::SizeType n = 2000, q = 3; + SPTAG::DimensionType m = 10; + + std::vector vec; + for (SPTAG::SizeType i = 0; i < n; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { + vec.push_back((T)i + id * n); + } + } + + for (SPTAG::SizeType i = 0; i < q; i++) + { + for (SPTAG::DimensionType j = 0; j < m; j++) + { + query.push_back((T)i * 2 + id * n); + } + } + + std::vector meta; + std::vector metaoffset; + for (SPTAG::SizeType i = 0; i < n; i++) + { + metaoffset.push_back((std::uint64_t)meta.size()); + std::string a = std::to_string(i + id * n); + for (size_t j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back((std::uint64_t)meta.size()); + + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t *)vec.data(), sizeof(T) * n * m, false), SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + SPTAG::ByteArray((std::uint8_t *)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t *)metaoffset.data(), metaoffset.size() * sizeof(std::uint64_t), false), n)); + + BuildIndex(algo, distCalcMethod, vecset, metaset, indexName(id).c_str()); +} + +template void TestMultiIndexScanN(SPTAG::IndexAlgoType algo, std::string distCalcMethod, unsigned int n) +{ + int k = 3; + std::vector> queries(n, std::vector()); + + for (unsigned int i = 0; i < n; i++) + { + GenerateVectorDataSet(algo, distCalcMethod, i, queries[i]); + } + + MultiIndexSearch(n, queries, k, false, 10); +} + +BOOST_AUTO_TEST_SUITE(MultiIndexScanTest) + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + TestMultiIndexScanN(SPTAG::IndexAlgoType::BKT, "L2", 2); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/PerfTest.cpp b/Test/src/PerfTest.cpp index bd62ea2dc..4a69830a7 100644 --- a/Test/src/PerfTest.cpp +++ b/Test/src/PerfTest.cpp @@ -1,20 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" -#include "inc/Helper/VectorSetReader.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common/CommonUtils.h" -#include "inc/Core/Common/QueryResultSet.h" #include "inc/Core/Common/DistanceUtils.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/Test.h" +#include #include #include -#include using namespace SPTAG; template -void Search(std::shared_ptr& vecIndex, std::shared_ptr& queryset, int k, std::shared_ptr& truth) +void Search(std::shared_ptr &vecIndex, std::shared_ptr &queryset, int k, + std::shared_ptr &truth) { std::vector res(queryset->Count(), QueryResult(nullptr, k, true)); auto t1 = std::chrono::high_resolution_clock::now(); @@ -24,38 +25,46 @@ void Search(std::shared_ptr& vecIndex, std::shared_ptr& vecIndex->SearchIndex(res[i]); } auto t2 = std::chrono::high_resolution_clock::now(); - std::cout << "Search time: " << (std::chrono::duration_cast(t2 - t1).count() / (float)(queryset->Count())) << "us" << std::endl; + std::cout << "Search time: " + << (std::chrono::duration_cast(t2 - t1).count() / (float)(queryset->Count())) + << "us" << std::endl; float eps = 1e-6f, recall = 0; bool deleted; int truthDimension = min(k, truth->Dimension()); for (SizeType i = 0; i < queryset->Count(); i++) { - SizeType* nn = (SizeType*)(truth->GetVector(i)); + SizeType *nn = (SizeType *)(truth->GetVector(i)); for (int j = 0; j < truthDimension; j++) { std::string truthstr = std::to_string(nn[j]); - ByteArray truthmeta = ByteArray((std::uint8_t*)(truthstr.c_str()), truthstr.length(), false); - float truthdist = vecIndex->ComputeDistance(queryset->GetVector(i), vecIndex->GetSample(truthmeta, deleted)); + ByteArray truthmeta = ByteArray((std::uint8_t *)(truthstr.c_str()), truthstr.length(), false); + float truthdist = + vecIndex->ComputeDistance(queryset->GetVector(i), vecIndex->GetSample(truthmeta, deleted)); for (int l = 0; l < k; l++) { - if (fabs(truthdist - res[i].GetResult(l)->Dist) <= eps * (fabs(truthdist) + eps)) { + if (fabs(truthdist - res[i].GetResult(l)->Dist) <= eps * (fabs(truthdist) + eps)) + { recall += 1.0; break; } } } } - LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", k, truthDimension, recall / queryset->Count() / truthDimension); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", k, truthDimension, + recall / queryset->Count() / truthDimension); } template -void PerfAdd(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, std::shared_ptr& queryset, int k, std::shared_ptr& truth, std::string out) +void PerfAdd(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, std::shared_ptr &queryset, int k, + std::shared_ptr &truth, std::string out) { std::shared_ptr vecIndex = VectorIndex::CreateInstance(algo, GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); - if (algo == IndexAlgoType::KDT) vecIndex->SetParameter("KDTNumber", "2"); + if (algo == IndexAlgoType::KDT) + vecIndex->SetParameter("KDTNumber", "2"); vecIndex->SetParameter("DistCalcMethod", distCalcMethod); vecIndex->SetParameter("NumberOfThreads", "5"); vecIndex->SetParameter("AddCEF", "200"); @@ -64,15 +73,20 @@ void PerfAdd(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptrSetParameter("MaxCheckForRefineGraph", "4096"); auto t1 = std::chrono::high_resolution_clock::now(); - for (SizeType i = 0; i < vec->Count(); i++) { + for (SizeType i = 0; i < vec->Count(); i++) + { ByteArray metaarr = meta->GetMetadata(i); - std::uint64_t offset[2] = { 0, metaarr.Length() }; - std::shared_ptr metaset(new MemMetadataSet(metaarr, ByteArray((std::uint8_t*)offset, 2 * sizeof(std::uint64_t), false), 1)); + std::uint64_t offset[2] = {0, metaarr.Length()}; + std::shared_ptr metaset( + new MemMetadataSet(metaarr, ByteArray((std::uint8_t *)offset, 2 * sizeof(std::uint64_t), false), 1)); ErrorCode ret = vecIndex->AddIndex(vec->GetVector(i), 1, vec->Dimension(), metaset, true); - if (ErrorCode::Success != ret) std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; + if (ErrorCode::Success != ret) + std::cerr << "Error AddIndex(" << (int)(ret) << ") for vector " << i << std::endl; } auto t2 = std::chrono::high_resolution_clock::now(); - std::cout << "Add time: " << (std::chrono::duration_cast(t2 - t1).count() / (float)(vec->Count())) << "us" << std::endl; + std::cout << "Add time: " + << (std::chrono::duration_cast(t2 - t1).count() / (float)(vec->Count())) + << "us" << std::endl; Sleep(10000); @@ -84,13 +98,16 @@ void PerfAdd(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr(vecIndex, queryset, k, truth); } -template -void PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, std::shared_ptr& queryset, int k, std::shared_ptr& truth, std::string out) +template +void PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, std::shared_ptr &queryset, int k, + std::shared_ptr &truth, std::string out) { std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); BOOST_CHECK(nullptr != vecIndex); - if (algo == IndexAlgoType::KDT) vecIndex->SetParameter("KDTNumber", "2"); + if (algo == IndexAlgoType::KDT) + vecIndex->SetParameter("KDTNumber", "2"); vecIndex->SetParameter("DistCalcMethod", distCalcMethod); vecIndex->SetParameter("NumberOfThreads", "5"); vecIndex->SetParameter("RefineIterations", "3"); @@ -103,36 +120,45 @@ void PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr -void GenerateData(std::shared_ptr& vecset, std::shared_ptr& metaset, std::shared_ptr& queryset, std::shared_ptr& truth, std::string distCalcMethod, int k) +void GenerateData(std::shared_ptr &vecset, std::shared_ptr &metaset, + std::shared_ptr &queryset, std::shared_ptr &truth, std::string distCalcMethod, + int k) { SizeType n = 2000, q = 2000; DimensionType m = 128; - - if (fileexists("perftest_vector.bin") && fileexists("perftest_meta.bin") && fileexists("perftest_metaidx.bin") && fileexists("perftest_query.bin")) { - std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); + + if (fileexists("perftest_vector.bin") && fileexists("perftest_meta.bin") && fileexists("perftest_metaidx.bin") && + fileexists("perftest_query.bin")) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile("perftest_vector.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); exit(1); } vecset = vectorReader->GetVectorSet(); - metaset.reset(new MemMetadataSet("perftest_meta.bin", "perftest_metaidx.bin", vecset->Count() * 2, vecset->Count() * 2, 10)); + metaset.reset(new MemMetadataSet("perftest_meta.bin", "perftest_metaidx.bin", vecset->Count() * 2, + vecset->Count() * 2, 10)); if (ErrorCode::Success != vectorReader->LoadFile("perftest_query.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); exit(1); } queryset = vectorReader->GetVectorSet(); } - else { + else + { ByteArray vec = ByteArray::Alloc(sizeof(T) * n * m); - for (SizeType i = 0; i < n; i++) { - for (DimensionType j = 0; j < m; j++) { - ((T*)vec.Data())[i * m + j] = (T)COMMON::Utils::rand(127, -127); + for (SizeType i = 0; i < n; i++) + { + for (DimensionType j = 0; j < m; j++) + { + ((T *)vec.Data())[i * m + j] = (T)COMMON::Utils::rand(127, -127); } } vecset.reset(new BasicVectorSet(vec, GetEnumValueType(), m, n)); @@ -141,69 +167,100 @@ void GenerateData(std::shared_ptr& vecset, std::shared_ptrSaveMetadata("perftest_meta.bin", "perftest_metaidx.bin"); ByteArray query = ByteArray::Alloc(sizeof(T) * q * m); - for (SizeType i = 0; i < q; i++) { - for (DimensionType j = 0; j < m; j++) { - ((T*)query.Data())[i * m + j] = (T)COMMON::Utils::rand(127, -127); + for (SizeType i = 0; i < q; i++) + { + for (DimensionType j = 0; j < m; j++) + { + ((T *)query.Data())[i * m + j] = (T)COMMON::Utils::rand(127, -127); } } queryset.reset(new BasicVectorSet(query, GetEnumValueType(), m, q)); queryset->Save("perftest_query.bin"); } - - if (fileexists(("perftest_truth." + distCalcMethod).c_str())) { - std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), k, VectorFileType::DEFAULT)); + + if (fileexists(("perftest_truth." + distCalcMethod).c_str())) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), k, VectorFileType::DEFAULT)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile("perftest_truth." + distCalcMethod)) { - LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n"); exit(1); } truth = vectorReader->GetVectorSet(); } - else { - omp_set_num_threads(5); - + else + { DistCalcMethod distMethod; Helper::Convert::ConvertStringTo(distCalcMethod.c_str(), distMethod); - if (distMethod == DistCalcMethod::Cosine) { + if (distMethod == DistCalcMethod::Cosine) + { std::cout << "Normalize vecset!" << std::endl; - COMMON::Utils::BatchNormalize((T*)(vecset->GetData()), vecset->Count(), vecset->Dimension(), COMMON::Utils::GetBase(), 5); + COMMON::Utils::BatchNormalize((T *)(vecset->GetData()), vecset->Count(), vecset->Dimension(), + COMMON::Utils::GetBase(), 5); } ByteArray tru = ByteArray::Alloc(sizeof(float) * queryset->Count() * k); - #pragma omp parallel for - for (SizeType i = 0; i < queryset->Count(); ++i) + int testThreads = 5; + std::vector mythreads; + mythreads.reserve(testThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < testThreads; tid++) { - SizeType* neighbors = ((SizeType*)tru.Data()) + i * k; + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < queryset->Count()) + { + SizeType *neighbors = ((SizeType *)tru.Data()) + i * k; - COMMON::QueryResultSet res((const T*)queryset->GetVector(i), k); - for (SizeType j = 0; j < vecset->Count(); j++) - { - float dist = COMMON::DistanceUtils::ComputeDistance(res.GetTarget(), reinterpret_cast(vecset->GetVector(j)), queryset->Dimension(), distMethod); - res.AddPoint(j, dist); - } - res.SortResult(); - for (int j = 0; j < k; j++) neighbors[j] = res.GetResult(j)->VID; + COMMON::QueryResultSet res((const T *)queryset->GetVector(i), k); + for (SizeType j = 0; j < vecset->Count(); j++) + { + float dist = COMMON::DistanceUtils::ComputeDistance( + res.GetTarget(), reinterpret_cast(vecset->GetVector(j)), queryset->Dimension(), + distMethod); + res.AddPoint(j, dist); + } + res.SortResult(); + for (int j = 0; j < k; j++) + neighbors[j] = res.GetResult(j)->VID; + } + else + { + return; + } + } + }); } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + truth.reset(new BasicVectorSet(tru, GetEnumValueType(), k, queryset->Count())); truth->Save("perftest_truth." + distCalcMethod); } } -template -void PTest(IndexAlgoType algo, std::string distCalcMethod) +template void PTest(IndexAlgoType algo, std::string distCalcMethod) { std::shared_ptr vecset, queryset, truth; std::shared_ptr metaset; diff --git a/Test/src/ReconstructIndexSimilarityTest.cpp b/Test/src/ReconstructIndexSimilarityTest.cpp index 2da00493b..e7ca78066 100644 --- a/Test/src/ReconstructIndexSimilarityTest.cpp +++ b/Test/src/ReconstructIndexSimilarityTest.cpp @@ -1,88 +1,133 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" -#include -#include "inc/Helper/VectorSetReader.h" -#include "inc/Core/Common/PQQuantizer.h" -#include "inc/Core/VectorIndex.h" #include "inc/Core/Common/CommonUtils.h" -#include "inc/Core/Common/QueryResultSet.h" #include "inc/Core/Common/DistanceUtils.h" -#include +#include "inc/Core/Common/PQQuantizer.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/Test.h" +#include +#include #include +#include +#include #include -#include using namespace SPTAG; template -void Search(std::shared_ptr& vecIndex, std::shared_ptr& queryset, int k, std::shared_ptr& truth) +void Search(std::shared_ptr &vecIndex, std::shared_ptr &vecset, + std::shared_ptr &queryset, int k, std::shared_ptr &truth) { - std::vector> res(queryset->Count(), SPTAG::COMMON::QueryResultSet(nullptr, k * 2)); + std::vector> res(queryset->Count(), + COMMON::QueryResultSet(nullptr, k * 2)); auto t1 = std::chrono::high_resolution_clock::now(); for (SizeType i = 0; i < queryset->Count(); i++) { res[i].Reset(); - res[i].SetTarget((const T*)queryset->GetVector(i), vecIndex->m_pQuantizer); + res[i].SetTarget((const T *)queryset->GetVector(i), vecIndex->m_pQuantizer); vecIndex->SearchIndex(res[i]); } auto t2 = std::chrono::high_resolution_clock::now(); - std::cout << "Search time: " << (std::chrono::duration_cast(t2 - t1).count() / (float)(queryset->Count())) << "us" << std::endl; + std::cout << "Search time: " + << (std::chrono::duration_cast(t2 - t1).count() / (float)(queryset->Count())) + << "us" << std::endl; float eps = 1e-6f, recall = 0; int truthDimension = min(k, truth->Dimension()); - for (SizeType i = 0; i < queryset->Count(); i++) { - SizeType* nn = (SizeType*)(truth->GetVector(i)); + for (SizeType i = 0; i < queryset->Count(); i++) + { + SizeType *nn = (SizeType *)(truth->GetVector(i)); std::vector visited(2 * k, false); - for (int j = 0; j < truthDimension; j++) { - float truthdist = vecIndex->ComputeDistance(res[i].GetQuantizedTarget(), vecIndex->GetSample(nn[j])); - for (int l = 0; l < k*2; l++) { - if (visited[l]) continue; + for (int j = 0; j < truthDimension; j++) + { + float truthdist = vecIndex->ComputeDistance(res[i].GetQuantizedTarget(), vecset->GetVector(nn[j])); + for (int l = 0; l < k * 2; l++) + { + if (visited[l]) + continue; - //std::cout << res[i].GetResult(l)->Dist << " " << truthdist << std::endl; - if (res[i].GetResult(l)->VID == nn[j]) { - recall += 1.0; - visited[l] = true; - break; - } - else if (fabs(res[i].GetResult(l)->Dist - truthdist) <= eps) { + // std::cout << res[i].GetResult(l)->Dist << " " << truthdist << std::endl; + if (res[i].GetResult(l)->VID == nn[j]) + { recall += 1.0; visited[l] = true; break; } + // else if (fabs(res[i].GetResult(l)->Dist - truthdist) <= eps) { + // recall += 1.0; + // visited[l] = true; + // break; + // } } } } - LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", truthDimension, k*2, recall / queryset->Count() / truthDimension); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", truthDimension, k * 2, + recall / queryset->Count() / truthDimension); } - - -template -std::shared_ptr PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr& vec, std::shared_ptr& meta, std::shared_ptr& queryset, int k, std::shared_ptr& truth, std::string out, std::shared_ptr quantizer) +template +std::shared_ptr PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr &vec, + std::shared_ptr &meta, std::shared_ptr &queryset, int k, + std::shared_ptr &truth, std::string out, + std::shared_ptr quantizer) { - std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + std::shared_ptr vecIndex = VectorIndex::CreateInstance(algo, GetEnumValueType()); vecIndex->SetQuantizer(quantizer); BOOST_CHECK(nullptr != vecIndex); - if (algo == IndexAlgoType::KDT) vecIndex->SetParameter("KDTNumber", "2"); - vecIndex->SetParameter("DistCalcMethod", distCalcMethod); - vecIndex->SetParameter("NumberOfThreads", "12"); - vecIndex->SetParameter("RefineIterations", "3"); - vecIndex->SetParameter("MaxCheck", "4096"); - vecIndex->SetParameter("MaxCheckForRefineGraph", "8192"); + if (algo != IndexAlgoType::SPANN) + { + if (algo == IndexAlgoType::KDT) + vecIndex->SetParameter("KDTNumber", "2"); + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + vecIndex->SetParameter("NumberOfThreads", "16"); + vecIndex->SetParameter("RefineIterations", "3"); + vecIndex->SetParameter("MaxCheck", "4096"); + vecIndex->SetParameter("MaxCheckForRefineGraph", "8192"); + } + else + { + vecIndex->SetParameter("IndexAlgoType", "BKT", "Base"); + vecIndex->SetParameter("DistCalcMethod", distCalcMethod, "Base"); + + vecIndex->SetParameter("isExecute", "true", "SelectHead"); + vecIndex->SetParameter("NumberOfThreads", "4", "SelectHead"); + vecIndex->SetParameter("Ratio", "0.2", "SelectHead"); // vecIndex->SetParameter("Count", "200", "SelectHead"); + + vecIndex->SetParameter("isExecute", "true", "BuildHead"); + vecIndex->SetParameter("RefineIterations", "3", "BuildHead"); + vecIndex->SetParameter("MaxCheck", "4096"); + vecIndex->SetParameter("MaxCheckForRefineGraph", "8192"); + vecIndex->SetParameter("NumberOfThreads", "4", "BuildHead"); + + vecIndex->SetParameter("Storage", "FILEIO", "BuildSSDIndex"); + vecIndex->SetParameter("Update", "true", "BuildSSDIndex"); + vecIndex->SetParameter("StartFileSizeGB", "1", "BuildSSDIndex"); + vecIndex->SetParameter("isExecute", "true", "BuildSSDIndex"); + vecIndex->SetParameter("BuildSsdIndex", "true", "BuildSSDIndex"); + vecIndex->SetParameter("NumberOfThreads", "4", "BuildSSDIndex"); + vecIndex->SetParameter("PostingPageLimit", "12", "BuildSSDIndex"); + vecIndex->SetParameter("SearchPostingPageLimit", "12", "BuildSSDIndex"); + vecIndex->SetParameter("InternalResultNum", "64", "BuildSSDIndex"); + vecIndex->SetParameter("SearchInternalResultNum", "64", "BuildSSDIndex"); + } - BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true)); - BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out)); - //Search(vecIndex, queryset, k, truth); + BOOST_CHECK(ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true)); + BOOST_CHECK(ErrorCode::Success == vecIndex->SaveIndex(out)); + // Search(vecIndex, queryset, k, truth); return vecIndex; } template -void GenerateReconstructData(std::shared_ptr& real_vecset, std::shared_ptr& rec_vecset, std::shared_ptr& quan_vecset, std::shared_ptr& metaset, std::shared_ptr& queryset, std::shared_ptr& truth, DistCalcMethod distCalcMethod, int k, std::shared_ptr& quantizer) +void GenerateReconstructData(std::shared_ptr &real_vecset, std::shared_ptr &rec_vecset, + std::shared_ptr &quan_vecset, std::shared_ptr &metaset, + std::shared_ptr &queryset, std::shared_ptr &truth, + DistCalcMethod distCalcMethod, int k, std::shared_ptr &quantizer) { std::random_device rd; std::mt19937 gen(rd()); @@ -94,197 +139,310 @@ void GenerateReconstructData(std::shared_ptr& real_vecset, std::share int QuanDim = m / M; std::string CODEBOOK_FILE = "quantest_quantizer.bin"; - if (fileexists("quantest_vector.bin") && fileexists("quantest_query.bin")) { - std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); + if (fileexists("quantest_vector.bin") && fileexists("quantest_query.bin")) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile("quantest_vector.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); exit(1); } real_vecset = vectorReader->GetVectorSet(); if (ErrorCode::Success != vectorReader->LoadFile("quantest_query.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n"); exit(1); } queryset = vectorReader->GetVectorSet(); } - else { + else + { ByteArray real_vec = ByteArray::Alloc(sizeof(R) * n * m); - for (int i = 0; i < n * m; i++) { - ((R*)real_vec.Data())[i] = (R)(dist(gen)); + for (int i = 0; i < n * m; i++) + { + ((R *)real_vec.Data())[i] = (R)(dist(gen)); } real_vecset.reset(new BasicVectorSet(real_vec, GetEnumValueType(), m, n)); real_vecset->Save("quantest_vector.bin"); ByteArray real_query = ByteArray::Alloc(sizeof(R) * q * m); - for (int i = 0; i < q * m; i++) { - ((R*)real_query.Data())[i] = (R)(dist(gen)); + for (int i = 0; i < q * m; i++) + { + ((R *)real_query.Data())[i] = (R)(dist(gen)); } queryset.reset(new BasicVectorSet(real_query, GetEnumValueType(), m, q)); queryset->Save("quantest_query.bin"); } - if (fileexists(("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)).c_str())) { - std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), k, VectorFileType::DEFAULT)); + if (fileexists(("quantest_truth." + Helper::Convert::ConvertToString(distCalcMethod)).c_str())) + { + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), k, VectorFileType::DEFAULT)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); - if (ErrorCode::Success != vectorReader->LoadFile("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod))) + if (ErrorCode::Success != + vectorReader->LoadFile("quantest_truth." + Helper::Convert::ConvertToString(distCalcMethod))) { - LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n"); exit(1); } truth = vectorReader->GetVectorSet(); } - else { - omp_set_num_threads(5); - + else + { ByteArray tru = ByteArray::Alloc(sizeof(SizeType) * queryset->Count() * k); -#pragma omp parallel for - for (SizeType i = 0; i < queryset->Count(); ++i) + int testThreads = 5; + std::vector mythreads; + mythreads.reserve(testThreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < testThreads; tid++) { - SizeType* neighbors = ((SizeType*)tru.Data()) + i * k; - - COMMON::QueryResultSet res((const R*)queryset->GetVector(i), k); - for (SizeType j = 0; j < real_vecset->Count(); j++) - { - float dist = COMMON::DistanceUtils::ComputeDistance(res.GetTarget(), reinterpret_cast(real_vecset->GetVector(j)), queryset->Dimension(), distCalcMethod); - res.AddPoint(j, dist); - } - res.SortResult(); - for (int j = 0; j < k; j++) neighbors[j] = res.GetResult(j)->VID; + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < queryset->Count()) + { + SizeType *neighbors = ((SizeType *)tru.Data()) + i * k; + + COMMON::QueryResultSet res((const R *)queryset->GetVector(i), k); + for (SizeType j = 0; j < real_vecset->Count(); j++) + { + float dist = COMMON::DistanceUtils::ComputeDistance( + res.GetTarget(), reinterpret_cast(real_vecset->GetVector(j)), + queryset->Dimension(), distCalcMethod); + res.AddPoint(j, dist); + } + res.SortResult(); + for (int j = 0; j < k; j++) + neighbors[j] = res.GetResult(j)->VID; + } + else + { + return; + } + } + }); } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + truth.reset(new BasicVectorSet(tru, GetEnumValueType(), k, queryset->Count())); - truth->Save("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)); + truth->Save("quantest_truth." + Helper::Convert::ConvertToString(distCalcMethod)); } - if (fileexists(CODEBOOK_FILE.c_str()) && fileexists("quantest_quan_vector.bin") && fileexists("quantest_rec_vector.bin")) { - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { + if (fileexists(CODEBOOK_FILE.c_str()) && fileexists("quantest_quan_vector.bin") && + fileexists("quantest_rec_vector.bin")) + { + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) + { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } - quantizer->LoadIQuantizer(ptr); + quantizer = COMMON::IQuantizer::LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); - std::shared_ptr options(new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); + std::shared_ptr options( + new Helper::ReaderOptions(GetEnumValueType(), m, VectorFileType::DEFAULT)); auto vectorReader = Helper::VectorSetReader::CreateInstance(options); if (ErrorCode::Success != vectorReader->LoadFile("quantest_rec_vector.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); exit(1); } rec_vecset = vectorReader->GetVectorSet(); - std::shared_ptr quanOptions(new Helper::ReaderOptions(GetEnumValueType(), M, VectorFileType::DEFAULT)); + std::shared_ptr quanOptions( + new Helper::ReaderOptions(GetEnumValueType(), M, VectorFileType::DEFAULT)); vectorReader = Helper::VectorSetReader::CreateInstance(quanOptions); if (ErrorCode::Success != vectorReader->LoadFile("quantest_quan_vector.bin")) { - LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n"); exit(1); } quan_vecset = vectorReader->GetVectorSet(); } - else { - omp_set_num_threads(16); - + else + { + int testThreads = 16; std::cout << "Building codebooks!" << std::endl; - R* vecs = (R*)(real_vecset->GetData()); + R *vecs = (R *)(real_vecset->GetData()); std::unique_ptr codebooks = std::make_unique(M * Ks * QuanDim); std::unique_ptr belong(new int[n]); - for (int i = 0; i < M; i++) { - R* kmeans = codebooks.get() + i * Ks * QuanDim; - for (int j = 0; j < Ks; j++) { + for (int i = 0; i < M; i++) + { + R *kmeans = codebooks.get() + i * Ks * QuanDim; + for (int j = 0; j < Ks; j++) + { std::memcpy(kmeans + j * QuanDim, vecs + j * m + i * QuanDim, sizeof(R) * QuanDim); } int cnt = 100; - while (cnt--) { - //calculate cluster -#pragma omp parallel for - for (int ii = 0; ii < n; ii++) { - double min_dis = 1e9; - int min_id = 0; - for (int jj = 0; jj < Ks; jj++) { - double now_dis = COMMON::DistanceUtils::ComputeDistance(vecs + ii * m + i * QuanDim, kmeans + jj * QuanDim, QuanDim, DistCalcMethod::L2); - if (now_dis < min_dis) { - min_dis = now_dis; - min_id = jj; - } + while (cnt--) + { + // calculate cluster + std::vector mythreads; + mythreads.reserve(testThreads); + { + std::atomic_size_t sent(0); + for (int tid = 0; tid < testThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t ii = 0; + while (true) + { + ii = sent.fetch_add(1); + if (ii < n) + { + double min_dis = 1e9; + int min_id = 0; + for (int jj = 0; jj < Ks; jj++) + { + double now_dis = COMMON::DistanceUtils::ComputeDistance( + vecs + ii * m + i * QuanDim, kmeans + jj * QuanDim, QuanDim, + DistCalcMethod::L2); + if (now_dis < min_dis) + { + min_dis = now_dis; + min_id = jj; + } + } + belong[ii] = min_id; + } + else + { + return; + } + } + }); } - belong[ii] = min_id; + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); } - //recalculate kmeans + + // recalculate kmeans std::memset(kmeans, 0, sizeof(R) * Ks * QuanDim); -#pragma omp parallel for - for (int ii = 0; ii < Ks; ii++) { - int num = 0; - for (int jj = 0; jj < n; jj++) { - if (belong[jj] == ii) { - num++; - for (int kk = 0; kk < QuanDim; kk++) { - kmeans[ii * QuanDim + kk] += vecs[jj * m + i * QuanDim + kk]; + + std::atomic_size_t sent(0); + for (int tid = 0; tid < testThreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t ii = 0; + while (true) + { + ii = sent.fetch_add(1); + if (ii < Ks) + { + int num = 0; + for (int jj = 0; jj < n; jj++) + { + if (belong[jj] == ii) + { + num++; + for (int kk = 0; kk < QuanDim; kk++) + { + kmeans[ii * QuanDim + kk] += vecs[jj * m + i * QuanDim + kk]; + } + } + } + for (int jj = 0; jj < QuanDim; jj++) + { + kmeans[ii * QuanDim + jj] /= num; + } + } + else + { + return; } } - } - for (int jj = 0; jj < QuanDim; jj++) { - kmeans[ii * QuanDim + jj] /= num; - } + }); + } + for (auto &t : mythreads) + { + t.join(); } + mythreads.clear(); } } std::cout << "Building Finish!" << std::endl; - quantizer = std::make_shared>(M, Ks, QuanDim, false, std::move(codebooks)); - auto ptr = SPTAG::f_createIO(); - if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::out)) { + quantizer = std::make_shared>(M, Ks, QuanDim, false, std::move(codebooks)); + auto ptr = f_createIO(); + if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::out)) + { BOOST_ASSERT("Canot Open CODEBOOK_FILE to write!" == "Error"); } quantizer->SaveQuantizer(ptr); ptr->ShutDown(); - if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) { + if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) + { BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error"); } quantizer->LoadIQuantizer(ptr); BOOST_ASSERT(quantizer); rec_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(R) * n * m), GetEnumValueType(), m, n)); - quan_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(std::uint8_t) * n * M), GetEnumValueType(), M, n)); - for (int i = 0; i < n; i++) { + quan_vecset.reset( + new BasicVectorSet(ByteArray::Alloc(sizeof(std::uint8_t) * n * M), GetEnumValueType(), M, n)); + for (int i = 0; i < n; i++) + { auto nvec = &vecs[i * m]; - quantizer->QuantizeVector(nvec, (uint8_t*)quan_vecset->GetVector(i)); - quantizer->ReconstructVector((uint8_t*)quan_vecset->GetVector(i), rec_vecset->GetVector(i)); + quantizer->QuantizeVector(nvec, (uint8_t *)quan_vecset->GetVector(i)); + quantizer->ReconstructVector((uint8_t *)quan_vecset->GetVector(i), rec_vecset->GetVector(i)); } quan_vecset->Save("quantest_quan_vector.bin"); rec_vecset->Save("quantest_rec_vector.bin"); } } -template -void ReconstructTest(IndexAlgoType algo, DistCalcMethod distMethod) +template void ReconstructTest(IndexAlgoType algo, DistCalcMethod distMethod) { std::shared_ptr real_vecset, rec_vecset, quan_vecset, queryset, truth; std::shared_ptr metaset; std::shared_ptr quantizer; - GenerateReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, quantizer); - //LoadReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10); - - auto real_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), real_vecset, metaset, queryset, 10, truth, "real_idx", nullptr); - Search(real_idx, queryset, 10, truth); - auto rec_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), rec_vecset, metaset, queryset, 10, truth, "rec_idx", nullptr); - Search(rec_idx, queryset, 10, truth); - auto quan_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), quan_vecset, metaset, queryset, 10, truth, "quan_idx", quantizer); - - LOG(Helper::LogLevel::LL_Info, "Test search with SDC"); - Search(quan_idx, queryset, 10, truth); - - LOG(Helper::LogLevel::LL_Info, "Test search with ADC"); + GenerateReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, + quantizer); + // LoadReconstructData(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10); + + auto real_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), real_vecset, + metaset, queryset, 10, truth, "real_idx", nullptr); + Search(real_idx, real_vecset, queryset, 10, truth); + real_idx = nullptr; + std::filesystem::remove_all("SPANN"); + + auto rec_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), rec_vecset, metaset, + queryset, 10, truth, "rec_idx", nullptr); + Search(rec_idx, real_vecset, queryset, 10, truth); + rec_idx = nullptr; + std::filesystem::remove_all("SPANN"); + + auto quan_idx = PerfBuild(algo, Helper::Convert::ConvertToString(distMethod), + quan_vecset, metaset, queryset, 10, truth, "quan_idx", quantizer); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with SDC"); + Search(quan_idx, real_vecset, queryset, 10, truth); + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with ADC"); quan_idx->SetQuantizerADC(true); - Search(quan_idx, queryset, 10, truth); -} + Search(quan_idx, real_vecset, queryset, 10, truth); + quan_idx = nullptr; + std::filesystem::remove_all("SPANN"); + std::filesystem::remove_all("real_idx"); + std::filesystem::remove_all("rec_idx"); + std::filesystem::remove_all("quan_idx"); +} BOOST_AUTO_TEST_SUITE(ReconstructIndexSimilarityTest) @@ -292,14 +450,18 @@ BOOST_AUTO_TEST_CASE(BKTReconstructTest) { ReconstructTest(IndexAlgoType::BKT, DistCalcMethod::L2); - } BOOST_AUTO_TEST_CASE(KDTReconstructTest) { ReconstructTest(IndexAlgoType::KDT, DistCalcMethod::L2); +} + +BOOST_AUTO_TEST_CASE(SPANNReconstructTest) +{ + ReconstructTest(IndexAlgoType::SPANN, DistCalcMethod::L2); } BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/SIMDTest.cpp b/Test/src/SIMDTest.cpp index d58f9298b..b47594929 100644 --- a/Test/src/SIMDTest.cpp +++ b/Test/src/SIMDTest.cpp @@ -2,42 +2,43 @@ // Licensed under the MIT License. // #include -#include -#include "inc/Test.h" #include "inc/Core/Common/SIMDUtils.h" +#include "inc/Test.h" +#include - -template -static void ComputeSum(T *pX, const T *pY, SPTAG::DimensionType length) +template static void ComputeSum(T *pX, const T *pY, SPTAG::DimensionType length) { - const T* pEnd1 = pX + length; - while (pX < pEnd1) { + const T *pEnd1 = pX + length; + while (pX < pEnd1) + { *pX++ += *pY++; } } -template -T random(int high = RAND_MAX, int low = 0) // Generates a random value. +template T random(int high = RAND_MAX, int low = 0) // Generates a random value. { - return (T)(low + float(high - low)*(std::rand()/static_cast(RAND_MAX + 1.0))); + return (T)(low + float(high - low) * (std::rand() / static_cast(RAND_MAX + 1.0))); } -template -void test(int high) { +template void test(int high) +{ SPTAG::DimensionType dimension = random(256, 2); T *X = new T[dimension], *Y = new T[dimension]; BOOST_ASSERT(X != nullptr && Y != nullptr); - for (SPTAG::DimensionType i = 0; i < dimension; i++) { + for (SPTAG::DimensionType i = 0; i < dimension; i++) + { X[i] = random(high, -high); Y[i] = random(high, -high); } T *X_copy = new T[dimension]; - for (SPTAG::DimensionType i = 0; i < dimension; i++) { + for (SPTAG::DimensionType i = 0; i < dimension; i++) + { X_copy[i] = X[i]; } ComputeSum(X, Y, dimension); SPTAG::COMMON::SIMDUtils::ComputeSum(X_copy, Y, dimension); - for (SPTAG::DimensionType i = 0; i < dimension; i++) { + for (SPTAG::DimensionType i = 0; i < dimension; i++) + { BOOST_CHECK_CLOSE_FRACTION(double(X[i]), double(X_copy[i]), 1e-5); } @@ -55,5 +56,4 @@ BOOST_AUTO_TEST_CASE(TestDistanceComputation) test(32767); } - BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/SPFreshTest.cpp b/Test/src/SPFreshTest.cpp new file mode 100644 index 000000000..639f1c709 --- /dev/null +++ b/Test/src/SPFreshTest.cpp @@ -0,0 +1,1283 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/Common/DistanceUtils.h" +#include "inc/Core/Common/QueryResultSet.h" +#include "inc/Core/SPANN/Index.h" +#include "inc/Core/SPANN/SPANNResultIterator.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Helper/DiskIO.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/VectorSetReader.h" +#include "inc/Helper/StringConvert.h" +#include "inc/Test.h" +#include "inc/TestDataGenerator.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace SPTAG; + +namespace SPFreshTest +{ +SizeType N = 10000; +DimensionType M = 100; +int K = 10; +int queries = 10; + +template +std::shared_ptr BuildIndex(const std::string &outDirectory, std::shared_ptr vecset, + std::shared_ptr metaset, const std::string &distMethod = "L2") +{ + auto vecIndex = VectorIndex::CreateInstance(IndexAlgoType::SPANN, GetEnumValueType()); + + std::string configuration = R"( + [Base] + DistCalcMethod=L2 + IndexAlgoType=BKT + ValueType=)" + Helper::Convert::ConvertToString(GetEnumValueType()) + + R"( + Dim=)" + std::to_string(M) + + R"( + IndexDirectory=)" + outDirectory + + R"( + + [SelectHead] + isExecute=true + NumberOfThreads=16 + SelectThreshold=0 + SplitFactor=0 + SplitThreshold=0 + Ratio=0.2 + + [BuildHead] + isExecute=true + NumberOfThreads=16 + + [BuildSSDIndex] + isExecute=true + BuildSsdIndex=true + InternalResultNum=64 + SearchInternalResultNum=64 + NumberOfThreads=16 + PostingPageLimit=)" + std::to_string(4 * sizeof(T)) + + R"( + SearchPostingPageLimit=)" + + std::to_string(4 * sizeof(T)) + R"( + TmpDir=tmpdir + Storage=FILEIO + SpdkBatchSize=64 + ExcludeHead=false + ResultNum=10 + SearchThreadNum=2 + Update=true + SteadyState=true + InsertThreadNum=1 + AppendThreadNum=1 + ReassignThreadNum=0 + DisableReassign=false + ReassignK=64 + LatencyLimit=50.0 + SearchDuringUpdate=true + MergeThreshold=10 + Sampling=4 + BufferLength=6 + InPlace=true + StartFileSizeGB=1 + OneClusterCutMax=true + ConsistencyCheck=true + ChecksumCheck=true + ChecksumInRead=false + DeletePercentageForRefine=0.4 + AsyncAppendQueueSize=0 + AllowZeroReplica=false + )"; + + std::shared_ptr buffer(new Helper::SimpleBufferIO()); + Helper::IniReader reader; + if (!buffer->Initialize(configuration.data(), std::ios::in, configuration.size())) + return nullptr; + if (ErrorCode::Success != reader.LoadIni(buffer)) + return nullptr; + + std::string sections[] = {"Base", "SelectHead", "BuildHead", "BuildSSDIndex"}; + for (const auto &sec : sections) + { + auto params = reader.GetParameters(sec.c_str()); + for (const auto &[key, val] : params) + { + vecIndex->SetParameter(key.c_str(), val.c_str(), sec.c_str()); + } + } + + auto buildStatus = vecIndex->BuildIndex(vecset, metaset, true, false, false); + if (buildStatus != ErrorCode::Success) + return nullptr; + + return vecIndex; +} + +template +std::vector SearchOnly(std::shared_ptr &vecIndex, std::shared_ptr &queryset, int k) +{ + std::vector res(queryset->Count(), QueryResult(nullptr, k, true)); + + auto t1 = std::chrono::high_resolution_clock::now(); + for (SizeType i = 0; i < queryset->Count(); i++) + { + res[i].SetTarget(queryset->GetVector(i)); + vecIndex->SearchIndex(res[i]); + } + auto t2 = std::chrono::high_resolution_clock::now(); + + float avgUs = + std::chrono::duration_cast(t2 - t1).count() / static_cast(queryset->Count()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Avg search time: %.2fus/query\n", avgUs); + + return res; +} + +template +float EvaluateRecall(const std::vector &res, std::shared_ptr &vecIndex, + std::shared_ptr &queryset, std::shared_ptr &truth, + std::shared_ptr &baseVec, std::shared_ptr &addVec, SizeType baseCount, int k, int batch) +{ + if (!truth) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Truth data is null. Cannot compute recall.\n"); + return 0.0f; + } + + const SizeType recallK = min(k, static_cast(truth->Dimension())); + float totalRecall = 0.0f; + float eps = 1e-4f; + + for (SizeType i = 0; i < queryset->Count(); ++i) + { + const SizeType *truthNN = reinterpret_cast(truth->GetVector(i + batch * queryset->Count())); + for (int j = 0; j < recallK; ++j) + { + SizeType truthVid = truthNN[j]; + float truthDist = + (truthVid < baseCount) + ? vecIndex->ComputeDistance(queryset->GetVector(i), baseVec->GetVector(truthVid)) + : vecIndex->ComputeDistance(queryset->GetVector(i), addVec->GetVector(truthVid - baseCount)); + + for (int l = 0; l < k; ++l) + { + const auto result = res[i].GetResult(l); + if (truthVid == result->VID || + std::fabs(truthDist - result->Dist) <= eps * (std::fabs(truthDist) + eps)) + { + totalRecall += 1.0f; + break; + } + } + } + } + + float avgRecall = totalRecall / (queryset->Count() * recallK); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall %d@%d = %.4f\n", k, recallK, avgRecall); + return avgRecall; +} + +template +float Search(std::shared_ptr &vecIndex, std::shared_ptr &queryset, + std::shared_ptr &baseVec, std::shared_ptr &addVec, int k, + std::shared_ptr &truth, SizeType baseCount, int batch = 0) +{ + auto results = SearchOnly(vecIndex, queryset, k); + return EvaluateRecall(results, vecIndex, queryset, truth, baseVec, addVec, baseCount, k, batch); +} + +template +void InsertVectors(SPANN::Index *p_index, int insertThreads, int step, + std::shared_ptr addset, std::shared_ptr &metaset, int start = 0) +{ + SPANN::Options &p_opts = *(p_index->GetOptions()); + p_index->ForceCompaction(); + p_index->GetDBStat(); + + std::vector threads; + + std::atomic_size_t vectorsSent(start); + auto func = [&]() { + size_t index = start; + while (true) + { + index = vectorsSent.fetch_add(1); + if (index < start + step) + { + if ((index & ((1 << 5) - 1)) == 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Sent %.2lf%%...\n", index * 100.0 / step); + } + ByteArray p_meta = metaset->GetMetadata((SizeType)index); + std::uint64_t *offsets = new std::uint64_t[2]{0, p_meta.Length()}; + std::shared_ptr meta(new MemMetadataSet( + p_meta, ByteArray((std::uint8_t *)offsets, 2 * sizeof(std::uint64_t), true), 1)); + auto status = p_index->AddIndex(addset->GetVector((SizeType)index), 1, p_opts.m_dim, meta, true); + BOOST_REQUIRE(status == ErrorCode::Success); + } + else + { + return; + } + } + }; + for (int j = 0; j < insertThreads; j++) + { + threads.emplace_back(func); + } + for (auto &thread : threads) + { + thread.join(); + } + + while (!p_index->AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } +} +} // namespace SPFreshTest + +bool CompareFilesWithLogging(const std::filesystem::path &file1, const std::filesystem::path &file2) +{ + std::ifstream f1(file1, std::ios::binary); + std::ifstream f2(file2, std::ios::binary); + + if (!f1.is_open() || !f2.is_open()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to open one of the files:\n %s\n %s\n", + file1.string().c_str(), file2.string().c_str()); + return false; + } + + // Check file sizes first + f1.seekg(0, std::ios::end); + f2.seekg(0, std::ios::end); + if (f1.tellg() != f2.tellg()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File size differs: %s\n", file1.filename().string().c_str()); + return false; + } + + f1.seekg(0, std::ios::beg); + f2.seekg(0, std::ios::beg); + + const int bufferSize = 4096; // Adjust buffer size as needed + std::vector buffer1(bufferSize); + std::vector buffer2(bufferSize); + + while (f1.read(buffer1.data(), bufferSize) && f2.read(buffer2.data(), bufferSize)) + { + if (std::memcmp(buffer1.data(), buffer2.data(), f1.gcount()) != 0) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File mismatch at: %s\n", file1.filename().string().c_str()); + return false; // Mismatch found + } + } + + return true; +} + +bool CompareDirectoriesWithLogging(const std::filesystem::path &dir1, const std::filesystem::path &dir2, + const std::unordered_set &exceptions = {}) +{ + std::map files1, files2; + + for (const auto &entry : std::filesystem::recursive_directory_iterator(dir1)) + { + if (entry.is_regular_file()) + { + files1[std::filesystem::relative(entry.path(), dir1).string()] = entry.path(); + } + } + + for (const auto &entry : std::filesystem::recursive_directory_iterator(dir2)) + { + if (entry.is_regular_file()) + { + files2[std::filesystem::relative(entry.path(), dir2).string()] = entry.path(); + } + } + + bool matched = true; + + for (const auto &[relPath, filePath1] : files1) + { + if (exceptions.count(relPath)) + continue; + + auto it = files2.find(relPath); + if (it == files2.end()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Missing in %s: %s\n", dir2.string().c_str(), relPath.c_str()); + matched = false; + continue; + } + if (!CompareFilesWithLogging(filePath1, it->second)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "File end differs: %s\n", filePath1.filename().string().c_str()); + matched = false; + } + } + + for (const auto &[relPath, _] : files2) + { + if (exceptions.count(relPath)) + continue; + if (files1.find(relPath) == files1.end()) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Extra in %s: %s\n", dir2.string().c_str(), relPath.c_str()); + matched = false; + } + } + + return matched; +} + +void NormalizeVector(float *embedding, int dimension) +{ + // get magnitude + float magnitude = 0.0f; + { + float sum = 0.0; + for (int i = 0; i < dimension; i++) + { + sum += embedding[i] * embedding[i]; + } + magnitude = std::sqrt(sum); + } + + // normalized target vector + for (int i = 0; i < dimension; i++) + { + embedding[i] /= magnitude; + } +} + +template +std::shared_ptr get_embeddings(uint32_t row_id, uint32_t end_id, uint32_t embedding_dim, + uint32_t array_index) +{ + uint32_t count = end_id - row_id; + ByteArray vec = ByteArray::Alloc(sizeof(T) * count * embedding_dim); + for (uint32_t rid = 0; rid < count; rid++) + { + for (int idx = 0; idx < embedding_dim; ++idx) + { + ((T *)vec.Data())[rid * embedding_dim + idx] = (row_id + rid) * 17 + idx * 19 + (array_index + 1) * 23; + } + NormalizeVector(((T *)vec.Data()) + rid * embedding_dim, embedding_dim); + } + return std::make_shared(vec, GetEnumValueType(), embedding_dim, count); +} + +BOOST_AUTO_TEST_SUITE(SPFreshTest) + +BOOST_AUTO_TEST_CASE(TestLoadAndSave) +{ + using namespace SPFreshTest; + + // Prepare test data using TestDataGenerator + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + std::shared_ptr loadedIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex("original_index", loadedIndex) == ErrorCode::Success); + BOOST_REQUIRE(loadedIndex != nullptr); + BOOST_REQUIRE(loadedIndex->SaveIndex("loaded_and_saved_index") == ErrorCode::Success); + loadedIndex = nullptr; + + std::unordered_set exceptions = {"indexloader.ini"}; + + // Compare files in both directories + BOOST_REQUIRE_MESSAGE(CompareDirectoriesWithLogging("original_index", "loaded_and_saved_index", exceptions), + "Saved index does not match loaded-then-saved index"); + + std::filesystem::remove_all("original_index"); + std::filesystem::remove_all("loaded_and_saved_index"); +} + +BOOST_AUTO_TEST_CASE(TestReopenIndexRecall) +{ + using namespace SPFreshTest; + + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + float recall1 = Search(originalIndex, queryset, vecset, addvecset, K, truth, N); + originalIndex = nullptr; + + std::shared_ptr loadedOnce; + BOOST_REQUIRE(VectorIndex::LoadIndex("original_index", loadedOnce) == ErrorCode::Success); + BOOST_REQUIRE(loadedOnce != nullptr); + BOOST_REQUIRE(loadedOnce->SaveIndex("reopened_index") == ErrorCode::Success); + loadedOnce = nullptr; + + std::shared_ptr loadedTwice; + BOOST_REQUIRE(VectorIndex::LoadIndex("reopened_index", loadedTwice) == ErrorCode::Success); + BOOST_REQUIRE(loadedTwice != nullptr); + float recall2 = Search(loadedTwice, queryset, vecset, addvecset, K, truth, N); + loadedTwice = nullptr; + + BOOST_REQUIRE_MESSAGE(std::fabs(recall1 - recall2) < 1e-5, "Recall mismatch between original and reopened index"); + + std::filesystem::remove_all("original_index"); + std::filesystem::remove_all("reopened_index"); +} + +BOOST_AUTO_TEST_CASE(TestInsertAndSearch) +{ + using namespace SPFreshTest; + + // Prepare test data using TestDataGenerator + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + // Build base index + auto index = BuildIndex("insert_test_index", vecset, metaset); + BOOST_REQUIRE(index != nullptr); + BOOST_REQUIRE(index->SaveIndex("insert_test_index") == ErrorCode::Success); + index = nullptr; + + std::shared_ptr loadedOnce; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_test_index", loadedOnce) == ErrorCode::Success); + BOOST_REQUIRE(loadedOnce != nullptr); + + InsertVectors(static_cast *>(loadedOnce.get()), 2, 1000, addvecset, addmetaset); + SearchOnly(loadedOnce, queryset, K); + loadedOnce = nullptr; + + std::filesystem::remove_all("insert_test_index"); +} + +BOOST_AUTO_TEST_CASE(TestClone) +{ + using namespace SPFreshTest; + + // Prepare test data using TestDataGenerator + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + + auto clonedIndex = originalIndex->Clone("cloned_index"); + BOOST_REQUIRE(clonedIndex != nullptr); + BOOST_REQUIRE(clonedIndex->SaveIndex("cloned_index") == ErrorCode::Success); + originalIndex.reset(); + clonedIndex = nullptr; + + std::unordered_set exceptions = {"indexloader.ini"}; + + // Compare files in both directories + BOOST_REQUIRE_MESSAGE(CompareDirectoriesWithLogging("original_index", "cloned_index", exceptions), + "Saved index does not match loaded-then-saved index"); + + std::filesystem::remove_all("original_index"); + std::filesystem::remove_all("cloned_index"); +} + +BOOST_AUTO_TEST_CASE(TestCloneRecall) +{ + using namespace SPFreshTest; + + // Prepare test data using TestDataGenerator + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + float originalRecall = Search(originalIndex, queryset, vecset, addvecset, K, truth, N); + + auto clonedIndex = originalIndex->Clone("cloned_index"); + BOOST_REQUIRE(clonedIndex != nullptr); + originalIndex.reset(); + clonedIndex = nullptr; + + std::shared_ptr loadedClonedIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex("cloned_index", loadedClonedIndex) == ErrorCode::Success); + BOOST_REQUIRE(loadedClonedIndex != nullptr); + float clonedRecall = Search(loadedClonedIndex, queryset, vecset, addvecset, K, truth, N); + loadedClonedIndex = nullptr; + + BOOST_REQUIRE_MESSAGE(std::fabs(originalRecall - clonedRecall) < 1e-5, + "Recall mismatch between original and cloned index: " + << "original=" << originalRecall << ", cloned=" << clonedRecall); + + std::filesystem::remove_all("original_index"); + std::filesystem::remove_all("cloned_index"); +} + +BOOST_AUTO_TEST_CASE(IndexPersistenceAndInsertSanity) +{ + using namespace SPFreshTest; + + // Prepare test data + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + // Build and save base index + auto baseIndex = BuildIndex("insert_test_index", vecset, metaset); + BOOST_REQUIRE(baseIndex != nullptr); + BOOST_REQUIRE(baseIndex->SaveIndex("insert_test_index") == ErrorCode::Success); + baseIndex = nullptr; + + // Load the saved index + std::shared_ptr loadedOnce; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_test_index", loadedOnce) == ErrorCode::Success); + BOOST_REQUIRE(loadedOnce != nullptr); + + // Search sanity check + SearchOnly(loadedOnce, queryset, K); + + // Clone the loaded index + auto clonedIndex = loadedOnce->Clone("insert_cloned_index"); + BOOST_REQUIRE(clonedIndex != nullptr); + + // Save and reload the cloned index + BOOST_REQUIRE(clonedIndex->SaveIndex("insert_cloned_index") == ErrorCode::Success); + loadedOnce.reset(); + clonedIndex = nullptr; + + std::shared_ptr loadedClone; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_cloned_index", loadedClone) == ErrorCode::Success); + BOOST_REQUIRE(loadedClone != nullptr); + + // Insert new vectors + InsertVectors(static_cast *>(loadedClone.get()), 1, + static_cast(addvecset->Count()), addvecset, addmetaset); + + // Final save and reload after insert + BOOST_REQUIRE(loadedClone->SaveIndex("insert_final_index") == ErrorCode::Success); + loadedClone = nullptr; + + std::shared_ptr reloadedFinal; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_final_index", reloadedFinal) == ErrorCode::Success); + + // Final search sanity + SearchOnly(reloadedFinal, queryset, K); + reloadedFinal = nullptr; + + // Cleanup + std::filesystem::remove_all("insert_test_index"); + std::filesystem::remove_all("insert_cloned_index"); + std::filesystem::remove_all("insert_final_index"); +} + +BOOST_AUTO_TEST_CASE(IndexPersistenceAndInsertMultipleThreads) +{ + using namespace SPFreshTest; + + // Prepare test data + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + // Build and save base index + auto baseIndex = BuildIndex("insert_test_index_multi", vecset, metaset); + BOOST_REQUIRE(baseIndex != nullptr); + BOOST_REQUIRE(baseIndex->SaveIndex("insert_test_index_multi") == ErrorCode::Success); + baseIndex = nullptr; + + // Load the saved index + std::shared_ptr loadedOnce; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_test_index_multi", loadedOnce) == ErrorCode::Success); + BOOST_REQUIRE(loadedOnce != nullptr); + + // Search sanity check + SearchOnly(loadedOnce, queryset, K); + + // Clone the loaded index + auto clonedIndex = loadedOnce->Clone("insert_cloned_index_multi"); + BOOST_REQUIRE(clonedIndex != nullptr); + + // Save and reload the cloned index + BOOST_REQUIRE(clonedIndex->SaveIndex("insert_cloned_index_multi") == ErrorCode::Success); + loadedOnce.reset(); + clonedIndex = nullptr; + + std::shared_ptr loadedClone; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_cloned_index_multi", loadedClone) == ErrorCode::Success); + BOOST_REQUIRE(loadedClone != nullptr); + + // Insert new vectors + InsertVectors(static_cast *>(loadedClone.get()), 2, + static_cast(addvecset->Count()), addvecset, addmetaset); + + // Final save and reload after insert + BOOST_REQUIRE(loadedClone->SaveIndex("insert_final_index_multi") == ErrorCode::Success); + loadedClone = nullptr; + + std::shared_ptr reloadedFinal; + BOOST_REQUIRE(VectorIndex::LoadIndex("insert_final_index_multi", reloadedFinal) == ErrorCode::Success); + BOOST_REQUIRE(reloadedFinal != nullptr); + // Final search sanity + SearchOnly(reloadedFinal, queryset, K); + reloadedFinal = nullptr; + + // Cleanup + std::filesystem::remove_all("insert_test_index_multi"); + std::filesystem::remove_all("insert_cloned_index_multi"); + std::filesystem::remove_all("insert_final_index_multi"); +} + +BOOST_AUTO_TEST_CASE(IndexSaveDuringQuery) +{ + using namespace SPFreshTest; + + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + auto index = BuildIndex("save_during_query_index", vecset, metaset); + BOOST_REQUIRE(index != nullptr); + + std::atomic keepQuerying(true); + std::thread queryThread([&]() { + while (keepQuerying) + { + for (int q = 0; q < queryset->Count(); ++q) + { + QueryResult result(queryset->GetVector(q), K, true); + index->SearchIndex(result); + } + } + }); + + // Wait a bit before saving + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + ErrorCode saveStatus = index->SaveIndex("save_during_query_index"); + BOOST_REQUIRE(saveStatus == ErrorCode::Success); + + keepQuerying = false; + queryThread.join(); + + index = nullptr; + + std::shared_ptr reloaded; + BOOST_REQUIRE(VectorIndex::LoadIndex("save_during_query_index", reloaded) == ErrorCode::Success); + BOOST_REQUIRE(reloaded != nullptr); + + SearchOnly(reloaded, queryset, K); + reloaded = nullptr; + + std::filesystem::remove_all("save_during_query_index"); +} + +BOOST_AUTO_TEST_CASE(IndexMultiThreadedQuerySanity) +{ + using namespace SPFreshTest; + + // Generate test data + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + // Build and save index + auto index = BuildIndex("multi_query_index", vecset, metaset); + BOOST_REQUIRE(index != nullptr); + BOOST_REQUIRE(index->SaveIndex("multi_query_index") == ErrorCode::Success); + index = nullptr; + + // Reload the index + std::shared_ptr loaded; + BOOST_REQUIRE(VectorIndex::LoadIndex("multi_query_index", loaded) == ErrorCode::Success); + BOOST_REQUIRE(loaded != nullptr); + + // Insert additional vectors + InsertVectors(static_cast *>(loaded.get()), 2, + static_cast(addvecset->Count()), addvecset, addmetaset); + + // Perform multithreaded query + const int threadCount = 4; + std::vector threads; + std::atomic nextQuery(0); + std::atomic completedQueries(0); + + for (int t = 0; t < threadCount; ++t) + { + threads.emplace_back([&, t]() { + QueryResult result(nullptr, K, true); + while (true) + { + int i = nextQuery.fetch_add(1); + if (i >= queryset->Count()) + break; + + result.SetTarget(queryset->GetVector(static_cast(i))); + loaded->SearchIndex(result); + + ++completedQueries; + } + }); + } + + for (auto &thread : threads) + { + thread.join(); + } + + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Multithreaded query completed: %d queries\n", completedQueries.load()); + loaded = nullptr; + + // Cleanup + std::filesystem::remove_all("multi_query_index"); +} + +BOOST_AUTO_TEST_CASE(IndexShadowCloneLifecycleKeepLast) +{ + using namespace SPFreshTest; + + constexpr int iterations = 5; + constexpr int insertBatchSize = 100; + + std::shared_ptr vecset, queryset, truth, addvecset, addtruth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.Run(vecset, metaset, queryset, truth, addvecset, addmetaset, addtruth); + + const std::string baseIndexName = "base_index"; + BOOST_REQUIRE(BuildIndex(baseIndexName, vecset, metaset)->SaveIndex(baseIndexName) == ErrorCode::Success); + + std::string previousIndexName = baseIndexName; + + for (int iter = 0; iter < iterations; ++iter) + { + std::string shadowIndexName = "shadow_index_" + std::to_string(iter); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "[%d] Loading index: %s\n", iter, previousIndexName.c_str()); + + // Load previous index + std::shared_ptr loaded; + BOOST_REQUIRE(VectorIndex::LoadIndex(previousIndexName, loaded) == ErrorCode::Success); + BOOST_REQUIRE(loaded != nullptr); + + // Query check + for (int i = 0; i < std::min(queryset->Count(), 5); ++i) + { + QueryResult result(queryset->GetVector(i), K, true); + loaded->SearchIndex(result); + } + + // Cleanup previous base index after first iteration + if (iter == 1) + { + std::filesystem::remove_all(baseIndexName); + } + + // Clone to shadow + BOOST_REQUIRE(loaded->Clone(shadowIndexName) != nullptr); + loaded.reset(); + + std::shared_ptr shadowLoaded; + BOOST_REQUIRE(VectorIndex::LoadIndex(shadowIndexName, shadowLoaded) == ErrorCode::Success); + BOOST_REQUIRE(shadowLoaded != nullptr); + auto *shadowIndex = static_cast *>(shadowLoaded.get()); + + // Prepare insert batch + const int insertOffset = (iter * insertBatchSize) % static_cast(addvecset->Count()); + const int insertCount = min(insertBatchSize, static_cast(addvecset->Count()) - insertOffset); + + std::vector metaBytes; + std::vector offsetTable(insertCount + 1); + std::uint64_t offset = 0; + for (int i = 0; i < insertCount; ++i) + { + ByteArray meta = addmetaset->GetMetadata(insertOffset + i); + offsetTable[i] = offset; + metaBytes.insert(metaBytes.end(), meta.Data(), meta.Data() + meta.Length()); + offset += meta.Length(); + } + offsetTable[insertCount] = offset; + + ByteArray metaBuf(new std::uint8_t[metaBytes.size()], metaBytes.size(), true); + std::memcpy(metaBuf.Data(), metaBytes.data(), metaBytes.size()); + + ByteArray offsetBuf(new std::uint8_t[offsetTable.size() * sizeof(std::uint64_t)], + offsetTable.size() * sizeof(std::uint64_t), true); + std::memcpy(offsetBuf.Data(), offsetTable.data(), offsetTable.size() * sizeof(std::uint64_t)); + + auto batchMeta = std::make_shared(metaBuf, offsetBuf, insertCount); + const void *vectorStart = addvecset->GetVector(insertOffset); + + shadowIndex->AddIndex(vectorStart, insertCount, shadowIndex->GetOptions()->m_dim, batchMeta, true); + + while (!shadowIndex->AllFinished()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + BOOST_REQUIRE(shadowLoaded->SaveIndex(shadowIndexName) == ErrorCode::Success); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "[%d] Created new shadow index: %s\n", iter, shadowIndexName.c_str()); + shadowLoaded = nullptr; + + previousIndexName = shadowIndexName; + } + + // Keep the final shadow index directory for debugging/inspection + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Kept final index: %s\n", previousIndexName.c_str()); + + // Cleanup all created indexes after test + std::filesystem::remove_all(baseIndexName); + for (int iter = 0; iter < iterations; ++iter) + { + std::string shadow = "shadow_index_" + std::to_string(iter); + std::filesystem::remove_all(shadow); + } +} + +BOOST_AUTO_TEST_CASE(IterativeSearch) +{ + using namespace SPFreshTest; + + constexpr int insertIterations = 5; + constexpr int insertBatchSize = 1000; + constexpr int dimension = 1024; + std::shared_ptr vecset = get_embeddings(0, insertBatchSize, dimension, -1); + std::shared_ptr metaset = + TestUtils::TestDataGenerator::GenerateMetadataSet(insertBatchSize, 0); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + std::string prevPath = "original_index"; + for (int iter = 0; iter < insertIterations; iter++) + { + std::string clone_path = "clone_index_" + std::to_string(iter); + std::shared_ptr prevIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, prevIndex) == ErrorCode::Success); + BOOST_REQUIRE(prevIndex != nullptr); + + auto cloneIndex = prevIndex->Clone(clone_path); + auto *cloneIndexPtr = static_cast *>(cloneIndex.get()); + std::shared_ptr tmpvecs = + get_embeddings((iter + 1) * insertBatchSize, (iter + 2) * insertBatchSize, dimension, -1); + std::shared_ptr tmpmetas = + TestUtils::TestDataGenerator::GenerateMetadataSet(insertBatchSize, (iter + 1) * insertBatchSize); + InsertVectors(cloneIndexPtr, 1, insertBatchSize, tmpvecs, tmpmetas); + + BOOST_REQUIRE(cloneIndex->SaveIndex(clone_path) == ErrorCode::Success); + cloneIndex = nullptr; + + std::shared_ptr loadedIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(clone_path, loadedIndex) == ErrorCode::Success); + BOOST_REQUIRE(loadedIndex != nullptr); + + std::shared_ptr embedding = + get_embeddings((1000 * iter) + 500, ((1000 * iter) + 501), dimension, -1); + std::shared_ptr resultIterator = loadedIndex->GetIterator(embedding->GetData(), false); + int batch = 100; + int ri = 0; + float current = INT_MAX, previous = INT_MAX; + bool relaxMono = false; + while (!relaxMono) + { + auto results = resultIterator->Next(batch); + int resultCount = results->GetResultNum(); + if (resultCount <= 0) + break; + + previous = current; + current = 0; + for (int j = 0; j < resultCount; j++) + { + std::cout << "Result[" << ri << "] VID:" << results->GetResult(j)->VID + << " Dist:" << results->GetResult(j)->Dist + << " RelaxedMono:" << results->GetResult(j)->RelaxedMono << " current:" << current + << " previous:" << previous << std::endl; + relaxMono = results->GetResult(j)->RelaxedMono; + current += results->GetResult(j)->Dist; + ri++; + } + current /= resultCount; + } + resultIterator->Close(); + loadedIndex = nullptr; + } + + for (int iter = 0; iter < insertIterations; iter++) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + std::filesystem::remove_all("original_index"); +} + +BOOST_AUTO_TEST_CASE(RefineIndex) +{ + using namespace SPFreshTest; + + int iterations = 5; + int insertBatchSize = N / iterations; + int deleteBatchSize = N / iterations; + + // Generate test data + std::shared_ptr vecset, addvecset, queryset, truth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.RunBatches(vecset, metaset, addvecset, addmetaset, queryset, N, insertBatchSize, deleteBatchSize, + iterations, truth); + + // Build and save index + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + + float recall = Search(originalIndex, queryset, vecset, addvecset, K, truth, N); + std::cout << "original: recall@" << K << "= " << recall << std::endl; + + for (int iter = 0; iter < iterations; iter++) + { + + InsertVectors(static_cast *>(originalIndex.get()), 1, insertBatchSize, addvecset, + metaset, iter * insertBatchSize); + for (int i = 0; i < deleteBatchSize; i++) + originalIndex->DeleteIndex(iter * deleteBatchSize + i); + + recall = Search(originalIndex, queryset, vecset, addvecset, K, truth, N, iter + 1); + std::cout << "iter " << iter << ": recall@" << K << "=" << recall << std::endl; + } + std::cout << "Before Refine:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(originalIndex.get())->GetDBStat(); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + BOOST_REQUIRE(VectorIndex::LoadIndex("original_index", originalIndex) == ErrorCode::Success); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->Check() == ErrorCode::Success); + + recall = Search(originalIndex, queryset, vecset, addvecset, K, truth, N, iterations); + std::cout << "After Refine:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(originalIndex.get())->GetDBStat(); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + std::filesystem::remove_all("original_index"); +} + +BOOST_AUTO_TEST_CASE(CacheTest) +{ + using namespace SPFreshTest; + + int iterations = 5; + int insertBatchSize = N / iterations; + int deleteBatchSize = N / iterations; + + // Generate test data + std::shared_ptr vecset, addvecset, queryset, truth; + std::shared_ptr metaset, addmetaset; + + TestUtils::TestDataGenerator generator(N, queries, M, K, "L2"); + generator.RunBatches(vecset, metaset, addvecset, addmetaset, queryset, N, insertBatchSize, deleteBatchSize, + iterations, truth); + + // Build and save index + std::shared_ptr originalIndex, finalIndex; + + std::filesystem::remove_all("original_index"); + + originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + + for (int iter = 0; iter < iterations; iter++) + { + if (direxists(("clone_index_" + std::to_string(iter)).c_str())) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + } + + std::string prevPath = "original_index"; + float recall = 0.0; + + std::cout << "=================No Cache===================" << std::endl; + + for (int iter = 0; iter < iterations; iter++) + { + std::string clone_path = "clone_index_" + std::to_string(iter); + std::shared_ptr prevIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, prevIndex) == ErrorCode::Success); + BOOST_REQUIRE(prevIndex != nullptr); + auto t0 = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(prevIndex->Check() == ErrorCode::Success); + std::cout << "[INFO] Check time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t0).count() + << " ms" << std::endl; + + + auto cloneIndex = prevIndex->Clone(clone_path); + prevIndex = nullptr; + BOOST_REQUIRE(cloneIndex->Check() == ErrorCode::Success); + + recall = Search(cloneIndex, queryset, vecset, addvecset, K, truth, N, iter); + std::cout << "[INFO] After Save, Clone and Load:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(cloneIndex.get())->GetDBStat(); + + auto t1 = std::chrono::high_resolution_clock::now(); + InsertVectors(static_cast *>(cloneIndex.get()), 1, insertBatchSize, addvecset, + metaset, iter * insertBatchSize); + std::cout << "[INFO] Insert time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t1).count() + << " ms" << std::endl; + + for (int i = 0; i < deleteBatchSize; i++) + cloneIndex->DeleteIndex(iter * deleteBatchSize + i); + + recall = Search(cloneIndex, queryset, vecset, addvecset, K, truth, N, iter + 1); + std::cout << "[INFO] After iter " << iter << ": recall@" << K << "=" << recall << std::endl; + static_cast *>(cloneIndex.get())->GetDBStat(); + + BOOST_REQUIRE(cloneIndex->SaveIndex(clone_path) == ErrorCode::Success); + cloneIndex = nullptr; + prevPath = clone_path; + } + + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, finalIndex) == ErrorCode::Success); + BOOST_REQUIRE(finalIndex != nullptr); + auto t = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(finalIndex->Check() == ErrorCode::Success); + std::cout << "[INFO] Check time for iteration " << iterations << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t).count() + << " ms" << std::endl; + + recall = Search(finalIndex, queryset, vecset, addvecset, K, truth, N, iterations); + std::cout << "[INFO] After Save and Load:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(finalIndex.get())->GetDBStat(); + finalIndex = nullptr; + for (int iter = 0; iter < iterations; iter++) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + + std::cout << "=================Enable Cache===================" << std::endl; + prevPath = "original_index"; + for (int iter = 0; iter < iterations; iter++) + { + std::string clone_path = "clone_index_" + std::to_string(iter); + std::shared_ptr prevIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, prevIndex) == ErrorCode::Success); + BOOST_REQUIRE(prevIndex != nullptr); + auto t0 = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(prevIndex->Check() == ErrorCode::Success); + std::cout << "[INFO] Check time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t0).count() + << " ms" << std::endl; + + + prevIndex->SetParameter("CacheSizeGB", "4", "BuildSSDIndex"); + prevIndex->SetParameter("CacheShards", "2", "BuildSSDIndex"); + + BOOST_REQUIRE(prevIndex->SaveIndex(prevPath) == ErrorCode::Success); + auto cloneIndex = prevIndex->Clone(clone_path); + + recall = Search(cloneIndex, queryset, vecset, addvecset, K, truth, N, iter); + std::cout << "[INFO] After Save, Clone and Load:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(cloneIndex.get())->GetDBStat(); + + auto t1 = std::chrono::high_resolution_clock::now(); + InsertVectors(static_cast *>(cloneIndex.get()), 1, insertBatchSize, addvecset, + metaset, iter * insertBatchSize); + std::cout << "[INFO] Insert time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t1).count() + << " ms" << std::endl; + + for (int i = 0; i < deleteBatchSize; i++) + cloneIndex->DeleteIndex(iter * deleteBatchSize + i); + + recall = Search(cloneIndex, queryset, vecset, addvecset, K, truth, N, iter + 1); + std::cout << "[INFO] After iter " << iter << ": recall@" << K << "=" << recall << std::endl; + static_cast *>(cloneIndex.get())->GetDBStat(); + + BOOST_REQUIRE(cloneIndex->SaveIndex(clone_path) == ErrorCode::Success); + cloneIndex = nullptr; + prevPath = clone_path; + } + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, finalIndex) == ErrorCode::Success); + BOOST_REQUIRE(finalIndex != nullptr); + auto tt = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(finalIndex->Check() == ErrorCode::Success); + std::cout << "[INFO] Check time for iteration " << iterations << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - tt).count() + << " ms" << std::endl; + + recall = Search(finalIndex, queryset, vecset, addvecset, K, truth, N, iterations); + std::cout << "[INFO] After Save and Load:" << " recall@" << K << "=" << recall << std::endl; + static_cast *>(finalIndex.get())->GetDBStat(); + finalIndex = nullptr; + + for (int iter = 0; iter < iterations; iter++) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + + std::filesystem::remove_all("original_index"); +} + +BOOST_AUTO_TEST_CASE(IterativeSearchPerf) +{ + using namespace SPFreshTest; + + constexpr int insertIterations = 5; + constexpr int insertBatchSize = 60000; + constexpr int appendBatchSize = 40000; + constexpr int dimension = 100; + std::shared_ptr vecset = get_embeddings(0, insertBatchSize, dimension, -1); + std::shared_ptr metaset = TestUtils::TestDataGenerator::GenerateMetadataSet(insertBatchSize, 0); + + auto originalIndex = BuildIndex("original_index", vecset, metaset); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + std::string prevPath = "original_index"; + for (int iter = 0; iter < insertIterations; iter++) + { + std::string clone_path = "clone_index_" + std::to_string(iter); + std::shared_ptr prevIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, prevIndex) == ErrorCode::Success); + BOOST_REQUIRE(prevIndex != nullptr); + auto t0 = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(prevIndex->Check() == ErrorCode::Success); + std::cout << "Check time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - t0).count() + << " ms" << std::endl; + + auto cloneIndex = prevIndex->Clone(clone_path); + auto *cloneIndexPtr = static_cast *>(cloneIndex.get()); + std::shared_ptr tmpvecs = get_embeddings( + insertBatchSize + iter * appendBatchSize, insertBatchSize + (iter + 1) * appendBatchSize, dimension, -1); + std::shared_ptr tmpmetas = TestUtils::TestDataGenerator::GenerateMetadataSet( + appendBatchSize, insertBatchSize + (iter)*appendBatchSize); + auto t1 = std::chrono::high_resolution_clock::now(); + InsertVectors(cloneIndexPtr, 1, appendBatchSize, tmpvecs, tmpmetas); + std::cout << "Insert time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - + t1) + .count() + << " ms" << std::endl; + + BOOST_REQUIRE(cloneIndex->SaveIndex(clone_path) == ErrorCode::Success); + cloneIndex = nullptr; + } + + for (int iter = 0; iter < insertIterations; iter++) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + std::filesystem::remove_all("original_index"); +} + +BOOST_AUTO_TEST_CASE(RefineTestIdx) +{ + using namespace SPFreshTest; + + constexpr int dimension = 1024; + + std::shared_ptr vecset = get_embeddings(0, 500, dimension, -1); + std::shared_ptr metaset = TestUtils::TestDataGenerator::GenerateMetadataSet(1000, 0); + + for (auto i = 0; i < 2; ++i) { + void* p = vecset->GetVector(i); + for (auto i = 0; i < dimension; ++i) { + std::cout << ((float*)p)[i] << " "; + } + std::cout << std::endl; + } + + auto originalIndex = BuildIndex("original_index", vecset, metaset, "COSINE"); + BOOST_REQUIRE(originalIndex != nullptr); + BOOST_REQUIRE(originalIndex->SaveIndex("original_index") == ErrorCode::Success); + originalIndex = nullptr; + + std::string prevPath = "original_index"; + for (int iter = 0; iter < 1; iter++) + { + std::string clone_path = "clone_index_" + std::to_string(iter); + std::shared_ptr prevIndex; + BOOST_REQUIRE(VectorIndex::LoadIndex(prevPath, prevIndex) == ErrorCode::Success); + BOOST_REQUIRE(prevIndex != nullptr); + auto t0 = std::chrono::high_resolution_clock::now(); + BOOST_REQUIRE(prevIndex->Check() == ErrorCode::Success); + std::cout << "Check time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - + t0) + .count() + << " ms" << std::endl; + + auto cloneIndex = prevIndex->Clone(clone_path); + auto *cloneIndexPtr = static_cast *>(cloneIndex.get()); + std::shared_ptr tmpvecs = get_embeddings(500, 1100, dimension, -1); + std::shared_ptr tmpmetas = TestUtils::TestDataGenerator::GenerateMetadataSet(1200, 1000); + auto t1 = std::chrono::high_resolution_clock::now(); + InsertVectors(cloneIndexPtr, 1, 1200, tmpvecs, tmpmetas); + std::cout << "Insert time for iteration " << iter << ": " + << std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - + t1) + .count() + << " ms" << std::endl; + + for (auto i = 1000; i < 1900; ++i) + { + cloneIndexPtr->DeleteIndex(i); + } + + BOOST_REQUIRE(cloneIndex->SaveIndex(clone_path) == ErrorCode::Success); + cloneIndex = nullptr; + } + + for (int iter = 0; iter < 1; iter++) + { + std::filesystem::remove_all("clone_index_" + std::to_string(iter)); + } + // std::filesystem::remove_all("original_index"); +} +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/SSDServingTest.cpp b/Test/src/SSDServingTest.cpp index cba53866c..edebfd405 100644 --- a/Test/src/SSDServingTest.cpp +++ b/Test/src/SSDServingTest.cpp @@ -2,375 +2,397 @@ // Licensed under the MIT License. #include "inc/Test.h" -#include +#include #include +#include #include +#include #include -#include -#include #include #include "inc/Core/Common.h" +#include "inc/Core/Common/CommonUtils.h" +#include "inc/Core/Common/DistanceUtils.h" #include "inc/Helper/StringConvert.h" #include "inc/SSDServing/main.h" -#include "inc/Core/Common/DistanceUtils.h" -#include "inc/Core/Common/CommonUtils.h" - -template -void GenerateVectors(std::string fileName, SPTAG::SizeType rows, SPTAG::DimensionType dims, SPTAG::VectorFileType fileType) { - if (boost::filesystem::exists(fileName)) - { - fprintf(stdout, "%s was generated. Skip generation.\n", fileName.c_str()); - return; - } - - std::ofstream of(fileName, std::ofstream::binary); - if (!of.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", fileName.c_str()); - BOOST_CHECK(false); - return; - } - - std::uniform_real_distribution ud(0, 126); - std::mt19937 mt(543); - std::vector tmp(dims); - - if (fileType == SPTAG::VectorFileType::DEFAULT) - { - of.write(reinterpret_cast(&rows), 4); - of.write(reinterpret_cast(&dims), 4); - - for (size_t i = 0; i < rows; i++) - { - for (size_t j = 0; j < dims; j++) - { - float smt = ud(mt); - tmp[j] = static_cast(smt); - } - - SPTAG::COMMON::Utils::Normalize(tmp.data(), dims, SPTAG::COMMON::Utils::GetBase()); - of.write(reinterpret_cast(tmp.data()), dims * sizeof(T)); - } - } - else if (fileType == SPTAG::VectorFileType::XVEC) - { - for (size_t i = 0; i < rows; i++) - { - for (size_t j = 0; j < dims; j++) - { - float smt = ud(mt); - tmp[j] = static_cast(smt); - } - - SPTAG::COMMON::Utils::Normalize(tmp.data(), dims, SPTAG::COMMON::Utils::GetBase()); - of.write(reinterpret_cast(&dims), 4); - of.write(reinterpret_cast(tmp.data()), dims * sizeof(T)); - } - - } +template +void GenerateVectors(std::string fileName, SPTAG::SizeType rows, SPTAG::DimensionType dims, + SPTAG::VectorFileType fileType) +{ + if (boost::filesystem::exists(fileName)) + { + fprintf(stdout, "%s was generated. Skip generation.\n", fileName.c_str()); + return; + } + + std::ofstream of(fileName, std::ofstream::binary); + if (!of.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", fileName.c_str()); + BOOST_CHECK(false); + return; + } + + std::uniform_real_distribution ud(0, 126); + std::mt19937 mt(543); + std::vector tmp(dims); + + if (fileType == SPTAG::VectorFileType::DEFAULT) + { + of.write(reinterpret_cast(&rows), 4); + of.write(reinterpret_cast(&dims), 4); + + for (size_t i = 0; i < rows; i++) + { + for (size_t j = 0; j < dims; j++) + { + float smt = ud(mt); + tmp[j] = static_cast(smt); + } + + SPTAG::COMMON::Utils::Normalize(tmp.data(), dims, SPTAG::COMMON::Utils::GetBase()); + of.write(reinterpret_cast(tmp.data()), dims * sizeof(T)); + } + } + else if (fileType == SPTAG::VectorFileType::XVEC) + { + for (size_t i = 0; i < rows; i++) + { + for (size_t j = 0; j < dims; j++) + { + float smt = ud(mt); + tmp[j] = static_cast(smt); + } + + SPTAG::COMMON::Utils::Normalize(tmp.data(), dims, SPTAG::COMMON::Utils::GetBase()); + of.write(reinterpret_cast(&dims), 4); + of.write(reinterpret_cast(tmp.data()), dims * sizeof(T)); + } + } } -void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) { - switch (vecType) - { -#define DefineVectorValueType(Name, Type) \ -case SPTAG::VectorValueType::Name: \ -GenerateVectors(vectorsName, rows, dims, vecFileType); \ -break; \ +void GenVec(std::string vectorsName, SPTAG::VectorValueType vecType, SPTAG::VectorFileType vecFileType, + SPTAG::SizeType rows = 1000, SPTAG::DimensionType dims = 100) +{ + switch (vecType) + { +#define DefineVectorValueType(Name, Type) \ + case SPTAG::VectorValueType::Name: \ + GenerateVectors(vectorsName, rows, dims, vecFileType); \ + break; #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType - default: - break; - } + default: + break; + } } -std::string CreateBaseConfig(SPTAG::VectorValueType p_valueType, SPTAG::DistCalcMethod p_distCalcMethod, - SPTAG::IndexAlgoType p_indexAlgoType, SPTAG::DimensionType p_dim, - std::string p_vectorPath, SPTAG::VectorFileType p_vectorType, SPTAG::SizeType p_vectorSize, std::string p_vectorDelimiter, - std::string p_queryPath, SPTAG::VectorFileType p_queryType, SPTAG::SizeType p_querySize, std::string p_queryDelimiter, - std::string p_warmupPath, SPTAG::VectorFileType p_warmupType, SPTAG::SizeType p_warmupSize, std::string p_warmupDelimiter, - std::string p_truthPath, SPTAG::TruthFileType p_truthType, - bool p_generateTruth, - std::string p_headIDFile, - std::string p_headVectorsFile, - std::string p_headIndexFolder, - std::string p_ssdIndex -) { - std::ostringstream config; - config << "[Base]" << std::endl; - config << "ValueType=" << SPTAG::Helper::Convert::ConvertToString(p_valueType) << std::endl; - config << "DistCalcMethod=" << SPTAG::Helper::Convert::ConvertToString(p_distCalcMethod) << std::endl; - config << "IndexAlgoType=" << SPTAG::Helper::Convert::ConvertToString(p_indexAlgoType) << std::endl; - config << "Dim=" << p_dim << std::endl; - config << "VectorPath=" << p_vectorPath << std::endl; - config << "VectorType=" << SPTAG::Helper::Convert::ConvertToString(p_vectorType) << std::endl; - config << "VectorSize=" << p_vectorSize << std::endl; - config << "VectorDelimiter=" << p_vectorDelimiter << std::endl; - config << "QueryPath=" << p_queryPath << std::endl; - config << "QueryType=" << SPTAG::Helper::Convert::ConvertToString(p_queryType) << std::endl; - config << "QuerySize=" << p_querySize << std::endl; - config << "QueryDelimiter=" << p_queryDelimiter << std::endl; - config << "WarmupPath=" << p_warmupPath << std::endl; - config << "WarmupType=" << SPTAG::Helper::Convert::ConvertToString(p_warmupType) << std::endl; - config << "WarmupSize=" << p_warmupSize << std::endl; - config << "WarmupDelimiter=" << p_warmupDelimiter << std::endl; - config << "TruthPath=" << p_truthPath << std::endl; - config << "TruthType=" << SPTAG::Helper::Convert::ConvertToString(p_truthType) << std::endl; - config << "GenerateTruth=" << SPTAG::Helper::Convert::ConvertToString(p_generateTruth) << std::endl; - config << "IndexDirectory=" << "zbtest" << std::endl; - config << "HeadVectorIDs=" << p_headIDFile << std::endl; - config << "HeadVectors=" << p_headVectorsFile << std::endl; - config << "HeadIndexFolder=" << p_headIndexFolder << std::endl; - config << "SSDIndex=" << p_ssdIndex << std::endl; - config << std::endl; - return config.str(); +std::string CreateBaseConfig(SPTAG::VectorValueType p_valueType, SPTAG::DistCalcMethod p_distCalcMethod, + SPTAG::IndexAlgoType p_indexAlgoType, SPTAG::DimensionType p_dim, std::string p_vectorPath, + SPTAG::VectorFileType p_vectorType, SPTAG::SizeType p_vectorSize, + std::string p_vectorDelimiter, std::string p_queryPath, SPTAG::VectorFileType p_queryType, + SPTAG::SizeType p_querySize, std::string p_queryDelimiter, std::string p_warmupPath, + SPTAG::VectorFileType p_warmupType, SPTAG::SizeType p_warmupSize, + std::string p_warmupDelimiter, std::string p_truthPath, SPTAG::TruthFileType p_truthType, + bool p_generateTruth, std::string p_headIDFile, std::string p_headVectorsFile, + std::string p_headIndexFolder, std::string p_ssdIndex) +{ + std::ostringstream config; + config << "[Base]" << std::endl; + config << "ValueType=" << SPTAG::Helper::Convert::ConvertToString(p_valueType) << std::endl; + config << "DistCalcMethod=" << SPTAG::Helper::Convert::ConvertToString(p_distCalcMethod) << std::endl; + config << "IndexAlgoType=" << SPTAG::Helper::Convert::ConvertToString(p_indexAlgoType) << std::endl; + config << "Dim=" << p_dim << std::endl; + config << "VectorPath=" << p_vectorPath << std::endl; + config << "VectorType=" << SPTAG::Helper::Convert::ConvertToString(p_vectorType) << std::endl; + config << "VectorSize=" << p_vectorSize << std::endl; + config << "VectorDelimiter=" << p_vectorDelimiter << std::endl; + config << "QueryPath=" << p_queryPath << std::endl; + config << "QueryType=" << SPTAG::Helper::Convert::ConvertToString(p_queryType) << std::endl; + config << "QuerySize=" << p_querySize << std::endl; + config << "QueryDelimiter=" << p_queryDelimiter << std::endl; + config << "WarmupPath=" << p_warmupPath << std::endl; + config << "WarmupType=" << SPTAG::Helper::Convert::ConvertToString(p_warmupType) << std::endl; + config << "WarmupSize=" << p_warmupSize << std::endl; + config << "WarmupDelimiter=" << p_warmupDelimiter << std::endl; + config << "TruthPath=" << p_truthPath << std::endl; + config << "TruthType=" << SPTAG::Helper::Convert::ConvertToString(p_truthType) << std::endl; + config << "GenerateTruth=" << SPTAG::Helper::Convert::ConvertToString(p_generateTruth) << std::endl; + config << "IndexDirectory=" + << "zbtest" << std::endl; + config << "HeadVectorIDs=" << p_headIDFile << std::endl; + config << "HeadVectors=" << p_headVectorsFile << std::endl; + config << "HeadIndexFolder=" << p_headIndexFolder << std::endl; + config << "SSDIndex=" << p_ssdIndex << std::endl; + config << std::endl; + return config.str(); } -void TestHead(std::string configName, std::string OutputIDFile, std::string OutputVectorFile, std::string baseConfig) { - - std::ofstream config(configName); - if (!config.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", configName.c_str()); - BOOST_CHECK(false); - return; - } - - config << baseConfig; - - config << "[SelectHead]" << std::endl; - config << "isExecute=true" << std::endl; - config << "TreeNumber=" << "1" << std::endl; - config << "BKTKmeansK=" << 3 << std::endl; - config << "BKTLeafSize=" << 6 << std::endl; - config << "SamplesNumber=" << 100 << std::endl; - config << "NumberOfThreads=" << "2" << std::endl; - config << "SaveBKT=" << "false" << std::endl; - - config << "AnalyzeOnly=" << "false" << std::endl; - config << "CalcStd=" << "true" << std::endl; - config << "SelectDynamically=" << "true" << std::endl; - config << "NoOutput=" << "false" << std::endl; - - config << "SelectThreshold=" << 12 << std::endl; - config << "SplitFactor=" << 9 << std::endl; - config << "SplitThreshold=" << 18 << std::endl; - config << "Ratio=" << "0.2" << std::endl; - config << "RecursiveCheckSmallCluster=" << "true" << std::endl; - config << "PrintSizeCount=" << "true" << std::endl; - - config.close(); - std::map> my_map; - SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); +void TestHead(std::string configName, std::string OutputIDFile, std::string OutputVectorFile, std::string baseConfig) +{ + + std::ofstream config(configName); + if (!config.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", configName.c_str()); + BOOST_CHECK(false); + return; + } + + config << baseConfig; + + config << "[SelectHead]" << std::endl; + config << "isExecute=true" << std::endl; + config << "TreeNumber=" + << "1" << std::endl; + config << "BKTKmeansK=" << 3 << std::endl; + config << "BKTLeafSize=" << 6 << std::endl; + config << "SamplesNumber=" << 100 << std::endl; + config << "NumberOfThreads=" + << "2" << std::endl; + config << "SaveBKT=" + << "false" << std::endl; + + config << "AnalyzeOnly=" + << "false" << std::endl; + config << "CalcStd=" + << "true" << std::endl; + config << "SelectDynamically=" + << "true" << std::endl; + config << "NoOutput=" + << "false" << std::endl; + + config << "SelectThreshold=" << 12 << std::endl; + config << "SplitFactor=" << 9 << std::endl; + config << "SplitThreshold=" << 18 << std::endl; + config << "Ratio=" + << "0.2" << std::endl; + config << "RecursiveCheckSmallCluster=" + << "true" << std::endl; + config << "PrintSizeCount=" + << "true" << std::endl; + + config.close(); + std::map> my_map; + SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); } -void TestBuildHead( - std::string configName, - std::string p_headVectorFile, - std::string p_headIndexFile, - SPTAG::IndexAlgoType p_indexAlgoType, - std::string p_builderFile, - std::string baseConfig) { - - std::ofstream config(configName); - if (!config.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", configName.c_str()); - BOOST_CHECK(false); - return; - } - - { - if (p_builderFile.empty()) - { - fprintf(stderr, "no builder file for head index build.\n"); - BOOST_CHECK(false); - return; - } - - if (boost::filesystem::exists(p_builderFile)) - { - fprintf(stdout, "%s was generated. Skip generation.\n", p_builderFile.c_str()); - } - else { - std::ofstream bf(p_builderFile); - if (!bf.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", p_builderFile.c_str()); - BOOST_CHECK(false); - return; - } - bf.close(); - } - } - - config << baseConfig; - - config << "[BuildHead]" << std::endl; - config << "isExecute=true" << std::endl; - config << "NumberOfThreads=" << 2 << std::endl; - - config.close(); - - std::map> my_map; - SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); +void TestBuildHead(std::string configName, std::string p_headVectorFile, std::string p_headIndexFile, + SPTAG::IndexAlgoType p_indexAlgoType, std::string p_builderFile, std::string baseConfig) +{ + + std::ofstream config(configName); + if (!config.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", configName.c_str()); + BOOST_CHECK(false); + return; + } + + { + if (p_builderFile.empty()) + { + fprintf(stderr, "no builder file for head index build.\n"); + BOOST_CHECK(false); + return; + } + + if (boost::filesystem::exists(p_builderFile)) + { + fprintf(stdout, "%s was generated. Skip generation.\n", p_builderFile.c_str()); + } + else + { + std::ofstream bf(p_builderFile); + if (!bf.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", p_builderFile.c_str()); + BOOST_CHECK(false); + return; + } + bf.close(); + } + } + + config << baseConfig; + + config << "[BuildHead]" << std::endl; + config << "isExecute=true" << std::endl; + config << "NumberOfThreads=" << 2 << std::endl; + + config.close(); + + std::map> my_map; + SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); } -void TestBuildSSDIndex(std::string configName, - std::string p_vectorIDTranslate, - std::string p_headIndexFolder, - std::string p_headConfig, - std::string p_ssdIndex, - bool p_outputEmptyReplicaID, - std::string baseConfig -) { - std::ofstream config(configName); - if (!config.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", configName.c_str()); - BOOST_CHECK(false); - return; - } - - if (!p_headConfig.empty()) - { - if (boost::filesystem::exists(p_headConfig)) - { - fprintf(stdout, "%s was generated. Skip generation.\n", p_headConfig.c_str()); - } - else { - std::ofstream bf(p_headConfig); - if (!bf.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", p_headConfig.c_str()); - BOOST_CHECK(false); - return; - } - bf << "[Index]" << std::endl; - bf.close(); - } - } - - config << baseConfig; - - config << "[BuildSSDIndex]" << std::endl; - config << "isExecute=true" << std::endl; - config << "BuildSsdIndex=" << "true" << std::endl; - config << "InternalResultNum=" << 60 << std::endl; - config << "NumberOfThreads=" << 2 << std::endl; - config << "HeadConfig=" << p_headConfig << std::endl; - - config << "ReplicaCount=" << 4 << std::endl; - config << "PostingPageLimit=" << 2 << std::endl; - config << "OutputEmptyReplicaID=" << p_outputEmptyReplicaID << std::endl; - - config.close(); - std::map> my_map; - SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); +void TestBuildSSDIndex(std::string configName, std::string p_vectorIDTranslate, std::string p_headIndexFolder, + std::string p_headConfig, std::string p_ssdIndex, bool p_outputEmptyReplicaID, + std::string baseConfig) +{ + std::ofstream config(configName); + if (!config.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", configName.c_str()); + BOOST_CHECK(false); + return; + } + + if (!p_headConfig.empty()) + { + if (boost::filesystem::exists(p_headConfig)) + { + fprintf(stdout, "%s was generated. Skip generation.\n", p_headConfig.c_str()); + } + else + { + std::ofstream bf(p_headConfig); + if (!bf.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", p_headConfig.c_str()); + BOOST_CHECK(false); + return; + } + bf << "[Index]" << std::endl; + bf.close(); + } + } + + config << baseConfig; + + config << "[BuildSSDIndex]" << std::endl; + config << "isExecute=true" << std::endl; + config << "BuildSsdIndex=" + << "true" << std::endl; + config << "InternalResultNum=" << 60 << std::endl; + config << "NumberOfThreads=" << 2 << std::endl; + config << "HeadConfig=" << p_headConfig << std::endl; + + config << "ReplicaCount=" << 4 << std::endl; + config << "PostingPageLimit=" << 2 << std::endl; + config << "OutputEmptyReplicaID=" << p_outputEmptyReplicaID << std::endl; + + config << "Storage=" + << "FILEIO" << std::endl; + config << "Update=" + << "true" << std::endl; + config << "StartFileSizeGB=" << 1 << std::endl; + config << "SsdInfoFile=" << p_ssdIndex + "_info" << std::endl; + config << "SsdMappingFile=" << p_ssdIndex + "_mapping" << std::endl; + + config.close(); + std::map> my_map; + SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); } -void TestSearchSSDIndex( - std::string configName, - std::string p_vectorIDTranslate, - std::string p_headIndexFolder, - std::string p_headConfig, - std::string p_ssdIndex, - std::string p_searchResult, - std::string p_logFile, - std::string baseConfig -) { - std::ofstream config(configName); - if (!config.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", configName.c_str()); - BOOST_CHECK(false); - return; - } - - if (!p_headConfig.empty()) - { - if (boost::filesystem::exists(p_headConfig)) - { - fprintf(stdout, "%s was generated. Skip generation.\n", p_headConfig.c_str()); - } - else { - std::ofstream bf(p_headConfig); - if (!bf.is_open()) - { - fprintf(stderr, "%s can't be opened.\n", p_headConfig.c_str()); - BOOST_CHECK(false); - return; - } - bf << "[Index]" << std::endl; - bf.close(); - } - } - - config << baseConfig; - - config << "[SearchSSDIndex]" << std::endl; - config << "isExecute=true" << std::endl; - config << "BuildSsdIndex=" << "false" << std::endl; - config << "InternalResultNum=" << 64 << std::endl; - config << "NumberOfThreads=" << 2 << std::endl; - config << "HeadConfig=" << p_headConfig << std::endl; - - config << "SearchResult=" << p_searchResult << std::endl; - config << "LogFile=" << p_logFile << std::endl; - config << "QpsLimit=" << 0 << std::endl; - config << "ResultNum=" << 32 << std::endl; - config << "MaxCheck=" << 2048 << std::endl; - config << "SearchPostingPageLimit=" << 2 << std::endl; - config << "QueryCountLimit=" << 10000 << std::endl; - - config.close(); - std::map> my_map; - SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); +void TestSearchSSDIndex(std::string configName, std::string p_vectorIDTranslate, std::string p_headIndexFolder, + std::string p_headConfig, std::string p_ssdIndex, std::string p_searchResult, + std::string p_logFile, std::string baseConfig) +{ + std::ofstream config(configName); + if (!config.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", configName.c_str()); + BOOST_CHECK(false); + return; + } + + if (!p_headConfig.empty()) + { + if (boost::filesystem::exists(p_headConfig)) + { + fprintf(stdout, "%s was generated. Skip generation.\n", p_headConfig.c_str()); + } + else + { + std::ofstream bf(p_headConfig); + if (!bf.is_open()) + { + fprintf(stderr, "%s can't be opened.\n", p_headConfig.c_str()); + BOOST_CHECK(false); + return; + } + bf << "[Index]" << std::endl; + bf.close(); + } + } + + config << baseConfig; + + config << "[SearchSSDIndex]" << std::endl; + config << "isExecute=true" << std::endl; + config << "BuildSsdIndex=" + << "false" << std::endl; + config << "InternalResultNum=" << 64 << std::endl; + config << "NumberOfThreads=" << 2 << std::endl; + config << "HeadConfig=" << p_headConfig << std::endl; + + config << "SearchResult=" << p_searchResult << std::endl; + config << "LogFile=" << p_logFile << std::endl; + config << "QpsLimit=" << 0 << std::endl; + config << "ResultNum=" << 32 << std::endl; + config << "MaxCheck=" << 2048 << std::endl; + config << "SearchPostingPageLimit=" << 2 << std::endl; + config << "QueryCountLimit=" << 10000 << std::endl; + + config << "Storage=" + << "FILEIO" << std::endl; + config << "Update=" + << "true" << std::endl; + config << "StartFileSizeGB=" << 1 << std::endl; + config << "SsdInfoFile=" << p_ssdIndex + "_info" << std::endl; + config << "SsdMappingFile=" << p_ssdIndex + "_mapping" << std::endl; + + config.close(); + std::map> my_map; + SPTAG::SSDServing::BootProgram(false, &my_map, configName.c_str()); } -void RunFromMap() { - std::map> myMap; - myMap["Base"]["IndexAlgoType"] = "KDT"; - std::string dataFilePath = "sddtest/vectors_Int8_DEFAULT.bin"; - std::string indexFilePath = "zbtest"; - SPTAG::SSDServing::BootProgram( - true, - &myMap, - nullptr, - SPTAG::VectorValueType::Int8, - SPTAG::DistCalcMethod::L2, - dataFilePath.c_str(), - indexFilePath.c_str() - ); - - std::ostringstream config; - config << "[Base]" << std::endl; - config << "ValueType=" << "Int8" << std::endl; - config << "DistCalcMethod=" << "L2" << std::endl; - config << "IndexAlgoType=" << "KDT" << std::endl; - config << "VectorPath=" << "sddtest/vectors_Int8_DEFAULT.bin" << std::endl; - config << "QueryPath=" << "sddtest/vectors_Int8_DEFAULT.query" << std::endl; - config << "QueryType=" << "DEFAULT" << std::endl; - config << "TruthPath=" << "sddtest/vectors_Int8_L2_DEFAULT_DEFAULT.truth" << std::endl; - config << "TruthType=" << "DEFAULT" << std::endl; - config << "IndexDirectory=" << "zbtest" << std::endl; - config << "GenerateTruth=" << "true" << std::endl; - config << std::endl; - - TestSearchSSDIndex( - "run_from_map_search_config.ini", - "", - "", - "", - "", - "", - "", - config.str() - ); +void RunFromMap() +{ + std::map> myMap; + myMap["Base"]["IndexAlgoType"] = "KDT"; + myMap["Base"]["Dim"] = "100"; + myMap["BuildSSDIndex"]["Storage"] = "FILEIO"; + myMap["BuildSSDIndex"]["Update"] = "true"; + myMap["BuildSSDIndex"]["StartFileSizeGB"] = "1"; + myMap["BuildSSDIndex"]["SsdInfoFile"] = "_info"; + myMap["BuildSSDIndex"]["SsdMappingFile"] = "_mapping"; + + std::string dataFilePath = "sddtest/vectors_Int8_DEFAULT.bin"; + std::string indexFilePath = "zbtest"; + SPTAG::SSDServing::BootProgram(true, &myMap, nullptr, SPTAG::VectorValueType::Int8, SPTAG::DistCalcMethod::L2, + dataFilePath.c_str(), indexFilePath.c_str()); + + std::ostringstream config; + config << "[Base]" << std::endl; + config << "ValueType=" + << "Int8" << std::endl; + config << "DistCalcMethod=" + << "L2" << std::endl; + config << "IndexAlgoType=" + << "KDT" << std::endl; + config << "Dim=" << 100 << std::endl; + config << "VectorPath=" + << "sddtest/vectors_Int8_DEFAULT.bin" << std::endl; + config << "QueryPath=" + << "sddtest/vectors_Int8_DEFAULT.query" << std::endl; + config << "QueryType=" + << "DEFAULT" << std::endl; + config << "TruthPath=" + << "sddtest/vectors_Int8_L2_DEFAULT_DEFAULT.truth" << std::endl; + config << "TruthType=" + << "DEFAULT" << std::endl; + config << "IndexDirectory=" + << "zbtest" << std::endl; + config << "GenerateTruth=" + << "true" << std::endl; + config << std::endl; + + TestSearchSSDIndex("run_from_map_search_config.ini", "", "", "", "", "", "", config.str()); } BOOST_AUTO_TEST_SUITE(SSDServingTest) @@ -380,32 +402,43 @@ BOOST_AUTO_TEST_SUITE(SSDServingTest) #define QUERY_NUM 10 #define VECTOR_DIM 100 +#define FLOAT +#define INT8 +//#define UINT8 +#define INT16 +#define TDEFAULT +#define TXVEC + #define SSDTEST_DIRECTORY_NAME "sddtest" #define SSDTEST_DIRECTORY SSDTEST_DIRECTORY_NAME "/" // #define RAW_VECTORS(VT, FT) "vectors_"#VT"_"#FT".bin" -#define VECTORS(VT, FT) SSDTEST_DIRECTORY "vectors_"#VT"_"#FT".bin" -#define QUERIES(VT, FT) SSDTEST_DIRECTORY "vectors_"#VT"_"#FT".query" -#define TRUTHSET(VT, DM, FT, TFT) SSDTEST_DIRECTORY "vectors_"#VT"_"#DM"_"#FT"_"#TFT".truth" -#define HEAD_IDS(VT, DM, FT) "head_ids_"#VT"_"#DM"_"#FT".bin" -#define HEAD_VECTORS(VT, DM, FT) "head_vectors_"#VT"_"#DM"_"#FT".bin" -#define HEAD_INDEX(VT, DM, ALGO, FT) "head_"#VT"_"#DM"_"#ALGO"_"#FT".head_index" -#define SSD_INDEX(VT, DM, ALGO, FT) "ssd_"#VT"_"#DM"_"#ALGO"_"#FT".ssd_index" - -#define SELECT_HEAD_CONFIG(VT, DM, FT) SSDTEST_DIRECTORY "test_head_"#VT"_"#DM"_"#FT".ini" -#define BUILD_HEAD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_head_"#VT"_"#DM"_"#ALGO".ini" -#define BUILD_HEAD_BUILDER_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_head_"#VT"_"#DM"_"#ALGO".builder.ini" -#define BUILD_SSD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_ssd"#VT"_"#DM"_"#ALGO".ini" -#define BUILD_SSD_BUILDER_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_ssd"#VT"_"#DM"_"#ALGO".builder.ini" -#define SEARCH_SSD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_search_ssd_"#VT"_"#DM"_"#ALGO".ini" -#define SEARCH_SSD_BUILDER_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_search_ssd_"#VT"_"#DM"_"#ALGO".builder.ini" -#define SEARCH_SSD_RESULT(VT, DM, ALGO, FT, TFT) SSDTEST_DIRECTORY "test_search_ssd_"#VT"_"#DM"_"#ALGO"_"#FT"_"#TFT".result" - -#define GVQ(VT, FT) \ -BOOST_AUTO_TEST_CASE(GenerateVectorsQueries##VT##FT) { \ -boost::filesystem::create_directory(SSDTEST_DIRECTORY_NAME); \ -GenVec(VECTORS(VT, FT), SPTAG::VectorValueType::VT, SPTAG::VectorFileType::FT, VECTOR_NUM, VECTOR_DIM); \ -GenVec(QUERIES(VT, FT), SPTAG::VectorValueType::VT, SPTAG::VectorFileType::FT, QUERY_NUM, VECTOR_DIM); \ -} \ +#define VECTORS(VT, FT) SSDTEST_DIRECTORY "vectors_" #VT "_" #FT ".bin" +#define QUERIES(VT, FT) SSDTEST_DIRECTORY "vectors_" #VT "_" #FT ".query" +#define TRUTHSET(VT, DM, FT, TFT) SSDTEST_DIRECTORY "vectors_" #VT "_" #DM "_" #FT "_" #TFT ".truth" +#define HEAD_IDS(VT, DM, FT) "head_ids_" #VT "_" #DM "_" #FT ".bin" +#define HEAD_VECTORS(VT, DM, FT) "head_vectors_" #VT "_" #DM "_" #FT ".bin" +#define HEAD_INDEX(VT, DM, ALGO, FT) "head_" #VT "_" #DM "_" #ALGO "_" #FT ".head_index" +#define SSD_INDEX(VT, DM, ALGO, FT) "ssd_" #VT "_" #DM "_" #ALGO "_" #FT ".ssd_index" + +#define SELECT_HEAD_CONFIG(VT, DM, FT) SSDTEST_DIRECTORY "test_head_" #VT "_" #DM "_" #FT ".ini" +#define BUILD_HEAD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_head_" #VT "_" #DM "_" #ALGO ".ini" +#define BUILD_HEAD_BUILDER_CONFIG(VT, DM, ALGO) \ + SSDTEST_DIRECTORY "test_build_head_" #VT "_" #DM "_" #ALGO ".builder.ini" +#define BUILD_SSD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_ssd" #VT "_" #DM "_" #ALGO ".ini" +#define BUILD_SSD_BUILDER_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_build_ssd" #VT "_" #DM "_" #ALGO ".builder.ini" +#define SEARCH_SSD_CONFIG(VT, DM, ALGO) SSDTEST_DIRECTORY "test_search_ssd_" #VT "_" #DM "_" #ALGO ".ini" +#define SEARCH_SSD_BUILDER_CONFIG(VT, DM, ALGO) \ + SSDTEST_DIRECTORY "test_search_ssd_" #VT "_" #DM "_" #ALGO ".builder.ini" +#define SEARCH_SSD_RESULT(VT, DM, ALGO, FT, TFT) \ + SSDTEST_DIRECTORY "test_search_ssd_" #VT "_" #DM "_" #ALGO "_" #FT "_" #TFT ".result" + +#define GVQ(VT, FT) \ + BOOST_AUTO_TEST_CASE(GenerateVectorsQueries##VT##FT) \ + { \ + boost::filesystem::create_directory(SSDTEST_DIRECTORY_NAME); \ + GenVec(VECTORS(VT, FT), SPTAG::VectorValueType::VT, SPTAG::VectorFileType::FT, VECTOR_NUM, VECTOR_DIM); \ + GenVec(QUERIES(VT, FT), SPTAG::VectorValueType::VT, SPTAG::VectorFileType::FT, QUERY_NUM, VECTOR_DIM); \ + } GVQ(Float, DEFAULT) GVQ(Int16, DEFAULT) @@ -418,251 +451,320 @@ GVQ(UInt8, XVEC) GVQ(Int8, XVEC) #undef GVQ -#define WTEV(VT, DM, FT) \ -BOOST_AUTO_TEST_CASE(TestHead##VT##DM##FT) { \ - std::string configName = SELECT_HEAD_CONFIG(VT, DM, FT); \ - std::string OutputIDFile = HEAD_IDS(VT, DM, FT); \ - std::string OutputVectorFile = HEAD_VECTORS(VT, DM, FT); \ - std::string base_config = CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::Undefined, VECTOR_DIM, \ - VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::TruthFileType::Undefined, \ - false, \ - OutputIDFile, \ - OutputVectorFile, \ - "", \ - "" \ - ); \ - TestHead(configName, OutputIDFile, OutputVectorFile, base_config); \ -} \ - +#define WTEV(VT, DM, FT) \ + BOOST_AUTO_TEST_CASE(TestHead##VT##DM##FT) \ + { \ + std::string configName = SELECT_HEAD_CONFIG(VT, DM, FT); \ + std::string OutputIDFile = HEAD_IDS(VT, DM, FT); \ + std::string OutputVectorFile = HEAD_VECTORS(VT, DM, FT); \ + std::string base_config = \ + CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::Undefined, \ + VECTOR_DIM, VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", "", \ + SPTAG::VectorFileType::Undefined, -1, "", "", SPTAG::VectorFileType::Undefined, -1, "", \ + "", SPTAG::TruthFileType::Undefined, false, OutputIDFile, OutputVectorFile, "", ""); \ + TestHead(configName, OutputIDFile, OutputVectorFile, base_config); \ + } + +#ifdef FLOAT +#ifdef TDEFAULT WTEV(Float, L2, DEFAULT) WTEV(Float, Cosine, DEFAULT) -WTEV(Int16, L2, DEFAULT) -WTEV(Int16, Cosine, DEFAULT) -WTEV(UInt8, L2, DEFAULT) -WTEV(UInt8, Cosine, DEFAULT) -WTEV(Int8, L2, DEFAULT) -WTEV(Int8, Cosine, DEFAULT) +#endif +#ifdef TXVEC WTEV(Float, L2, XVEC) WTEV(Float, Cosine, XVEC) +#endif +#endif + +#ifdef INT16 +#ifdef TDEFAULT +WTEV(Int16, L2, DEFAULT) +WTEV(Int16, Cosine, DEFAULT) +#endif + +#ifdef TXVEC WTEV(Int16, L2, XVEC) WTEV(Int16, Cosine, XVEC) +#endif +#endif + +#ifdef UINT8 +#ifdef TDEFAULT +WTEV(UInt8, L2, DEFAULT) +WTEV(UInt8, Cosine, DEFAULT) +#endif + +#ifdef TXVEC WTEV(UInt8, L2, XVEC) WTEV(UInt8, Cosine, XVEC) +#endif +#endif + +#ifdef INT8 +#ifdef TDEFAULT +WTEV(Int8, L2, DEFAULT) +WTEV(Int8, Cosine, DEFAULT) +#endif + +#ifdef TXVEC WTEV(Int8, L2, XVEC) WTEV(Int8, Cosine, XVEC) +#endif +#endif #undef WTEV -#define BDHD(VT, DM, ALGO, FT) \ -BOOST_AUTO_TEST_CASE(TestBuildHead##VT##DM##ALGO##FT) { \ - std::string configName = BUILD_HEAD_CONFIG(VT, DM, ALGO); \ - std::string builderFile = BUILD_HEAD_BUILDER_CONFIG(VT, DM, ALGO); \ - std::string base_config = CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, VECTOR_DIM, \ - VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::TruthFileType::Undefined, \ - false, \ - "", \ - HEAD_VECTORS(VT, DM, FT), \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - "" \ - ); \ -TestBuildHead( \ - configName, \ - HEAD_VECTORS(VT, DM, FT), \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - SPTAG::IndexAlgoType::ALGO, \ - builderFile, \ - base_config \ -); \ -} \ - +#define BDHD(VT, DM, ALGO, FT) \ + BOOST_AUTO_TEST_CASE(TestBuildHead##VT##DM##ALGO##FT) \ + { \ + std::string configName = BUILD_HEAD_CONFIG(VT, DM, ALGO); \ + std::string builderFile = BUILD_HEAD_BUILDER_CONFIG(VT, DM, ALGO); \ + std::string base_config = CreateBaseConfig( \ + SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, VECTOR_DIM, \ + VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", "", SPTAG::VectorFileType::Undefined, -1, "", \ + "", SPTAG::VectorFileType::Undefined, -1, "", "", SPTAG::TruthFileType::Undefined, false, "", \ + HEAD_VECTORS(VT, DM, FT), HEAD_INDEX(VT, DM, ALGO, FT), ""); \ + TestBuildHead(configName, HEAD_VECTORS(VT, DM, FT), HEAD_INDEX(VT, DM, ALGO, FT), SPTAG::IndexAlgoType::ALGO, \ + builderFile, base_config); \ + } + +#ifdef FLOAT +#ifdef TDEFAULT BDHD(Float, L2, BKT, DEFAULT) BDHD(Float, L2, KDT, DEFAULT) BDHD(Float, Cosine, BKT, DEFAULT) BDHD(Float, Cosine, KDT, DEFAULT) +#endif -BDHD(Int8, L2, BKT, DEFAULT) -BDHD(Int8, L2, KDT, DEFAULT) -BDHD(Int8, Cosine, BKT, DEFAULT) -BDHD(Int8, Cosine, KDT, DEFAULT) - -BDHD(UInt8, L2, BKT, DEFAULT) -BDHD(UInt8, L2, KDT, DEFAULT) -BDHD(UInt8, Cosine, BKT, DEFAULT) -BDHD(UInt8, Cosine, KDT, DEFAULT) - -BDHD(Int16, L2, BKT, DEFAULT) -BDHD(Int16, L2, KDT, DEFAULT) -BDHD(Int16, Cosine, BKT, DEFAULT) -BDHD(Int16, Cosine, KDT, DEFAULT) - -//XVEC +#ifdef TXVEC BDHD(Float, L2, BKT, XVEC) BDHD(Float, L2, KDT, XVEC) BDHD(Float, Cosine, BKT, XVEC) BDHD(Float, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef INT8 +#ifdef TDEFAULT +BDHD(Int8, L2, BKT, DEFAULT) +BDHD(Int8, L2, KDT, DEFAULT) +BDHD(Int8, Cosine, BKT, DEFAULT) +BDHD(Int8, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDHD(Int8, L2, BKT, XVEC) BDHD(Int8, L2, KDT, XVEC) BDHD(Int8, Cosine, BKT, XVEC) BDHD(Int8, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef UINT8 +#ifdef TDEFAULT +BDHD(UInt8, L2, BKT, DEFAULT) +BDHD(UInt8, L2, KDT, DEFAULT) +BDHD(UInt8, Cosine, BKT, DEFAULT) +BDHD(UInt8, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDHD(UInt8, L2, BKT, XVEC) BDHD(UInt8, L2, KDT, XVEC) BDHD(UInt8, Cosine, BKT, XVEC) BDHD(UInt8, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef INT16 +#ifdef TDEFAULT +BDHD(Int16, L2, BKT, DEFAULT) +BDHD(Int16, L2, KDT, DEFAULT) +BDHD(Int16, Cosine, BKT, DEFAULT) +BDHD(Int16, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDHD(Int16, L2, BKT, XVEC) BDHD(Int16, L2, KDT, XVEC) BDHD(Int16, Cosine, BKT, XVEC) BDHD(Int16, Cosine, KDT, XVEC) +#endif +#endif #undef BDHD -#define BDSSD(VT, DM, ALGO, FT) \ -BOOST_AUTO_TEST_CASE(TestBuildSSDIndex##VT##DM##ALGO##FT) { \ - std::string configName = BUILD_SSD_CONFIG(VT, DM, ALGO); \ - std::string base_config = CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, VECTOR_DIM, \ - VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::VectorFileType::Undefined, -1, "", \ - "", SPTAG::TruthFileType::Undefined, \ - false, \ - HEAD_IDS(VT, DM, FT), \ - "", \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - SSD_INDEX(VT, DM, ALGO, FT) \ - ); \ -TestBuildSSDIndex(\ - configName, \ - HEAD_IDS(VT, DM, FT), \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - BUILD_SSD_BUILDER_CONFIG(VT, DM, ALGO), \ - SSD_INDEX(VT, DM, ALGO, FT), \ - true, \ - base_config \ -);} \ - -// DEFAULT +#define BDSSD(VT, DM, ALGO, FT) \ + BOOST_AUTO_TEST_CASE(TestBuildSSDIndex##VT##DM##ALGO##FT) \ + { \ + std::string configName = BUILD_SSD_CONFIG(VT, DM, ALGO); \ + std::string base_config = CreateBaseConfig( \ + SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, VECTOR_DIM, \ + VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", "", SPTAG::VectorFileType::Undefined, -1, "", \ + "", SPTAG::VectorFileType::Undefined, -1, "", "", SPTAG::TruthFileType::Undefined, false, \ + HEAD_IDS(VT, DM, FT), "", HEAD_INDEX(VT, DM, ALGO, FT), SSD_INDEX(VT, DM, ALGO, FT)); \ + TestBuildSSDIndex(configName, HEAD_IDS(VT, DM, FT), HEAD_INDEX(VT, DM, ALGO, FT), \ + BUILD_SSD_BUILDER_CONFIG(VT, DM, ALGO), SSD_INDEX(VT, DM, ALGO, FT), true, base_config); \ + } + +#ifdef FLOAT +#ifdef TDEFAULT BDSSD(Float, L2, BKT, DEFAULT) BDSSD(Float, L2, KDT, DEFAULT) BDSSD(Float, Cosine, BKT, DEFAULT) BDSSD(Float, Cosine, KDT, DEFAULT) +#endif -BDSSD(Int8, L2, BKT, DEFAULT) -BDSSD(Int8, L2, KDT, DEFAULT) -BDSSD(Int8, Cosine, BKT, DEFAULT) -BDSSD(Int8, Cosine, KDT, DEFAULT) - -BDSSD(UInt8, L2, BKT, DEFAULT) -BDSSD(UInt8, L2, KDT, DEFAULT) -BDSSD(UInt8, Cosine, BKT, DEFAULT) -BDSSD(UInt8, Cosine, KDT, DEFAULT) - -BDSSD(Int16, L2, BKT, DEFAULT) -BDSSD(Int16, L2, KDT, DEFAULT) -BDSSD(Int16, Cosine, BKT, DEFAULT) -BDSSD(Int16, Cosine, KDT, DEFAULT) - -// XVEC +#ifdef TXVEC BDSSD(Float, L2, BKT, XVEC) BDSSD(Float, L2, KDT, XVEC) BDSSD(Float, Cosine, BKT, XVEC) BDSSD(Float, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef INT8 +#ifdef TDEFAULT +BDSSD(Int8, L2, BKT, DEFAULT) +BDSSD(Int8, L2, KDT, DEFAULT) +BDSSD(Int8, Cosine, BKT, DEFAULT) +BDSSD(Int8, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDSSD(Int8, L2, BKT, XVEC) BDSSD(Int8, L2, KDT, XVEC) BDSSD(Int8, Cosine, BKT, XVEC) BDSSD(Int8, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef UINT8 +#ifdef TDEFAULT +BDSSD(UInt8, L2, BKT, DEFAULT) +BDSSD(UInt8, L2, KDT, DEFAULT) +BDSSD(UInt8, Cosine, BKT, DEFAULT) +BDSSD(UInt8, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDSSD(UInt8, L2, BKT, XVEC) BDSSD(UInt8, L2, KDT, XVEC) BDSSD(UInt8, Cosine, BKT, XVEC) BDSSD(UInt8, Cosine, KDT, XVEC) +#endif +#endif + +#ifdef INT16 +#ifdef TDEFAULT +BDSSD(Int16, L2, BKT, DEFAULT) +BDSSD(Int16, L2, KDT, DEFAULT) +BDSSD(Int16, Cosine, BKT, DEFAULT) +BDSSD(Int16, Cosine, KDT, DEFAULT) +#endif +#ifdef TXVEC BDSSD(Int16, L2, BKT, XVEC) BDSSD(Int16, L2, KDT, XVEC) BDSSD(Int16, Cosine, BKT, XVEC) BDSSD(Int16, Cosine, KDT, XVEC) -#undef BDSSD - +#endif +#endif -#define SCSSD(VT, DM, ALGO, FT, TFT) \ -BOOST_AUTO_TEST_CASE(TestSearchSSDIndex##VT##DM##ALGO##FT##TFT) { \ - std::string configName = SEARCH_SSD_CONFIG(VT, DM, ALGO); \ - std::string base_config = CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, VECTOR_DIM, \ - VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", \ - QUERIES(VT, FT), SPTAG::VectorFileType::FT, QUERY_NUM, "", \ - QUERIES(VT, FT), SPTAG::VectorFileType::FT, QUERY_NUM, "", \ - TRUTHSET(VT, DM, FT, TFT), SPTAG::TruthFileType::TFT, \ - true, \ - HEAD_IDS(VT, DM, FT), \ - "", \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - SSD_INDEX(VT, DM, ALGO, FT) \ - ); \ -TestSearchSSDIndex( \ - configName, \ - HEAD_IDS(VT, DM, FT), \ - HEAD_INDEX(VT, DM, ALGO, FT), \ - SEARCH_SSD_BUILDER_CONFIG(VT, DM, ALGO), \ - SSD_INDEX(VT, DM, ALGO, FT), \ - SEARCH_SSD_RESULT(VT, DM, ALGO, FT, TFT), \ - "", \ - base_config \ -);} \ +#undef BDSSD +#define SCSSD(VT, DM, ALGO, FT, TFT) \ + BOOST_AUTO_TEST_CASE(TestSearchSSDIndex##VT##DM##ALGO##FT##TFT) \ + { \ + std::string configName = SEARCH_SSD_CONFIG(VT, DM, ALGO); \ + std::string base_config = \ + CreateBaseConfig(SPTAG::VectorValueType::VT, SPTAG::DistCalcMethod::DM, SPTAG::IndexAlgoType::ALGO, \ + VECTOR_DIM, VECTORS(VT, FT), SPTAG::VectorFileType::FT, VECTOR_NUM, "", QUERIES(VT, FT), \ + SPTAG::VectorFileType::FT, QUERY_NUM, "", QUERIES(VT, FT), SPTAG::VectorFileType::FT, \ + QUERY_NUM, "", TRUTHSET(VT, DM, FT, TFT), SPTAG::TruthFileType::TFT, true, \ + HEAD_IDS(VT, DM, FT), "", HEAD_INDEX(VT, DM, ALGO, FT), SSD_INDEX(VT, DM, ALGO, FT)); \ + TestSearchSSDIndex(configName, HEAD_IDS(VT, DM, FT), HEAD_INDEX(VT, DM, ALGO, FT), \ + SEARCH_SSD_BUILDER_CONFIG(VT, DM, ALGO), SSD_INDEX(VT, DM, ALGO, FT), \ + SEARCH_SSD_RESULT(VT, DM, ALGO, FT, TFT), "", base_config); \ + } + +#ifdef FLOAT +#ifdef TDEFAULT SCSSD(Float, L2, BKT, DEFAULT, TXT) SCSSD(Float, L2, KDT, DEFAULT, TXT) SCSSD(Float, Cosine, BKT, DEFAULT, TXT) SCSSD(Float, Cosine, KDT, DEFAULT, TXT) +#endif -SCSSD(Int8, L2, BKT, DEFAULT, TXT) -SCSSD(Int8, L2, KDT, DEFAULT, TXT) -SCSSD(Int8, Cosine, BKT, DEFAULT, TXT) -SCSSD(Int8, Cosine, KDT, DEFAULT, TXT) - -SCSSD(UInt8, L2, BKT, DEFAULT, TXT) -SCSSD(UInt8, L2, KDT, DEFAULT, TXT) -SCSSD(UInt8, Cosine, BKT, DEFAULT, TXT) -SCSSD(UInt8, Cosine, KDT, DEFAULT, TXT) - -SCSSD(Int16, L2, BKT, DEFAULT, TXT) -SCSSD(Int16, L2, KDT, DEFAULT, TXT) -SCSSD(Int16, Cosine, BKT, DEFAULT, TXT) -SCSSD(Int16, Cosine, KDT, DEFAULT, TXT) - - -//Another +#ifdef TXVEC SCSSD(Float, L2, BKT, XVEC, XVEC) SCSSD(Float, L2, KDT, XVEC, XVEC) SCSSD(Float, Cosine, BKT, XVEC, XVEC) SCSSD(Float, Cosine, KDT, XVEC, XVEC) +#endif +#endif + +#ifdef INT8 +#ifdef TDEFAULT +SCSSD(Int8, L2, BKT, DEFAULT, TXT) +SCSSD(Int8, L2, KDT, DEFAULT, TXT) +SCSSD(Int8, Cosine, BKT, DEFAULT, TXT) +SCSSD(Int8, Cosine, KDT, DEFAULT, TXT) +#endif +#ifdef TXVEC SCSSD(Int8, L2, BKT, XVEC, XVEC) SCSSD(Int8, L2, KDT, XVEC, XVEC) SCSSD(Int8, Cosine, BKT, XVEC, XVEC) SCSSD(Int8, Cosine, KDT, XVEC, XVEC) +#endif +#endif + +#ifdef UINT8 +#ifdef TDEFAULT +SCSSD(UInt8, L2, BKT, DEFAULT, TXT) +SCSSD(UInt8, L2, KDT, DEFAULT, TXT) +SCSSD(UInt8, Cosine, BKT, DEFAULT, TXT) +SCSSD(UInt8, Cosine, KDT, DEFAULT, TXT) +#endif +#ifdef TXVEC SCSSD(UInt8, L2, BKT, XVEC, XVEC) SCSSD(UInt8, L2, KDT, XVEC, XVEC) SCSSD(UInt8, Cosine, BKT, XVEC, XVEC) SCSSD(UInt8, Cosine, KDT, XVEC, XVEC) +#endif +#endif + +#ifdef INT16 +#ifdef TDEFAULT +SCSSD(Int16, L2, BKT, DEFAULT, TXT) +SCSSD(Int16, L2, KDT, DEFAULT, TXT) +SCSSD(Int16, Cosine, BKT, DEFAULT, TXT) +SCSSD(Int16, Cosine, KDT, DEFAULT, TXT) +#endif +#ifdef TXVEC SCSSD(Int16, L2, BKT, XVEC, XVEC) SCSSD(Int16, L2, KDT, XVEC, XVEC) SCSSD(Int16, Cosine, BKT, XVEC, XVEC) SCSSD(Int16, Cosine, KDT, XVEC, XVEC) +#endif +#endif + #undef SCSSD -BOOST_AUTO_TEST_CASE(RUN_FROM_MAP) { - RunFromMap(); +#undef FLOAT +#undef INT8 +//#undef UINT8 +#undef INT16 +#undef TDEFAULT +#undef TXVEC + +BOOST_AUTO_TEST_CASE(RUN_FROM_MAP) +{ + RunFromMap(); } BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/StringConvertTest.cpp b/Test/src/StringConvertTest.cpp index fa457debe..568eee00d 100644 --- a/Test/src/StringConvertTest.cpp +++ b/Test/src/StringConvertTest.cpp @@ -1,33 +1,32 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "inc/Test.h" #include "inc/Helper/StringConvert.h" +#include "inc/Test.h" namespace { - namespace Local - { - - template - void TestConvertSuccCase(ValueType p_val, const char* p_valStr) - { - using namespace SPTAG::Helper::Convert; - - std::string str = ConvertToString(p_val); - if (nullptr != p_valStr) - { - BOOST_CHECK(str == p_valStr); - } +namespace Local +{ - ValueType val; - BOOST_CHECK(ConvertStringTo(str.c_str(), val)); - BOOST_CHECK(val == p_val); - } +template void TestConvertSuccCase(ValueType p_val, const char *p_valStr) +{ + using namespace SPTAG::Helper::Convert; + std::string str = ConvertToString(p_val); + if (nullptr != p_valStr) + { + BOOST_CHECK(str == p_valStr); } + + ValueType val; + BOOST_CHECK(ConvertStringTo(str.c_str(), val)); + BOOST_CHECK(val == p_val); } +} // namespace Local +} // namespace + BOOST_AUTO_TEST_SUITE(StringConvertTest) BOOST_AUTO_TEST_CASE(ConvertInt8) diff --git a/Test/src/TestDataGenerator.cpp b/Test/src/TestDataGenerator.cpp new file mode 100644 index 000000000..6e94425ac --- /dev/null +++ b/Test/src/TestDataGenerator.cpp @@ -0,0 +1,303 @@ +#include "inc/TestDataGenerator.h" +#include +#include +#include + +using namespace SPTAG; + +namespace TestUtils +{ + +template +TestDataGenerator::TestDataGenerator(int n, int q, int m, int k, std::string distMethod) + : m_n(n), m_q(q), m_m(m), m_k(k), m_distMethod(std::move(distMethod)) +{ +} + +template +void TestDataGenerator::Run(std::shared_ptr &vecset, std::shared_ptr &metaset, + std::shared_ptr &queryset, std::shared_ptr &truth, + std::shared_ptr &addvecset, std::shared_ptr &addmetaset, + std::shared_ptr &addtruth) +{ + LoadOrGenerateBase(vecset, metaset); + LoadOrGenerateQuery(queryset); + LoadOrGenerateAdd(addvecset, addmetaset); + LoadOrGenerateTruth("perftest_truth." + m_distMethod, vecset, queryset, truth, true); + LoadOrGenerateTruth("perftest_addtruth." + m_distMethod, CombineVectorSets(vecset, addvecset), queryset, addtruth, + true); +} +template +void TestDataGenerator::RunBatches(std::shared_ptr& vecset, + std::shared_ptr& metaset, + std::shared_ptr& addvecset, std::shared_ptr& addmetaset, + std::shared_ptr& queryset, int base, int batchinsert, int batchdelete, int batches, + std::shared_ptr& truths) +{ + LoadOrGenerateBase(vecset, metaset); + LoadOrGenerateQuery(queryset); + LoadOrGenerateAdd(addvecset, addmetaset); + LoadOrGenerateBatchTruth("perftest_batchtruth." + m_distMethod, CombineVectorSets(vecset, addvecset), queryset, + truths, base, batchinsert, batchdelete, batches, true); +} + +template +void TestDataGenerator::LoadOrGenerateBase(std::shared_ptr &vecset, std::shared_ptr &metaset) +{ + if (fileexists("perftest_vector.bin") && fileexists("perftest_meta.bin") && fileexists("perftest_metaidx.bin")) + { + auto reader = LoadReader("perftest_vector.bin"); + vecset = reader->GetVectorSet(); + metaset.reset( + new MemMetadataSet("perftest_meta.bin", "perftest_metaidx.bin", vecset->Count() * 2, MaxSize, 10)); + } + else + { + vecset = GenerateRandomVectorSet(m_n, m_m); + vecset->Save("perftest_vector.bin"); + + metaset = GenerateMetadataSet(m_n, 0); + metaset->SaveMetadata("perftest_meta.bin", "perftest_metaidx.bin"); + } +} + +template void TestDataGenerator::LoadOrGenerateQuery(std::shared_ptr &queryset) +{ + if (fileexists("perftest_query.bin")) + { + auto reader = LoadReader("perftest_query.bin"); + queryset = reader->GetVectorSet(); + } + else + { + queryset = GenerateRandomVectorSet(m_q, m_m); + queryset->Save("perftest_query.bin"); + } +} + +template +void TestDataGenerator::LoadOrGenerateAdd(std::shared_ptr &addvecset, + std::shared_ptr &addmetaset) +{ + if (fileexists("perftest_addvector.bin") && fileexists("perftest_addmeta.bin") && + fileexists("perftest_addmetaidx.bin")) + { + auto reader = LoadReader("perftest_addvector.bin"); + addvecset = reader->GetVectorSet(); + addmetaset.reset( + new MemMetadataSet("perftest_addmeta.bin", "perftest_addmetaidx.bin", addvecset->Count() * 2, MaxSize, 10)); + } + else + { + addvecset = GenerateRandomVectorSet(m_n, m_m); + addvecset->Save("perftest_addvector.bin"); + + addmetaset = GenerateMetadataSet(m_n, m_n); + addmetaset->SaveMetadata("perftest_addmeta.bin", "perftest_addmetaidx.bin"); + } +} + +template +void TestDataGenerator::LoadOrGenerateTruth(const std::string &filename, std::shared_ptr vecset, + std::shared_ptr queryset, std::shared_ptr &truth, + bool normalize) +{ + if (fileexists(filename.c_str())) + { + auto opts = std::make_shared(GetEnumValueType(), m_m, VectorFileType::DEFAULT); + auto reader = Helper::VectorSetReader::CreateInstance(opts); + if (ErrorCode::Success != reader->LoadFile(filename)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s\n", filename.c_str()); + exit(1); + } + truth = reader->GetVectorSet(); + return; + } + + DistCalcMethod distMethod; + Helper::Convert::ConvertStringTo(m_distMethod.c_str(), distMethod); + if (normalize && distMethod == DistCalcMethod::Cosine) + { + COMMON::Utils::BatchNormalize((T *)vecset->GetData(), vecset->Count(), vecset->Dimension(), + COMMON::Utils::GetBase(), 5); + } + + ByteArray tru = ByteArray::Alloc(sizeof(float) * queryset->Count() * m_k); + + std::vector mythreads; + int maxthreads = std::thread::hardware_concurrency(); + mythreads.reserve(maxthreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < maxthreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < queryset->Count()) + { + SizeType *neighbors = ((SizeType *)tru.Data()) + i * m_k; + COMMON::QueryResultSet res((const T *)queryset->GetVector(i), m_k); + for (SizeType j = 0; j < vecset->Count(); ++j) + { + float dist = COMMON::DistanceUtils::ComputeDistance( + res.GetTarget(), reinterpret_cast(vecset->GetVector(j)), m_m, distMethod); + res.AddPoint(j, dist); + } + res.SortResult(); + for (int j = 0; j < m_k; ++j) + neighbors[j] = res.GetResult(j)->VID; + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + + truth = std::make_shared(tru, GetEnumValueType(), m_k, queryset->Count()); + truth->Save(filename); +} + +template +void TestDataGenerator::LoadOrGenerateBatchTruth(const std::string &filename, + std::shared_ptr vecset, + std::shared_ptr queryset, + std::shared_ptr &truths, int base, + int batchinsert, int batchdelete, int batches, bool normalize) +{ + if (fileexists(filename.c_str())) + { + auto opts = std::make_shared(GetEnumValueType(), m_m, VectorFileType::DEFAULT); + auto reader = Helper::VectorSetReader::CreateInstance(opts); + if (ErrorCode::Success != reader->LoadFile(filename)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s\n", filename.c_str()); + exit(1); + } + truths = reader->GetVectorSet(); + return; + } + + DistCalcMethod distMethod; + Helper::Convert::ConvertStringTo(m_distMethod.c_str(), distMethod); + if (normalize && distMethod == DistCalcMethod::Cosine) + { + COMMON::Utils::BatchNormalize((T *)vecset->GetData(), vecset->Count(), vecset->Dimension(), + COMMON::Utils::GetBase(), 5); + } + + ByteArray tru = ByteArray::Alloc(sizeof(float) * (batches + 1) * queryset->Count() * m_k); + int start = 0; + int end = base; + int maxthreads = std::thread::hardware_concurrency(); + for (int iter = 0; iter < batches + 1; iter++) + { + std::vector mythreads; + mythreads.reserve(maxthreads); + std::atomic_size_t sent(0); + for (int tid = 0; tid < maxthreads; tid++) + { + mythreads.emplace_back([&, tid]() { + size_t i = 0; + while (true) + { + i = sent.fetch_add(1); + if (i < queryset->Count()) + { + SizeType *neighbors = ((SizeType *)tru.Data()) + iter * (queryset->Count() * m_k) + i * m_k; + COMMON::QueryResultSet res((const T *)queryset->GetVector(i), m_k); + for (SizeType j = start; j < end; ++j) + { + float dist = COMMON::DistanceUtils::ComputeDistance( + res.GetTarget(), reinterpret_cast(vecset->GetVector(j)), m_m, distMethod); + res.AddPoint(j, dist); + } + res.SortResult(); + for (int j = 0; j < m_k; ++j) + neighbors[j] = res.GetResult(j)->VID; + } + else + { + return; + } + } + }); + } + for (auto &t : mythreads) + { + t.join(); + } + mythreads.clear(); + start += batchdelete; + end += batchinsert; + } + truths = std::make_shared(tru, GetEnumValueType(), m_k, (batches + 1) * queryset->Count()); + truths->Save(filename); +} + +template +std::shared_ptr TestDataGenerator::GenerateRandomVectorSet(SizeType count, DimensionType dim) +{ + ByteArray vec = ByteArray::Alloc(sizeof(T) * count * dim); + for (SizeType i = 0; i < count * dim; ++i) + { + ((T *)vec.Data())[i] = (T)COMMON::Utils::rand(127, -127); + } + return std::make_shared(vec, GetEnumValueType(), dim, count); +} + +template +std::shared_ptr TestDataGenerator::GenerateMetadataSet(SizeType count, SizeType offsetBase) +{ + ByteArray meta = ByteArray::Alloc(count * 6); + ByteArray metaoffset = ByteArray::Alloc((count + 1) * sizeof(std::uint64_t)); + std::uint64_t offset = 0; + for (SizeType i = 0; i < count; i++) + { + ((std::uint64_t *)metaoffset.Data())[i] = offset; + std::string id = std::to_string(i + offsetBase); + std::memcpy(meta.Data() + offset, id.c_str(), id.length()); + offset += id.length(); + } + ((std::uint64_t *)metaoffset.Data())[count] = offset; + return std::make_shared(meta, metaoffset, count, count * 2, MaxSize, 10); +} + +template +std::shared_ptr TestDataGenerator::CombineVectorSets(std::shared_ptr base, + std::shared_ptr add) +{ + ByteArray vec = ByteArray::Alloc(sizeof(T) * (base->Count() + add->Count()) * m_m); + memcpy(vec.Data(), base->GetData(), sizeof(T) * (base->Count()) * m_m); + memcpy(vec.Data() + sizeof(T) * (base->Count()) * m_m, add->GetData(), sizeof(T) * (add->Count()) * m_m); + return std::make_shared(vec, GetEnumValueType(), m_m, base->Count() + add->Count()); +} + +template +std::shared_ptr TestDataGenerator::LoadReader(const std::string &filename) +{ + auto opts = std::make_shared(GetEnumValueType(), m_m, VectorFileType::DEFAULT); + auto reader = Helper::VectorSetReader::CreateInstance(opts); + if (ErrorCode::Success != reader->LoadFile(filename)) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read file %s\n", filename.c_str()); + exit(1); + } + return reader; +} + +// Explicit instantiation +template class TestDataGenerator; +template class TestDataGenerator; +template class TestDataGenerator; +template class TestDataGenerator; +} // namespace TestUtils diff --git a/Test/src/main.cpp b/Test/src/main.cpp index 7bf61ea11..f13500afe 100644 --- a/Test/src/main.cpp +++ b/Test/src/main.cpp @@ -5,21 +5,21 @@ #define BOOST_TEST_MODULE Main #include "inc/Test.h" -#include #include +#include using namespace boost::unit_test; class SPTAGVisitor : public test_tree_visitor { -public: - void visit(test_case const& test) + public: + void visit(test_case const &test) { std::string prefix(2, '\t'); std::cout << prefix << "Case: " << test.p_name << std::endl; } - bool test_suite_start(test_suite const& suite) + bool test_suite_start(test_suite const &suite) { std::string prefix(1, '\t'); std::cout << prefix << "Suite: " << suite.p_name << std::endl; @@ -34,8 +34,6 @@ struct GlobalFixture SPTAGVisitor visitor; traverse_test_tree(framework::master_test_suite(), visitor, false); } - }; BOOST_TEST_GLOBAL_FIXTURE(GlobalFixture); - diff --git a/ThirdParty/RocksDB b/ThirdParty/RocksDB new file mode 160000 index 000000000..275cd80cd --- /dev/null +++ b/ThirdParty/RocksDB @@ -0,0 +1 @@ +Subproject commit 275cd80cdb2de8e53c7ab805d74394309372005e diff --git a/ThirdParty/isal-l_crypto b/ThirdParty/isal-l_crypto new file mode 160000 index 000000000..08297dc3e --- /dev/null +++ b/ThirdParty/isal-l_crypto @@ -0,0 +1 @@ +Subproject commit 08297dc3e76d65e1bad83a9c9f9e49059cf806b5 diff --git a/ThirdParty/spdk b/ThirdParty/spdk new file mode 160000 index 000000000..10edc60aa --- /dev/null +++ b/ThirdParty/spdk @@ -0,0 +1 @@ +Subproject commit 10edc60aa8b5f1b04d6496fea976dec75e276a95 diff --git a/Tools/OPQ/OPQ_gpu_train_infer.py b/Tools/OPQ/OPQ_gpu_train_infer.py new file mode 100644 index 000000000..082b272d3 --- /dev/null +++ b/Tools/OPQ/OPQ_gpu_train_infer.py @@ -0,0 +1,661 @@ +import numpy as np +from struct import pack, unpack, calcsize +from struct import pack, unpack, calcsize +import heapq +import argparse +import copy +from operator import itemgetter +import os +import subprocess + +def get_config(): + parser = argparse.ArgumentParser(description ='implementation of nnsearch.') + parser.add_argument('--data_file', default = 'traindata', type = str, help = 'binary data file') + parser.add_argument('--query_file', default = 'query.tsv', type= str, help='query tsv file') + parser.add_argument('--data_normalize', default = 0, type = int, help='normalize data vectors') + parser.add_argument('--query_normalize', default = 0, type = int, help='normalize query vectors') + parser.add_argument('--data_type', default = 'float32', type = str, help = 'data type for binary file: float32, int8, int16') + parser.add_argument('--target_type', default = 'float32', type = str, help = 'GPU data type') + parser.add_argument('--k', type= int, default = 32, help='knn') + parser.add_argument('--dim', type= int, default = 100, help='vector dimensions') + parser.add_argument('--B', type= int, default = -1, help='batch data size') + parser.add_argument('--Q', type= int, default = 10000, help='batch query size') + parser.add_argument('--S', type= int, default = 1000, help='data split size') + parser.add_argument('--D', type= str, default = "L2", help='distance type') + parser.add_argument('--output_truth', type = str, default = "truth.txt", help='output truth file') + parser.add_argument('--data_format', type = str, default = "DEFAULT", help='data format') + parser.add_argument('--task', type = int, default = 0, help='task id') + parser.add_argument('-log_dir', type = str, default = "", help='debug log dir in cosmos') + + parser.add_argument('--T', type = int, default = 32, help="thread number") + parser.add_argument('--train_samples', type = int, default = 1000000, help='OPQ, PQ training samples') + parser.add_argument('--quan_type', type = str, default = 'none', help='quantizer type') + parser.add_argument('--quan_dim', type = int, default = -1, help='quantized vector dimensions') + parser.add_argument('--output_dir', type = str, default = 'quan_tmp', help='output dir') + parser.add_argument('--output_quantizer', type = str, default = "quantizer.bin", help='output quantizer file') + parser.add_argument('--output_quan_vector_file', type = str, default = "", help='quantized vectors') + parser.add_argument('--output_rec_vector_file', type = str, default = "", help = "reconstruct vectors") + parser.add_argument('--quan_test', type = int, default = 0, help='compare with ground truth') + args = parser.parse_args() + return args + +class DataReader: + def __init__(self, filename, featuredim, batchsize, normalize, datatype, targettype='float32'): + self.mytype = targettype + if filename.find('.bin') >= 0: + self.fin = open(filename, 'rb') + R = unpack('i', self.fin.read(4))[0] + self.featuredim = unpack('i', self.fin.read(4))[0] + self.isbinary = True + self.type = datatype + print ('Open Binary DataReader for data(%d,%d)...' % (R, self.featuredim)) + else: + with open(filename) as f: + R = sum(1 for _ in f) + self.fin = open(filename, 'r') + self.featuredim = featuredim + self.isbinary = False + self.type = self.mytype + + if batchsize <= 0: batchsize = R + self.query = np.zeros([batchsize, self.featuredim], dtype=self.mytype) + self.normalize = normalize + + def norm(self, data): + square = np.sqrt(np.sum(np.square(data), axis=1)) + data[square < 1e-6] = 1e-6 / math.sqrt(float(self.featuredim)) + square[square < 1e-6] = 1e-6 + data = data / square.reshape([-1, 1]) + return data + + def readbatch(self): + numQuerys = self.query.shape[0] + i = 0 + if self.isbinary: + while i < numQuerys: + vec = self.fin.read((np.dtype(self.type).itemsize)*self.featuredim) + if len(vec) == 0: break + if len(vec) != (np.dtype(self.type).itemsize)*self.featuredim: + print ("%d vector cannot be read correctly: require %d bytes but only read %d bytes" % (i, (np.dtype(self.type).itemsize)*self.featuredim, len(vec))) + continue + self.query[i] = np.frombuffer(vec, dtype=self.type).astype(self.mytype) + i += 1 + else: + while i < numQuerys: + line = self.fin.readline() + if len(line) == 0: break + + index = line.rfind("\t") + if index < 0: continue + + items = line[index+1:].split("|") + if len(items) < self.featuredim: continue + + for j in range(self.featuredim): self.query[i, j] = float(items[j]) + i += 1 + print ('Load batch query size:%r' % (i)) + if self.normalize != 0: return i, self.norm(self.query[0:i]) + return i, self.query[0:i] + + def readallbatches(self): + numQuerys = self.query.shape[0] + data = [] + R = 0 + while True: + i, q = self.readbatch() + if i == numQuerys: + data.append(copy.deepcopy(q)) + R += i + else: + if i > 0: + data.append(copy.deepcopy(q[0:i])) + R += i + break + return R, data + + def close(self): + self.fin.close() + +def gpusearch(args): + import faiss + ngpus = faiss.get_num_gpus() + print ('number of GPUs:', ngpus) + + gpu_resources = [] + for i in range(ngpus): + res = faiss.StandardGpuResources() + gpu_resources.append(res) + + datareader = DataReader(args.data_file, args.dim, args.B, args.data_normalize, args.data_type, args.target_type) + queryreader = DataReader(args.query_file, args.dim, args.Q, args.query_normalize, args.data_type, args.target_type) + RQ, dataQ = queryreader.readallbatches() + + batch = 0 + totaldata = 0 + while True: + numData, data = datareader.readbatch() + if numData == 0: + datareader.close() + break + + totaldata += numData + batch += 1 + print ("Begin batch %d" % batch) + + co = faiss.GpuMultipleClonerOptions() + co.shard = True + co.useFloat16 = False if args.target_type == 'float32' else True + co.useFloat16CoarseQuantizer = False + if args.D != 'Cosine': + cpu_index = faiss.IndexFlatL2(args.dim) + else: + cpu_index = faiss.IndexFlatIP(args.dim) + + gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=ngpus) + gpu_index.add(data) + + fout = open('truth.txt.%d' % batch, 'w') + foutd = open('dist.bin.%d' % batch, 'wb') + + foutd.write(pack('i', RQ)) + foutd.write(pack('i', args.k)) + + for query in dataQ: + D, I = gpu_index.search(query, args.k) + foutd.write(D.tobytes()) + for i in range(I.shape[0]): + for j in range(I.shape[1]): + fout.write(str(I[i][j]) + " ") + fout.write('\n') + + fout.close() + foutd.close() + + if args.B <= 0 or args.B >= totaldata: args.B = totaldata + + truth = [[] for j in range(RQ)] + for i in range(1, batch + 1): + f = open('truth.txt.%d' % i, 'r') + fd = open('dist.bin.%d' % i, 'rb') + r = unpack('i', fd.read(4))[0] + c = unpack('i', fd.read(4))[0] + print ('batch %d: r:%d c:%d RQ:%d k:%d' % (i, r, c, RQ, args.k)) + currdist = np.frombuffer(fd.read(4 * RQ * args.k), dtype=np.float32).reshape((RQ, args.k)) + fd.close() + + for j in range(RQ): + items = f.readline()[0:-1].split() + truth[j].extend([(int(items[k]) + args.B * (i-1), currdist[j][k]) for k in range(args.k)]) + truth[j].sort(key=itemgetter(1, 0)) + truth[j] = truth[j][0:args.k] + f.close() + + if not os.path.exists(args.output_truth + '.dist'): + os.mkdir(args.output_truth + '.dist') + + fout = open(args.output_truth, 'w') + foutd = open(args.output_truth + '.dist\\dist.bin.' + str(args.task), 'wb') + foutd.write(pack('i', RQ)) + foutd.write(pack('i', args.k)) + for i in range(RQ): + for j in range(args.k): + fout.write(str(truth[i][j][0]) + " ") + foutd.write(pack('i', truth[i][j][0])) + foutd.write(pack('f', truth[i][j][1])) + fout.write('\n') + fout.close() + foutd.close() + +def train_pq(args): + import faiss + from LibVQ.base_index import FaissIndex + + output_dir = args.output_dir + + datareader = DataReader(args.data_file, args.dim, args.train_samples, args.data_normalize, args.data_type, args.target_type) + + print ('train PQ...') + + index_method = 'pq' + ivf_centers_num = -1 + subvector_num = args.quan_dim + subvector_bits = 8 + numData, data = datareader.readbatch() + + faiss.omp_set_num_threads(args.T) + index = FaissIndex(index_method=index_method, + emb_size=len(data[0]), + ivf_centers_num=ivf_centers_num, + subvector_num=subvector_num, + subvector_bits=subvector_bits, + dist_mode='l2') + + print('Training the index with doc embeddings') + + index.fit(data) + + rtype = np.uint8(0) + if args.data_type == 'uint8': + rtype = np.uint8(1) + elif args.data_type == 'int16': + rtype = np.uint8(2) + elif args.data_type == 'float32': + rtype = np.uint8(3) + + faiss_index = index.index + ivf_index = faiss.downcast_index(faiss_index) + centroid_embedings = faiss.vector_to_array(ivf_index.pq.centroids) + codebooks = centroid_embedings.reshape(ivf_index.pq.M, ivf_index.pq.ksub, ivf_index.pq.dsub) + print ('codebooks shape:') + print (codebooks.shape) + + codebooks = codebooks.astype(np.float32) + with open(os.path.join(output_dir, args.output_quantizer + '.' + str(args.task)),'wb') as f: + f.write(pack('B', 1)) + f.write(pack('B', rtype)) + f.write(pack('i', codebooks.shape[0])) + f.write(pack('i', codebooks.shape[1])) + f.write(pack('i', codebooks.shape[2])) + f.write(codebooks.tobytes()) + + if args.quan_test == 0 and len(args.output_quan_vector_file) == 0 and len(args.output_rec_vector_file) == 0: + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + return + + if len(args.output_quan_vector_file) > 0: + fquan = open(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + fquan.write(pack('i', 0)) + fquan.write(pack('i', args.quan_dim)) + + if len(args.output_rec_vector_file) > 0: + frec = open(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + frec.write(pack('i', 0)) + frec.write(pack('i', data.shape[1])) + + writeitems = 0 + while numData > 0: + if args.quan_test > 0: index.add(data) + + codes = ivf_index.pq.compute_codes(data) + + print ('codes shape:') + print (codes.shape) + + if len(args.output_quan_vector_file) > 0: + fquan.write(codes.tobytes()) + + if len(args.output_rec_vector_file) > 0: + reconstructed = ivf_index.pq.decode(codes).astype(args.data_type) + frec.write(reconstructed.tobytes()) + + writeitems += numData + numData, data = datareader.readbatch() + + datareader.close() + + if len(args.output_quan_vector_file) > 0: + p = fquan.tell() + fquan.seek(0) + fquan.write(pack('i', writeitems)) + fquan.seek(p) + fquan.close() + if os.path.exists(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + if len(args.output_rec_vector_file) > 0: + p = frec.tell() + frec.seek(0) + frec.write(pack('i', writeitems)) + frec.seek(p) + frec.close() + if os.path.exists(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + + if args.quan_test > 0: + queryreader = DataReader(args.query_file, args.dim, -1, args.query_normalize, args.data_type, args.target_type) + numQuery, query = queryreader.readbatch() + + qid2ground_truths = {} + f = open(os.path.join(output_dir, 'truth.txt.' + str(args.task)), 'r') + for i in range(numQuery): + items = f.readline()[0:-1].strip().split(' ') + qid2ground_truths[i] = set([int(gt) for gt in items]) + f.close() + + # Test the performance + index.test(query, qid2ground_truths, topk=args.k, batch_size=64, + MRR_cutoffs=[5, 10], Recall_cutoffs=[5, 10, 20, 40]) + +def train_opq(args): + import faiss + from LibVQ.base_index import FaissIndex + + output_dir = args.output_dir + + datareader = DataReader(args.data_file, args.dim, args.train_samples, args.data_normalize, args.data_type, args.target_type) + + print ('train OPQ...') + + index_method = 'opq' + ivf_centers_num = -1 + subvector_num = args.quan_dim + subvector_bits = 8 + numData, data = datareader.readbatch() + + faiss.omp_set_num_threads(args.T) + index = FaissIndex(index_method=index_method, + emb_size=len(data[0]), + ivf_centers_num=ivf_centers_num, + subvector_num=subvector_num, + subvector_bits=subvector_bits, + dist_mode='l2') + + print('Training the index with doc embeddings') + + index.fit(data) + + rtype = np.uint8(0) + if args.data_type == 'uint8': + rtype = np.uint8(1) + elif args.data_type == 'int16': + rtype = np.uint8(2) + elif args.data_type == 'float32': + rtype = np.uint8(3) + + faiss_index = index.index + if isinstance(faiss_index, faiss.IndexPreTransform): + vt = faiss.downcast_VectorTransform(faiss_index.chain.at(0)) + assert isinstance(vt, faiss.LinearTransform) + rotate = faiss.vector_to_array(vt.A).reshape(vt.d_out, vt.d_in) + rotate_matrix = rotate.T + print ('rotate shape:') + print (rotate.shape) + + ivf_index = faiss.downcast_index(faiss_index.index) + centroid_embedings = faiss.vector_to_array(ivf_index.pq.centroids) + codebooks = centroid_embedings.reshape(ivf_index.pq.M, ivf_index.pq.ksub, ivf_index.pq.dsub) + print ('codebooks shape:') + print (codebooks.shape) + + codebooks = codebooks.astype(np.float32) + rotate = rotate.astype(np.float32) + with open(os.path.join(output_dir, args.output_quantizer + '.' + str(args.task)), 'wb') as f: + f.write(pack('B', 2)) + f.write(pack('B', rtype)) + f.write(pack('i', codebooks.shape[0])) + f.write(pack('i', codebooks.shape[1])) + f.write(pack('i', codebooks.shape[2])) + f.write(codebooks.tobytes()) + f.write(rotate_matrix.tobytes()) + + if args.quan_test == 0 and len(args.output_quan_vector_file) == 0 and len(args.output_rec_vector_file) == 0: + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + return + + if len(args.output_quan_vector_file) > 0: + fquan = open(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + fquan.write(pack('i', 0)) + fquan.write(pack('i', args.quan_dim)) + + if len(args.output_rec_vector_file) > 0: + frec = open(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + frec.write(pack('i', 0)) + frec.write(pack('i', data.shape[1])) + + writeitems = 0 + while numData > 0: + if args.quan_test > 0: index.add(data) + + rdata = np.matmul(data, rotate.T) + codes = ivf_index.pq.compute_codes(rdata) + + print ('codes shape:') + print (codes.shape) + if len(args.output_quan_vector_file) > 0: + fquan.write(codes.tobytes()) + + if len(args.output_rec_vector_file) > 0: + Y = ivf_index.pq.decode(codes) + reconstructed = np.matmul(Y, rotate).astype(args.data_type) + frec.write(reconstructed.tobytes()) + + writeitems += numData + numData, data = datareader.readbatch() + + datareader.close() + + if len(args.output_quan_vector_file) > 0: + p = fquan.tell() + fquan.seek(0) + fquan.write(pack('i', writeitems)) + fquan.seek(p) + fquan.close() + if os.path.exists(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + + if len(args.output_rec_vector_file) > 0: + p = frec.tell() + frec.seek(0) + frec.write(pack('i', writeitems)) + frec.seek(p) + frec.close() + if os.path.exists(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + + if args.quan_test > 0: + queryreader = DataReader(args.query_file, args.dim, -1, args.query_normalize, args.data_type, args.target_type) + numQuery, query = queryreader.readbatch() + + qid2ground_truths = {} + f = open(os.path.join(output_dir, 'truth.txt.' + str(args.task)), 'r') + for i in range(numQuery): + items = f.readline()[0:-1].strip().split(' ') + qid2ground_truths[i] = set([int(gt) for gt in items]) + f.close() + + # Test the performance + index.test(query, qid2ground_truths, topk=args.k, batch_size=64, + MRR_cutoffs=[5, 10], Recall_cutoffs=[5, 10, 20, 40]) + +def quan_reconstruct_vectors(args): + import faiss + + output_dir = args.output_dir + + datareader = DataReader(args.data_file, args.dim, args.train_samples, args.data_normalize, args.data_type, args.target_type) + numData, data = datareader.readbatch() + + print ('Quantize and Reconstruct Vectors...') + + quantizer_path = os.path.join(os.path.dirname(args.query_file), args.output_quantizer) + f = open(quantizer_path, 'rb') + pqtype = unpack('B', f.read(1))[0] + rectype = unpack('B', f.read(1))[0] + + d0 = unpack('i', f.read(4))[0] + d1 = unpack('i', f.read(4))[0] + d2 = unpack('i', f.read(4))[0] + + codebooks = np.frombuffer(f.read(d0*d1*d2*4), dtype=np.float32).reshape((d0,d1,d2)) + if pqtype == 2: + rotate_matrix = np.frombuffer(f.read(data.shape[1]*data.shape[1]*4), dtype = np.float32).reshape((data.shape[1], data.shape[1])) + rotate = np.transpose(rotate_matrix.copy()) + print (rotate_matrix.shape) + f.close() + + with open(os.path.join(output_dir, args.output_quantizer + '.' + str(args.task)), 'wb') as f: + f.write(pack('B', pqtype)) + f.write(pack('B', rectype)) + f.write(pack('i', codebooks.shape[0])) + f.write(pack('i', codebooks.shape[1])) + f.write(pack('i', codebooks.shape[2])) + f.write(codebooks.tobytes()) + if pqtype == 2: f.write(rotate_matrix.tobytes()) + f.close() + + if len(args.output_quan_vector_file) == 0 and len(args.output_rec_vector_file) == 0: + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + return + + if len(args.output_quan_vector_file) > 0: + fquan = open(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + fquan.write(pack('i', 0)) + fquan.write(pack('i', args.quan_dim)) + + if len(args.output_rec_vector_file) > 0: + frec = open(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), 'wb') + frec.write(pack('i', 0)) + frec.write(pack('i', data.shape[1])) + + def fourcc(x): + h = np.uint32(0) + h = h | ord(x[0]) | ord(x[1]) << 8 | ord(x[2]) << 16 | ord(x[3]) << 24 + return h + + with open('tmp_faiss_index', 'wb') as f: + h = fourcc('IxPq') + d = np.uint64(data.shape[1]) + M = np.uint64(codebooks.shape[0]) + nbits = np.uint64(math.log2(codebooks.shape[1])) + codesize = np.uint64(d0 * d1 * d2) + totalitems = np.uint64(0) + print ('h:%u d:%u M:%u nbits:%u codesize:%u lencode:%u' % (h, d, M, nbits, codesize, len(codebooks.tobytes()))) + f.write(pack('I', h)) + f.write(pack('i', d)) + f.write(pack('q', np.int64(0))) + dummy = np.int64(1048576) + f.write(pack('q', dummy)) + f.write(pack('q', dummy)) + f.write(pack('B', np.int8(1))) # is_trained + f.write(pack('i', np.int32(1))) # metric_type + f.write(pack('Q', d)) # size_t + f.write(pack('Q', M)) # size_t + f.write(pack('Q', nbits)) # size_t + f.write(pack('Q', codesize)) + f.write(codebooks.tobytes()) + f.write(pack('Q', totalitems)) # size_t + + f.write(pack('i', np.int32(0))) # search_type + f.write(pack('B', np.int8(0))) + f.write(pack('i', np.int32(nbits * M + 1))) + f.close() + + ivf_index = faiss.read_index('tmp_faiss_index') + if not ivf_index: + print ('Error: faiss index cannot be loaded!') + exit (1) + print ('ksubs:%d dsub:%d code_size:%d nbits:%d M:%d d:%d polysemous_ht:%d' % (ivf_index.pq.ksub, ivf_index.pq.dsub, ivf_index.pq.code_size, ivf_index.pq.nbits, ivf_index.pq.M, ivf_index.pq.d, ivf_index.polysemous_ht)) + + writeitems = 0 + while numData > 0: + print (data[0]) + if pqtype == 2: + data = np.matmul(data, rotate_matrix) + print ('rotate:') + print (data[0]) + + codes = ivf_index.pq.compute_codes(data) + print ('encode:') + print (codes[0]) + + if len(args.output_quan_vector_file) > 0: + fquan.write(codes.tobytes()) + + if len(args.output_rec_vector_file) > 0: + recY = ivf_index.pq.decode(codes) + print ('decode:') + print (recY[0]) + if pqtype == 2: + recY = np.matmul(recY, rotate).astype(args.data_type) + print ('rotateback:') + print (recY[0]) + frec.write(recY.tobytes()) + + writeitems += numData + numData, data = datareader.readbatch() + + datareader.close() + + if len(args.output_quan_vector_file) > 0: + p = fquan.tell() + fquan.seek(0) + fquan.write(pack('i', writeitems)) + fquan.seek(p) + fquan.close() + if os.path.exists(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_quan_vector_file + '.' + str(args.task))) + + if len(args.output_rec_vector_file) > 0: + p = frec.tell() + frec.seek(0) + frec.write(pack('i', writeitems)) + frec.seek(p) + frec.close() + if os.path.exists(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))): + os.remove(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + os.rename(os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task) + '.tmp'), os.path.join(output_dir, args.output_rec_vector_file + '.' + str(args.task))) + + if os.path.exists(os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))): + os.remove(os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + os.rename(args.output_truth, os.path.join(output_dir, 'truth.txt' + '.' + str(args.task))) + ret = subprocess.run(['ZipKDTree.exe', output_dir, args.output_truth]) + print (ret) + +if __name__ == '__main__': + args = get_config() + print ('log_dir:%s' % args.log_dir) + print ('output_dir:%s' % args.output_dir) + + if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) + + if args.data_format != 'DEFAULT': + target = 'PreprocessData.exe' if args.data_format == 'BOND' else 'ProcessData.exe' + casttype = 'BYTE' + if args.data_type == 'uint8': casttype = 'UBYTE' + elif args.data_type == 'int16': casttype = 'SHORT' + elif args.data_type == 'float32': casttype = 'FLOAT' + ret = subprocess.run([target, args.data_file, os.path.join(args.output_dir, 'vectors.bin.%d' % args.task), os.path.join(args.output_dir, 'meta.bin.%d' % args.task), os.path.join(args.output_dir, 'metaindex.bin.%d' % args.task), str(args.dim), casttype, '0']) + args.data_file = os.path.join(args.output_dir, 'vectors.bin.%d' % args.task) + print(ret) + + if args.query_file[-4:] != '.bin': + casttype = 'BYTE' + if args.data_type == 'uint8': casttype = 'UBYTE' + elif args.data_type == 'int16': casttype = 'SHORT' + elif args.data_type == 'float32': casttype = 'FLOAT' + ret = subprocess.run(['SearchPreprocess.exe', '-q', args.query_file, '-o', args.query_file + '_queryVector.bin', '-v', args.query_file + '_validQuery.bin', '-d', str(args.dim), '-t', casttype, '-n', '0']) + args.query_file = args.query_file + '_queryVector.bin' + print(ret) + + gpusearch(args) + + if args.quan_type != 'none': + if args.quan_type == 'pq': + train_pq(args) + elif args.quan_type == 'opq': + train_opq(args) + elif args.quan_type == 'quan_reconstruct': + quan_reconstruct_vectors(args) + + if args.log_dir != '': + localpath = args.output_truth + '.dist\\dist.bin.' + str(args.task) + ret = subprocess.run(['CosmosFolderTransfer.exe', 'uploadStream', localpath.replace('\\', '/'), args.log_dir + '/dist/', 'ap']) + print (ret) diff --git a/Tools/OPQ/README.md b/Tools/OPQ/README.md new file mode 100644 index 000000000..5931a3d8d --- /dev/null +++ b/Tools/OPQ/README.md @@ -0,0 +1,12 @@ +# OPQ gpu training and inference tool + +## Package Requirements (tbd) + +1. Python>=3.7 +2. numpy>=1.18.1 +3. faiss>=1.7.0 +4. LibVQ + + +## Parameter Sample +--data_file [input_path]\vectors.bin.0 --query_file [model_path]\query.bin --output_truth [output_path] --output_dir [output_path]\5474\cluster_unzip --task 0 --data_type float32 --k 5 --dim 1024 --B 1000000 --Q 1000 --D L2 --data_format DEFAULT --T 18 --train_samples 1000000 --quan_type opq --quan_dim 1024 --output_quantizer quantizer.bin --output_quan_vector_file dssm_vectors.bin --output_rec_vector_file vectors.bin --quan_test 1 --data_normalize 0 --query_normalize 0 diff --git a/Tools/nni-auto-tune/README.md b/Tools/nni-auto-tune/README.md new file mode 100644 index 000000000..923bd7d12 --- /dev/null +++ b/Tools/nni-auto-tune/README.md @@ -0,0 +1,125 @@ +# ANN-auto-tune using NNI + + +## Requirements + +This example requires NNI >= 2.8, python == 3.6/3.8 + + +```sh +pip install nni +``` +For more details, please view [NNI installation](https://nni.readthedocs.io/en/stable/installation.html) + +Install SPTAG + +```sh +pip install -i https://test.pypi.org/simple/ sptag +``` + +Install other dependencies + +```sh +pip install multiprocess +``` + +## Dataset + +We support muiltiple types of data for training. Including text file, binary file and [ann-benchmark](https://github.com/erikbern/ann-benchmarks) format hdf5 file. But the groundturth file should only be texts of index. + +In many cased, parameters that works fine in sampled dataset also work in original dataset. So we also provided a preprocessing script for dataset sampling and ground truth calculation. You can use follow command to do sample and ground truth pre-calculate. + +```sh +python preprocessing.py --train_file victors.bin --query_file query.bin --output_dir sampled/ --distance euclidean --num_sample 100000 +``` + +If you only need to calculate ground truth for time save in auto-tune, which we highly recommend, you can set `--num_sample` to -1 + +## Quickstart + + +Use this command to start a NNI trial to tune SPTAG model on ann-benchmark format hdf5 sift-128-euclidean dataset. +```sh +nnictl create --config config.yml +``` + +If you wish to tune SPTAG on a binary dataset or text dataset, the input file format are as below. + +#### DEFAULT (Binary) +> Input raw data for index build and input query file for index search (suppose vector dimension is 3): + +``` +<4 bytes int representing num_vectors><4 bytes int representing num_dimension> + +``` + +#### TXT +> Input raw data for index build and input query file for index search (suppose vector dimension is 3): + +``` +\t||| +\t||| +... +``` +where each line represents a vector with its metadata and its value separated by a tab space. Each dimension of a vector is separated by | or use --delimiter to define the separator. + +> Truth file to calculate recall (suppose K is 2): +``` + + +... +``` +where each line represents the K nearest neighbors of a query separated by a blank space. Each neighbor is given by its vector id. + +Then you can change `trialCommand` in config.yml to: + +```sh +python main.py --train_file victors.bin --query_file query.bin --label_file truth.txt --distance euclidean +``` + +You can also specify `--max_build_time` or `--max_memory` if you have build time limit or search memory usage limit in your project. they both default to -1, which mean no limit. + +**NOTE:** Always clear corresponding folder under `results/` before starting a trial on same dataset. + +## Results + +Install matplotlib for figure drawing + +```sh +pip install matplotlib +``` + +During the trial, the results are saved as json files in `results/(dataset_name)`. Use following command to visualize results. + +```sh +plot.py --path sift-128-euclidean +``` +The figure shows the correspondence between recall and qps. And you can see the details of each selected point in console. + + +The following are the results of sptag and other algorithms on different datasets + +sift-128-euclidean +------------------ + +![sift-128-euclidean](picture/sift-128-euclidean.png) + +glove-100-angular +------------------ + +![glove-100-angular](picture/glove-100-angular.png) + +glove-25-angular +------------------ + +![glove-25-angular](picture/glove-25-angular.png) + +nytimes-256-angular +------------------ + +![nytimes-256-angular](picture/nytimes-256-angular.png) + +fashion-mnist-784-euclidean +------------------ + +![fashion-mnist-784-euclidean](picture/fashion-mnist-784-euclidean.png) diff --git a/Tools/nni-auto-tune/config.yml b/Tools/nni-auto-tune/config.yml new file mode 100755 index 000000000..722649767 --- /dev/null +++ b/Tools/nni-auto-tune/config.yml @@ -0,0 +1,14 @@ +experimentName: sift128 +trialConcurrency: 4 +maxExperimentDuration: 168h +searchSpaceFile: search_space.json +trialCommand: python main.py --train_file sift-128-euclidean.hdf5 + +tuner: + name: TPE + classArgs: + optimize_mode: maximize + +trainingService: + platform: local + diff --git a/Tools/nni-auto-tune/config_aml.yml b/Tools/nni-auto-tune/config_aml.yml new file mode 100755 index 000000000..8c708c57e --- /dev/null +++ b/Tools/nni-auto-tune/config_aml.yml @@ -0,0 +1,26 @@ +experimentName: sift128 +trialConcurrency: 4 +maxExperimentDuration: 168h +searchSpaceFile: search_space.json + +trialCommand: + python --version; + pip --version; + pip install -i https://test.pypi.org/simple/ sptag; + git clone https://github.com/microsoft/SPTAG.git; + cd SPTAG/tools/nni-auto-tune; + wget http://ann-benchmarks.com/sift-128-euclidean.hdf5; + python main.py --train_file sift-128-euclidean.hdf5 + +tuner: + name: TPE + classArgs: + optimize_mode: maximize + +trainingService: + platform: aml + dockerImage: msranni/nni + subscriptionId: + resourceGroup: + workspaceName: + computeTarget: \ No newline at end of file diff --git a/Tools/nni-auto-tune/dataset.py b/Tools/nni-auto-tune/dataset.py new file mode 100755 index 000000000..45e823873 --- /dev/null +++ b/Tools/nni-auto-tune/dataset.py @@ -0,0 +1,146 @@ +import h5py +import numpy as np +import os +import random +from struct import pack, unpack, calcsize +import math +import argparse +import copy + + +class DataReader: + + def __init__(self, + filename, + featuredim, + batchsize, + datatype='float32', + normalize=False, + targettype='float32'): + self.mytype = targettype + if filename.find('.bin') >= 0: + self.fin = open(filename, 'rb') + R = unpack('i', self.fin.read(4))[0] + self.featuredim = unpack('i', self.fin.read(4))[0] + self.isbinary = True + self.type = datatype + print('Open Binary DataReader for data(%d,%d)...' % + (R, self.featuredim)) + else: + with open(filename) as f: + R = sum(1 for _ in f) + self.fin = open(filename, 'r') + self.featuredim = featuredim + self.isbinary = False + self.type = self.mytype + + if batchsize <= 0: batchsize = R + self.query = np.zeros([batchsize, self.featuredim], dtype=self.mytype) + self.normalize = normalize + + def norm(self, data): + square = np.sqrt(np.sum(np.square(data), axis=1)) + data[square < 1e-6] = 1e-6 / math.sqrt(float(self.featuredim)) + square[square < 1e-6] = 1e-6 + data = data / square.reshape([-1, 1]) + return data + + def readbatch(self): + numQuerys = self.query.shape[0] + i = 0 + if self.isbinary: + while i < numQuerys: + vec = self.fin.read( + (np.dtype(self.type).itemsize) * self.featuredim) + if len(vec) == 0: break + if len(vec) != (np.dtype( + self.type).itemsize) * self.featuredim: + print( + "%d vector cannot be read correctly: require %d bytes but only read %d bytes" + % (i, (np.dtype(self.type).itemsize) * self.featuredim, + len(vec))) + continue + self.query[i] = np.frombuffer(vec, dtype=self.type).astype( + self.mytype) + i += 1 + else: + while i < numQuerys: + line = self.fin.readline() + if len(line) == 0: break + + index = line.rfind("\t") + if index < 0: continue + + items = line[index + 1:].split("|") + if len(items) < self.featuredim: continue + + for j in range(self.featuredim): + self.query[i, j] = float(items[j]) + i += 1 + print('Load batch query size:%r' % (i)) + if self.normalize != 0: return i, self.norm(self.query[0:i]) + return i, self.query[0:i] + + def readallbatches(self): + numQuerys = self.query.shape[0] + data = [] + R = 0 + while True: + i, q = self.readbatch() + if i == numQuerys: + data.append(copy.deepcopy(q)) + R += i + else: + if i > 0: + data.append(copy.deepcopy(q[0:i])) + R += i + break + return R, np.array(data) + + def close(self): + self.fin.close() + + +def dataset_transform(dataset): + if dataset.attrs.get('type', 'dense') != 'sparse': + return np.array(dataset['train']), np.array(dataset['test']) + return sparse_to_lists(dataset['train'], + dataset['size_train']), sparse_to_lists( + dataset['test'], dataset['size_test']) + + +class HDF5Reader: + + def __init__(self, filename, data_type='float32'): + self.data = h5py.File(filename, 'r') + self._data_type = data_type + self.featuredim = int(self.data.attrs['dimension'] + ) if 'dimension' in self.data.attrs else len( + self.data['train'][0]) + self.train, self.test = dataset_transform(self.data) + self.distance = self.data.attrs['distance'] + self.label = np.array(self.data['distances']) + + def norm(self, data): + square = np.sqrt(np.sum(np.square(data), axis=1)) + data[square < 1e-6] = 1e-6 / math.sqrt(float(self.featuredim)) + square[square < 1e-6] = 1e-6 + data = data / square.reshape([-1, 1]) + return data + + def readallbatches(self): + return np.array(self.train, + dtype=self._data_type), np.array(self.test, + dtype=self._data_type) + + def close(self): + pass + + +def sparse_to_lists(data, lengths): + X = [] + index = 0 + for l in lengths: + X.append(data[index:index + l]) + index += l + return X diff --git a/Tools/nni-auto-tune/main.py b/Tools/nni-auto-tune/main.py new file mode 100755 index 000000000..55ccb12d2 --- /dev/null +++ b/Tools/nni-auto-tune/main.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import nni +import h5py +import time + +import numpy as np +import os +from model import Sptag, BruteForceBLAS +from runner import run_individual_query +from dataset import DataReader, HDF5Reader +import argparse +import json +import shutil +import itertools +from multiprocess import Pool, Process +import multiprocess + + +def knn_threshold(data, k, epsilon): + return data[k - 1] + epsilon + + +def get_recall_from_distance(dataset_distances, + run_distances, + k, + epsilon=1e-3): + recalls = np.zeros(len(run_distances)) + for i in range(len(run_distances)): + t = knn_threshold(dataset_distances[i], k, epsilon) + actual = 0 + for d in run_distances[i][:k]: + if d <= t: + actual += 1 + recalls[i] = actual + + return (np.mean(recalls) / float(k), np.std(recalls) / float(k), recalls) + + +def get_recall_from_index(dataset_index, run_index, k): + recalls = np.zeros(len(run_index)) + for i in range(len(run_index)): + actual = 0 + for d in run_index[i][:k]: + # need to conver to string because default loaded label are strings + if str(d) in dataset_index[i][:k]: + actual += 1 + recalls[i] = actual + + return (np.mean(recalls) / float(k), np.std(recalls) / float(k), recalls) + + +def queries_per_second(attrs): + return 1.0 / attrs["best_search_time"] + + +def compute_metrics(groundtruth, attrs, results, k, from_index=False): + if from_index: + mean, std, recalls = get_recall_from_index(groundtruth, results, k) + else: + mean, std, recalls = get_recall_from_distance(groundtruth, results, k) + qps = queries_per_second(attrs) + print('mean: %12.3f,std: %12.3f, qps: %12.3f' % (mean, std, qps)) + return mean, qps + + +def grid_search(params): + param_num = len(params) + params = list(params.items()) + max_param_choices = max([len(p[1]) for p in params]) + temp = [] + for i in range(max_param_choices): + temp += [i for _ in range(param_num)] + for c in set(itertools.permutations(temp, param_num)): + res = {} + for i in range(len(c)): + if c[i] >= len(params[i][1]): + break + else: + res[params[i][0]] = params[i][1][c[i]] + if len(res) == param_num: + yield res + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--train_file', + help='the data file to load training points from, ' + 'could be text file, binary flie or ann-benchmark format hdf5 file', + default='glove-100-angular.hdf5') + parser.add_argument( + '--query_file', + help='the data file to load query points from, if you use ' + 'ann-benchmark format hdf5 file in train_file, this should be None', + default=None) + parser.add_argument( + '--label_file', + help= + 'the data file to load groundtruth index from, only support text file', + default=None) + parser.add_argument('--algorithm', + help='the name of SPTAG algorithm', + default="BKT") + parser.add_argument("--k", + default=10, + type=int, + help="the number of near neighbours to search for") + parser.add_argument("--distance", + default='angular', + help="the type of distance for searching") + parser.add_argument( + "--max_build_time", + default=-1, + type=int, + help="the limit of index build time in seconds. -1 means no limit") + parser.add_argument( + "--max_memory", + default=-1, + type=int, + help= + "the limit of memory use during searching in bytes. -1 means no limit") + parser.add_argument("--dim", + default=100, + type=int, + help="the dimention of training vectors") + parser.add_argument("--input_type", + default="float32", + help="the data type of input vectors") + parser.add_argument("--data_type", + default="float32", + help="the data type for building and search in SPTAG ") + args = parser.parse_args() + + if args.train_file.endswith(".hdf5"): + # ann-benchmark format hdf5 file got all we want, so args like distance are ignored + data_reader = HDF5Reader(args.train_file, args.data_type) + X_train, X_test = data_reader.readallbatches() + distance = data_reader.distance + dimension = data_reader.featuredim + label = data_reader.label + else: + X_train = DataReader(args.train_file, + args.dim, + batchsize=-1, + datatype=args.input_type, + targettype=args.data_type).readbatch()[1] + X_test = DataReader(args.query_file, + args.dim, + batchsize=-1, + datatype=args.input_type, + targettype=args.data_type).readbatch()[1] + distance = args.distance + dimension = args.dim + label = [] + if args.label_file is None: + # if the groundtruth is not provided + # we calculate groundtruth distances with brute force + bf = BruteForceBLAS(distance) + bf.fit(X_train) + for i, x in enumerate(X_test): + if i % 1000 == 0: + print('%d/%d...' % (i, len(X_test))) + res = list(bf.query_with_distances(x, args.k)) + res.sort(key=lambda t: t[-1]) + label.append([d for _, d in res]) + else: + label = [] + # we assume the groundtruth index are split by space + with open(args.label_file, 'r') as f: + for line in f: + label.append(line.strip().split()) + + print('got a train set of size (%d * %d)' % (X_train.shape[0], dimension)) + print('got %d queries' % len(X_test)) + + para = nni.get_next_parameter() + algo = Sptag(args.algorithm, distance) + + t0 = time.time() + if args.max_build_time > 0: + pool = Pool(1) + results = pool.apply_async(algo.fit, + kwds=dict(X=X_train, + para=para, + data_type=args.data_type, + save_index=True)) + try: + results.get(args.max_build_time + ) # Wait timeout seconds for func to complete. + algo.load('index') + shutil.rmtree("index") + pool.close() + pool.join() + except multiprocess.TimeoutError: # kill subprocess if timeout + print("Aborting due to timeout", args.max_build_time) + pool.terminate() + nni.report_final_result({ + 'default': -1, + "recall": 0, + "qps": 0, + "build_time": args.max_build_time + }) + return + else: + algo.fit(X=X_train, para=para, data_type=args.data_type) + + build_time = time.time() - t0 + + print('Built index in', build_time) + + search_param_choices = { + "NumberOfInitialDynamicPivots": [1, 2, 4, 8, 16, 32, 50], + "MaxCheck": [512, 3200, 5120, 8192, 12800, 16400, 19600], + "NumberOfOtherDynamicPivots": [1, 2, 4, 8, 10] + } + + best_metric = -1 + best_res = {} + for i, search_params in enumerate(grid_search(search_param_choices)): + algo.set_query_arguments(search_params) + try: + attrs, results = run_individual_query(algo, + X_train, + X_test, + distance, + args.k, + max_mem=args.max_memory) + except MemoryError: + print("Aborting due to exceed memory limit") + nni.report_final_result({ + 'default': -1, + "recall": 0, + "qps": 0, + "build_time": args.max_build_time + }) + return + + neighbors = [0 for _ in results] + distances = [0 for _ in results] + + for idx, (t, ds) in enumerate(results): + neighbors[idx] = [n for n, d in ds] + [-1] * (args.k - len(ds)) + distances[idx] = [d for n, d in ds + ] + [float('inf')] * (args.k - len(ds)) + if args.label_file is None: + recalls_mean, qps = compute_metrics(label, attrs, distances, + args.k) + else: + recalls_mean, qps = compute_metrics(label, + attrs, + neighbors, + args.k, + from_index=True) + + combined_metric = -1 * np.log10(1 - recalls_mean) + 0.1 * np.log10(qps) + + res = { + "default": combined_metric, + "recall": recalls_mean, + "qps": qps, + "build_time": build_time + } + + if combined_metric > best_metric: + best_metric = combined_metric + best_res = res.copy() + + res["build_params"] = para + res["search_params"] = search_params + + experiment_id = nni.get_experiment_id() + result_dir = os.path.join('results', args.train_file.split('.')[0]) + if not os.path.exists(result_dir): + os.makedirs(result_dir) + trial_id = nni.get_trial_id() + with open( + os.path.join( + result_dir, + "result_" + str(trial_id) + ' ' + str(i) + ".json"), + "w") as f: + json.dump(res, f) + + nni.report_final_result(best_res) + + +if __name__ == '__main__': + main() diff --git a/Tools/nni-auto-tune/model.py b/Tools/nni-auto-tune/model.py new file mode 100755 index 000000000..4cd091d25 --- /dev/null +++ b/Tools/nni-auto-tune/model.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from sptag import SPTAG as SPTAG +from scipy.spatial.distance import pdist as scipy_pdist +import numpy +""" We use the implementation of metric and brute force algorithm from ann-banchmark""" + + +def pdist(a, b, metric): + return scipy_pdist([a, b], metric=metric)[0] + + +def jaccard(a, b): + if len(a) == 0 or len(b) == 0: + return 0 + intersect = len(set(a) & set(b)) + return intersect / (float)(len(a) + len(b) - intersect) + + +metrics = { + 'hamming': { + 'distance': lambda a, b: pdist(a, b, "hamming"), + 'distance_valid': lambda a: True + }, + # return 1 - jaccard similarity, because smaller distances are better. + # modified to use pdist jaccard given new data format (boolean ndarray) + 'jaccard': { + 'distance': lambda a, b: 1 - jaccard(a, b), #pdist(a, b, "jaccard"), + 'distance_valid': lambda a: a < 1 - 1e-5 + }, + 'euclidean': { + 'distance': lambda a, b: pdist(a, b, "euclidean"), + 'distance_valid': lambda a: True + }, + 'angular': { + 'distance': lambda a, b: pdist(a, b, "cosine"), + 'distance_valid': lambda a: True + } +} + + +class Sptag: + + def __init__(self, algo, metric): + self._algo = str(algo) + self._para = {} + self._metric = {'angular': 'Cosine', 'euclidean': 'L2'}[metric] + + def fit(self, X, para=None, data_type='float32', save_index=False): + self._data_type = { + 'float32': 'Float', + 'int8': 'Int8', + 'int16': 'Int16' + }[data_type] + self._sptag = SPTAG.AnnIndex(self._algo, self._data_type, X.shape[1]) + self._sptag.SetBuildParam("NumberOfThreads", '32', "Index") + self._sptag.SetBuildParam("DistCalcMethod", self._metric, "Index") + + if para: + self._para = para + for k, v in para.items(): + self._sptag.SetBuildParam(k, str(v), "Index") + + self._sptag.Build(X, X.shape[0], False) + # temp save for auto tune, please use self.save() for intended saving + if save_index: + self.save("index") + + def set_query_arguments(self, s_para=None): + if s_para: + self.s_para = s_para + for k, v in s_para.items(): + self._sptag.SetSearchParam(k, str(v), "Index") + + def query(self, v, k): + return self._sptag.Search(v, k)[0] + + def save(self, fn): + self._sptag.Save(fn) + + def load(self, fn): + self._sptag = SPTAG.AnnIndex.Load(fn) + + def __str__(self): + s = '' + if self._para: + s += ", " + ", ".join( + [k + "=" + str(v) for k, v in self._para.items()]) + if self.s_para: + s += ", " + ", ".join( + [k + "_s" + "=" + str(v) for k, v in self.s_para.items()]) + return 'Sptag(metric=%s, algo=%s' % (self._metric, + self._algo) + s + ')' + + +class BruteForceBLAS: + """kNN search that uses a linear scan = brute force.""" + + def __init__(self, metric, precision=numpy.float32): + if metric not in ('angular', 'euclidean', 'hamming', 'jaccard'): + raise NotImplementedError( + "BruteForceBLAS doesn't support metric %s" % metric) + elif metric == 'hamming' and precision != numpy.bool: + raise NotImplementedError( + "BruteForceBLAS doesn't support precision" + " %s with Hamming distances" % precision) + self._metric = metric + self._precision = precision + self.name = 'BruteForceBLAS()' + + def fit(self, X): + """Initialize the search index.""" + X = X.astype(self._precision) + if self._metric == 'angular': + # precompute (squared) length of each vector + lens = (X**2).sum(-1) + # normalize index vectors to unit length + X /= numpy.sqrt(lens)[..., numpy.newaxis] + self.index = numpy.ascontiguousarray(X, dtype=self._precision) + elif self._metric == 'hamming': + # Regarding bitvectors as vectors in l_2 is faster for blas + X = X.astype(numpy.float32) + # precompute (squared) length of each vector + lens = (X**2).sum(-1) + self.index = numpy.ascontiguousarray(X, dtype=numpy.float32) + self.lengths = numpy.ascontiguousarray(lens, dtype=numpy.float32) + elif self._metric == 'euclidean': + # precompute (squared) length of each vector + self.index = numpy.ascontiguousarray(X, dtype=self._precision) + lens = (self.index**2).sum(-1) + self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision) + elif self._metric == 'jaccard': + self.index = X + else: + # shouldn't get past the constructor! + assert False, "invalid metric" + + def query(self, v, n): + return [index for index, _ in self.query_with_distances(v, n)] + + def query_with_distances(self, v, n): + """Find indices of `n` most similar vectors from the index to query + vector `v`.""" + v = v.astype(self._precision) + if self._metric != 'jaccard': + # use same precision for query as for index + v = numpy.ascontiguousarray(v, dtype=self.index.dtype) + + # HACK we ignore query length as that's a constant + # not affecting the final ordering + if self._metric == 'angular': + # argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) # noqa + dists = -numpy.dot(self.index, v) + elif self._metric == 'euclidean': + # argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab # noqa + dists = self.lengths - 2 * numpy.dot(self.index, v) + elif self._metric == 'hamming': + # Just compute hamming distance using euclidean distance + dists = self.lengths - 2 * numpy.dot(self.index, v) + elif self._metric == 'jaccard': + dists = [ + metrics[self._metric]['distance'](v, e) for e in self.index + ] + else: + # shouldn't get past the constructor! + assert False, "invalid metric" + # partition-sort by distance, get `n` closest + nearest_indices = numpy.argpartition(dists, n)[:n] + indices = [ + idx for idx in nearest_indices + if metrics[self._metric]["distance_valid"](dists[idx]) + ] + + def fix(index): + ep = self.index[index] + ev = v + return (index, metrics[self._metric]['distance'](ep, ev)) + + return map(fix, indices) diff --git a/Tools/nni-auto-tune/picture/fashion-mnist-784-euclidean.png b/Tools/nni-auto-tune/picture/fashion-mnist-784-euclidean.png new file mode 100644 index 000000000..f78866321 Binary files /dev/null and b/Tools/nni-auto-tune/picture/fashion-mnist-784-euclidean.png differ diff --git a/Tools/nni-auto-tune/picture/glove-100-angular.png b/Tools/nni-auto-tune/picture/glove-100-angular.png new file mode 100644 index 000000000..34e4fcc72 Binary files /dev/null and b/Tools/nni-auto-tune/picture/glove-100-angular.png differ diff --git a/Tools/nni-auto-tune/picture/glove-25-angular.png b/Tools/nni-auto-tune/picture/glove-25-angular.png new file mode 100644 index 000000000..b1d9ca424 Binary files /dev/null and b/Tools/nni-auto-tune/picture/glove-25-angular.png differ diff --git a/Tools/nni-auto-tune/picture/nytimes-256-angular.png b/Tools/nni-auto-tune/picture/nytimes-256-angular.png new file mode 100644 index 000000000..8cdcf32e2 Binary files /dev/null and b/Tools/nni-auto-tune/picture/nytimes-256-angular.png differ diff --git a/Tools/nni-auto-tune/picture/sift-128-euclidean.png b/Tools/nni-auto-tune/picture/sift-128-euclidean.png new file mode 100644 index 000000000..e38227aaa Binary files /dev/null and b/Tools/nni-auto-tune/picture/sift-128-euclidean.png differ diff --git a/Tools/nni-auto-tune/plot.py b/Tools/nni-auto-tune/plot.py new file mode 100644 index 000000000..62467d6de --- /dev/null +++ b/Tools/nni-auto-tune/plot.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import matplotlib.pyplot as plt +import json +import os +import numpy as np +import argparse +import time +"""This is almost the same plot code as ann-benchmarks""" + + +def create_plot(path, x_scale, y_scale, results): + qps, recalls, build_times, build_params, search_params = [], [], [], [], [] + data = [] + print(len(results)) + start = time.time() + for i, p in enumerate(results): + if i % 2000 == 0: + print("processed", i, 'results in', round(time.time() - start), + 'seconds') + start = time.time() + + with open(os.path.join("results", path, p), "r") as f: + try: + res = json.load(f) + data.append((res["qps"], res["recall"], res["build_time"], + res["build_params"], res["search_params"])) + except: + print(os.path.join("results", path, p)) + + data.sort(key=lambda t: (-1 * t[0], -1 * t[1])) + last_x = -1 + comparator = ((lambda xv, lx: xv > lx) if last_x < 0 else + (lambda xv, lx: xv < lx)) + with open(path + '.json', 'w') as f: + for i, (q, r, b, bp, sp) in enumerate(data): + if comparator(r, last_x): + last_x = r + recalls.append(r) + qps.append(q) + print("Selected point", i, "recalls:", r, "qps:", q, + "build time:", b, "with build params:", bp, + "and search params:", sp) + f.write("Selected point" + str(i) + "recalls:" + str(r) + + "qps:" + str(q) + "build time:" + str(b) + + "with build params:" + str(bp) + "and search params:" + + str(sp) + "\n") + + handles, labels = [], [] + handle, = plt.plot(recalls, + qps, + '-', + label="SPTAG", + color="r", + ms=7, + mew=3, + lw=3, + linestyle='-', + marker='+') + handles.append(handle) + labels.append('sptag') + + ax = plt.gca() + ax.set_ylabel("Queries per second (1/s)") + ax.set_xlabel("Recall") + + if x_scale[0] == 'a': + alpha = float(x_scale[1:]) + fun = lambda x: 1 - (1 - x)**(1 / alpha) + inv_fun = lambda x: 1 - (1 - x)**alpha + ax.set_xscale('function', functions=(fun, inv_fun)) + if alpha <= 3: + ticks = [inv_fun(x) for x in np.arange(0, 1.2, .2)] + plt.xticks(ticks) + if alpha > 3: + from matplotlib import ticker + ax.xaxis.set_major_formatter(ticker.LogitFormatter()) + plt.xticks([0, 1 / 2, 1 - 1e-1, 1 - 1e-2, 1 - 1e-3, 1 - 1e-4, 1]) + else: + ax.set_xscale(x_scale) + ax.set_yscale(y_scale) + ax.set_title( + "Recall-Queries per second (1/s) tradeoff - up and to the right is better" + ) + box = plt.gca().get_position() + ax.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 0.5), + prop={'size': 9}) + plt.grid(b=True, which='major', color='0.65', linestyle='-') + plt.setp(ax.get_xminorticklabels(), visible=True) + + ax.spines['bottom']._adjust_location() + + plt.savefig(path + '.png', bbox_inches='tight') + plt.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--path', + help='the path to load result from', + default='glove-100-angular') + parser.add_argument( + '-X', + '--x_scale', + help= + 'Scale to use when drawing the X-axis. Typically linear, logit or a2', + default='a4') + parser.add_argument('-Y', + '--y_scale', + help='Scale to use when drawing the Y-axis', + choices=["linear", "log", "symlog", "logit"], + default='log') + args = parser.parse_args() + + if os.path.exists(os.path.join("results", args.path)): + results = os.listdir(os.path.join("results", args.path)) + else: + raise FileNotFoundError("No auto tune result found") + print("ploting", args.path) + create_plot(args.path, args.x_scale, args.y_scale, results) diff --git a/Tools/nni-auto-tune/preprocessing.py b/Tools/nni-auto-tune/preprocessing.py new file mode 100755 index 000000000..aca8506f9 --- /dev/null +++ b/Tools/nni-auto-tune/preprocessing.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import h5py +import time + +import numpy as np +import os +from model import BruteForceBLAS +from dataset import DataReader, HDF5Reader +import argparse +import json + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--train_file', + help='the data file to load training points from, ' + 'could be text file, binary flie or ann-benchmark format hdf5 file', + default='glove-100-angular.hdf5') + parser.add_argument( + '--query_file', + help='the data file to load query points from, if you use ' + 'ann-benchmark format hdf5 file in train_file, this should be None', + default=None) + parser.add_argument( + '--output_dir', + help= + 'the dir to save sampled train set and pre-calculated ground truth', + default='./sampled') + parser.add_argument( + "--num_sample", + default=1000000, + type=int, + help="the size of sampled train set. -1 means don't do sample") + parser.add_argument("--k", + default=32, + type=int, + help="the number of near neighbours to search for") + parser.add_argument("--distance", + default='angular', + help="the type of distance for searching") + parser.add_argument("--dim", + default=100, + type=int, + help="the dimention of training vectors") + parser.add_argument("--input_type", + default="float32", + help="the data type of input vectors") + parser.add_argument("--data_type", + default="float32", + help="the data type for building and search in SPTAG ") + args = parser.parse_args() + + def tostring(a): + return str(a) + + if args.train_file.endswith(".hdf5"): + # ann-benchmark format hdf5 file got all we want, so args like distance are ignored + data_reader = HDF5Reader(args.train_file, args.data_type) + X_train, X_test = data_reader.readallbatches() + distance = data_reader.distance + dimension = data_reader.featuredim + + else: + X_train = DataReader(args.train_file, + args.dim, + batchsize=-1, + datatype=args.input_type, + targettype=args.data_type).readbatch()[1] + X_test = DataReader(args.query_file, + args.dim, + batchsize=-1, + datatype=args.input_type, + targettype=args.data_type).readbatch()[1] + distance = args.distance + dimension = args.dim + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # -1 means don't do sample + if args.num_sample != -1: + np.random.shuffle(X_train) + X_train = X_train[:args.num_sample] + + print('sampled train set to size (%d)' % (X_train.shape[0])) + prefix = args.train_file.replace(".bin", "") + target_path = os.path.join( + args.output_dir, + os.path.splitext(os.path.split(prefix)[1])[0] + '-sampled-' + + str(args.num_sample) + '.txt') + with open(target_path, 'w') as f: + for i in range(len(X_train)): + f.write('\t') + f.write('|'.join(list(map(tostring, X_train[i])))) + f.write('\n') + + print('write sampled train set to', target_path, "complete") + + label_index = [] + + # we calculate groundtruth index with brute force + bf = BruteForceBLAS(distance) + bf.fit(X_train) + for i, x in enumerate(X_test): + if i % 1000 == 0: + print('%d/%d...' % (i, len(X_test))) + res = list(bf.query_with_distances(x, args.k)) + res.sort(key=lambda t: t[-1]) + label_index.append([n for n, _ in res]) + + print('generated a ground truth index set of size (%d)' % + (len(label_index))) + + with open(os.path.join(args.output_dir, 'ground_truth.txt'), 'w') as f: + for i in range(len(label_index)): + f.write(' '.join(list(map(tostring, label_index[i])))) + f.write('\n') + + print('write ground truth to', + os.path.join(args.output_dir, 'ground_truth.txt'), "complete") + + +if __name__ == '__main__': + main() diff --git a/Tools/nni-auto-tune/runner.py b/Tools/nni-auto-tune/runner.py new file mode 100755 index 000000000..d4650017c --- /dev/null +++ b/Tools/nni-auto-tune/runner.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from model import metrics +import time +import os +import psutil + + +def run_individual_query(algo, + X_train, + X_test, + distance, + k, + run_count=1, + max_mem=-1): + + best_search_time = float('inf') + for i in range(run_count): + print('Run %d/%d...' % (i + 1, run_count)) + # a bit dumb but can't be a scalar since of Python's scoping rules + n_items_processed = [0] + + def single_query(v): + start = time.time() + candidates = algo.query(v, k) + if max_mem > 0 and psutil.Process( + os.getpid()).memory_info().rss > max_mem: + raise MemoryError + total = (time.time() - start) + + candidates = [ + (int(idx), + float(metrics[distance]['distance'](v, X_train[idx]))) # noqa + for idx in candidates + ] + n_items_processed[0] += 1 + if n_items_processed[0] % 1000 == 0: + print('Processed %d/%d queries...' % + (n_items_processed[0], len(X_test))) + if len(candidates) > k: + print('warning: algorithm returned %d results, but k' + ' is only %d)' % (len(candidates), k)) + return (total, candidates) + + results = [single_query(x) for x in X_test] + + total_time = sum(t for t, _ in results) + total_candidates = sum(len(candidates) for _, candidates in results) + search_time = total_time / len(X_test) + avg_candidates = total_candidates / len(X_test) + best_search_time = min(best_search_time, search_time) + + attrs = { + "best_search_time": best_search_time, + "candidates": avg_candidates, + "name": 'BKT', + "run_count": run_count, + "distance": distance, + "count": int(k) + } + + return (attrs, results) diff --git a/Tools/nni-auto-tune/search_space.json b/Tools/nni-auto-tune/search_space.json new file mode 100755 index 000000000..e11158368 --- /dev/null +++ b/Tools/nni-auto-tune/search_space.json @@ -0,0 +1,13 @@ +{ + "BKTKmeansK": {"_type": "quniform", "_value": [4,32,8]}, + "Samples": {"_type": "quniform", "_value": [1000, 10000, 2000]}, + "TPTNumber": {"_type": "quniform", "_value": [32, 192, 16]}, + "RefineIterations": {"_type": "choice", "_value": [2, 3]}, + "NeighborhoodSize": {"_type": "quniform", "_value": [16, 192, 8]}, + "CEF": {"_type": "quniform", "_value": [1000, 2000,100]}, + "MaxCheckForRefineGraph": {"_type": "quniform", "_value": [4096, 16324, 1024]}, + "NumberOfInitialDynamicPivots": {"_type": "quniform", "_value": [1, 50, 10]}, + "GraphNeighborhoodScale": {"_type": "choice", "_value": [2, 3, 4]}, + "NumberOfOtherDynamicPivots": {"_type": "quniform", "_value": [1, 10, 2]} +} + diff --git a/Tools/nni-auto-tune/search_space_small.json b/Tools/nni-auto-tune/search_space_small.json new file mode 100755 index 000000000..3029b8e82 --- /dev/null +++ b/Tools/nni-auto-tune/search_space_small.json @@ -0,0 +1,5 @@ +{ + "BKTKmeansK": {"_type": "choice", "_value": [4,8,16,32]}, + "CEF": {"_type": "choice", "_value": [1000, 1100, 1200, 1400, 1800,2000]} +} + diff --git a/Wrappers/CLRCore.vcxproj b/Wrappers/CLRCore.vcxproj index dcbf66bf3..e3f4d85ff 100644 --- a/Wrappers/CLRCore.vcxproj +++ b/Wrappers/CLRCore.vcxproj @@ -25,36 +25,39 @@ ManagedCProj CLRCore 10.0 + net5.0 + true DynamicLibrary true - v142 + v143 true Unicode DynamicLibrary false - v142 + v143 true Unicode DynamicLibrary true - v142 + v143 true MultiByte DynamicLibrary false - v142 + v143 true MultiByte + Spectre @@ -115,10 +118,16 @@ NotUsing true stdcpp17 + /ZH:SHA_256 %(AdditionalOptions) CoreLibrary.lib;%(AdditionalDependencies) + true + /ZH:SHA_256 %(AdditionalOptions) + + sn -k $(SolutionDir)sgKey.snk + @@ -142,27 +151,5 @@ - - - - - - - - - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - - - - + \ No newline at end of file diff --git a/Wrappers/CMakeLists.txt b/Wrappers/CMakeLists.txt index 7c16ea82e..9a73541a1 100644 --- a/Wrappers/CMakeLists.txt +++ b/Wrappers/CMakeLists.txt @@ -1,34 +1,34 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -find_package(Python2 COMPONENTS Development) + +find_package(Python3 COMPONENTS Development) find_package(CUDA) -if (NOT GPU) +if ((NOT GPU) OR (NOT Python3_FOUND)) set (CUDA_FOUND false) endif() if (CUDA_FOUND) message (STATUS "Found cuda.") - set (Python2_FOUND false) else() message (STATUS "Could not find cuda.") endif() -if (Python2_FOUND) - include_directories (${Python2_INCLUDE_DIRS}) - link_directories (${Python2_LIBRARY_DIRS}) - set (Python_INCLUDE_DIRS ${Python2_INCLUDE_DIRS}) - set (Python_LIBRARIES ${Python2_LIBRARIES}) +if (Python3_FOUND) + include_directories (${Python3_INCLUDE_DIRS}) + link_directories (${Python3_LIBRARY_DIRS}) + set (Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}) + set (Python_LIBRARIES ${Python3_LIBRARIES}) set (Python_FOUND true) else() - find_package(Python3 COMPONENTS Development) - if (Python3_FOUND) - include_directories (${Python3_INCLUDE_DIRS}) - link_directories (${Python3_LIBRARY_DIRS}) - set (Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}) - set (Python_LIBRARIES ${Python3_LIBRARIES}) + find_package(Python2 COMPONENTS Development) + if (Python2_FOUND) + include_directories (${Python2_INCLUDE_DIRS}) + link_directories (${Python2_LIBRARY_DIRS}) + set (Python_INCLUDE_DIRS ${Python2_INCLUDE_DIRS}) + set (Python_LIBRARIES ${Python2_LIBRARIES}) set (Python_FOUND true) endif() endif() @@ -54,9 +54,9 @@ if (Python_FOUND) add_library (_SPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) set_target_properties(_SPTAG PROPERTIES PREFIX "" SUFFIX ${PY_SUFFIX}) if (CUDA_FOUND) - target_link_libraries(_SPTAG GPUSPTAGLibStatic ${Python_LIBRARIES}) + target_link_libraries(_SPTAG GPUSPTAGLib ${Python_LIBRARIES}) else() - target_link_libraries(_SPTAG SPTAGLibStatic ${Python_LIBRARIES}) + target_link_libraries(_SPTAG SPTAGLib ${Python_LIBRARIES}) endif() add_custom_command(TARGET _SPTAG POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/Wrappers/inc/SPTAG.py ${EXECUTABLE_OUTPUT_PATH}) @@ -64,7 +64,7 @@ if (Python_FOUND) file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_pwrap.cpp) add_library (_SPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) set_target_properties(_SPTAGClient PROPERTIES PREFIX "" SUFFIX ${PY_SUFFIX}) - target_link_libraries(_SPTAGClient SPTAGLibStatic ${Python_LIBRARIES} ${Boost_LIBRARIES}) + target_link_libraries(_SPTAGClient SPTAGLib ${Python_LIBRARIES} ${Boost_LIBRARIES}) add_custom_command(TARGET _SPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/Wrappers/inc/SPTAGClient.py ${EXECUTABLE_OUTPUT_PATH}) install(TARGETS _SPTAG _SPTAGClient @@ -92,6 +92,7 @@ if (JNI_FOUND) execute_process(COMMAND swig -java -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_jwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/JavaCore.i) execute_process(COMMAND swig -java -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_jwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/JavaClient.i) + execute_process(COMMAND swig -java -package site.ycsb.fileiointerface -c++ -I${PROJECT_SOURCE_DIR}/Wrappers/inc -o ${PROJECT_SOURCE_DIR}/Wrappers/inc/FileIOInterface_jwrap.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/JavaFileIO.i) include_directories(${JNI_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}/AnnService ${PROJECT_SOURCE_DIR}/Wrappers) @@ -99,13 +100,19 @@ if (JNI_FOUND) file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_jwrap.cpp) add_library (JAVASPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) set_target_properties(JAVASPTAG PROPERTIES SUFFIX ${JAVA_SUFFIX}) - target_link_libraries(JAVASPTAG SPTAGLibStatic ${JNI_LIBRARIES}) + target_link_libraries(JAVASPTAG SPTAGLib ${JNI_LIBRARIES}) file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_jwrap.cpp) add_library (JAVASPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) set_target_properties(JAVASPTAGClient PROPERTIES SUFFIX ${JAVA_SUFFIX}) - target_link_libraries(JAVASPTAGClient SPTAGLibStatic ${JNI_LIBRARIES} ${Boost_LIBRARIES}) + target_link_libraries(JAVASPTAGClient SPTAGLib ${JNI_LIBRARIES} ${Boost_LIBRARIES}) + + file(GLOB FILEIO_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/FileIOInterface.h) + file(GLOB FILEIO_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/FileIOInterface_jwrap.cpp) + add_library (JAVASPTAGFileIO SHARED ${FILEIO_SRC_FILES} ${FILEIO_HDR_FILES}) + set_target_properties(JAVASPTAGFileIO PROPERTIES SUFFIX ${JAVA_SUFFIX}) + target_link_libraries(JAVASPTAGFileIO SPTAGLib ${JNI_LIBRARIES}) file(GLOB JAVA_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.java) foreach(JAVA_FILE ${JAVA_FILES}) @@ -113,7 +120,7 @@ if (JNI_FOUND) add_custom_command(TARGET JAVASPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${JAVA_FILE} ${EXECUTABLE_OUTPUT_PATH}) endforeach(JAVA_FILE) - install(TARGETS JAVASPTAG JAVASPTAGClient + install(TARGETS JAVASPTAG JAVASPTAGClient JAVASPTAGFileIO RUNTIME DESTINATION bin ARCHIVE DESTINATION lib LIBRARY DESTINATION lib) @@ -164,13 +171,13 @@ if (DOTNET_FOUND) file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/CoreInterface_cwrap.cpp) add_library (CSHARPSPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) set_target_properties(CSHARPSPTAG PROPERTIES SUFFIX ${CSHARP_SUFFIX}) - target_link_libraries(CSHARPSPTAG SPTAGLibStatic) + target_link_libraries(CSHARPSPTAG SPTAGLib) file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/Wrappers/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/Wrappers/inc/ClientInterface_cwrap.cpp) add_library (CSHARPSPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) set_target_properties(CSHARPSPTAGClient PROPERTIES SUFFIX ${CSHARP_SUFFIX}) - target_link_libraries(CSHARPSPTAGClient SPTAGLibStatic ${Boost_LIBRARIES}) + target_link_libraries(CSHARPSPTAGClient SPTAGLib ${Boost_LIBRARIES}) file(GLOB CSHARP_FILES ${PROJECT_SOURCE_DIR}/Wrappers/inc/*.cs) foreach(CSHARP_FILE ${CSHARP_FILES}) diff --git a/Wrappers/CsharpClient.vcxproj b/Wrappers/CsharpClient.vcxproj index 1e48b766f..2953c2c71 100644 --- a/Wrappers/CsharpClient.vcxproj +++ b/Wrappers/CsharpClient.vcxproj @@ -29,31 +29,33 @@ DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte + Spectre + @@ -87,60 +89,21 @@ + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) %(AdditionalIncludeDirectories) - - - - - Level3 - MaxSpeed - true - true - true - true - _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) - stdcpp17 - - - true - true - - - - - Level3 - Disabled - true - true - - - - - Level3 - Disabled - true - true - _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) - Guard - ProgramDatabase - stdcpp17 + Guard + ProgramDatabase + _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + stdcpp17 + stdcpp17 + Level3 + Guard - /guard:cf %(AdditionalOptions) + /guard:cf %(AdditionalOptions) - - - - Level3 - MaxSpeed - true - true - true - true - - true - true + true @@ -165,15 +128,18 @@ - - - - - + + + + + + + + - + @@ -185,11 +151,14 @@ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - + + + + + + + + + \ No newline at end of file diff --git a/Wrappers/CsharpCore.vcxproj b/Wrappers/CsharpCore.vcxproj index eea069631..6b4ec384c 100644 --- a/Wrappers/CsharpCore.vcxproj +++ b/Wrappers/CsharpCore.vcxproj @@ -29,33 +29,32 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte + Spectre - - - + @@ -95,10 +94,15 @@ _WINDLL;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) stdcpp17 stdcpp17 + Level3 + Guard /guard:cf %(AdditionalOptions) + + true + @@ -120,7 +124,7 @@ - + @@ -128,11 +132,4 @@ - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - \ No newline at end of file diff --git a/Wrappers/JavaClient.vcxproj b/Wrappers/JavaClient.vcxproj index a22a88b6e..b4861d39b 100644 --- a/Wrappers/JavaClient.vcxproj +++ b/Wrappers/JavaClient.vcxproj @@ -29,26 +29,26 @@ DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte @@ -165,14 +165,18 @@ - - - - + + + + + + + + - + @@ -184,10 +188,14 @@ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - + + + + + + + + + \ No newline at end of file diff --git a/Wrappers/JavaCore.vcxproj b/Wrappers/JavaCore.vcxproj index a0ed98069..325aeaa25 100644 --- a/Wrappers/JavaCore.vcxproj +++ b/Wrappers/JavaCore.vcxproj @@ -29,26 +29,26 @@ Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte @@ -115,11 +115,10 @@ - - + - + @@ -127,10 +126,4 @@ - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - \ No newline at end of file diff --git a/Wrappers/PythonClient.vcxproj b/Wrappers/PythonClient.vcxproj index 25f5b7293..9160772cb 100644 --- a/Wrappers/PythonClient.vcxproj +++ b/Wrappers/PythonClient.vcxproj @@ -1,6 +1,7 @@  - + + Debug @@ -26,30 +27,29 @@ 10.0 - DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte @@ -162,15 +162,18 @@ - - - - - + + + + + + + + - + @@ -179,12 +182,15 @@ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - - - - + + + + + + + + + + \ No newline at end of file diff --git a/Wrappers/PythonCore.vcxproj b/Wrappers/PythonCore.vcxproj index db9b11845..53c7cafbb 100644 --- a/Wrappers/PythonCore.vcxproj +++ b/Wrappers/PythonCore.vcxproj @@ -1,6 +1,7 @@  - + + Debug @@ -26,30 +27,29 @@ 10.0 - Application true - v142 + v143 MultiByte Application false - v142 + v143 true MultiByte DynamicLibrary true - v142 + v143 MultiByte DynamicLibrary false - v142 + v143 true MultiByte @@ -117,10 +117,9 @@ - - + @@ -129,8 +128,7 @@ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - + + \ No newline at end of file diff --git a/Wrappers/WinRT/AnnIndex.cpp b/Wrappers/WinRT/AnnIndex.cpp new file mode 100644 index 000000000..3ba0354dc --- /dev/null +++ b/Wrappers/WinRT/AnnIndex.cpp @@ -0,0 +1,97 @@ +#include "AnnIndex.h" +#include "pch.h" +#if __has_include("AnnIndex.g.cpp") +#include "AnnIndex.g.cpp" +#endif +#include "SearchResult.g.cpp" + +namespace winrt::SPTAG::implementation +{ +template >> +sptag::ByteArray GetByteArray(const winrt::array_view &data) +{ + auto copy = new T[data.size()]; + for (auto i = 0u; i < data.size(); i++) + copy[i] = data[i]; + const auto byteSize = data.size() * sizeof(data.at(0)) / sizeof(byte); + auto byteArray = sptag::ByteArray(reinterpret_cast(copy), byteSize, true); + return byteArray; +} + +SPTAG::SearchResult AnnIndex::GetResultFromMetadata(const sptag::BasicResult &r) const +{ + return winrt::make(r.Meta.Data(), r.Meta.Length(), r.Dist); +} + +winrt::Windows::Foundation::Collections::IVector AnnIndex::Search(EmbeddingVector p_data, + uint32_t p_resultNum) const +{ + auto vec = std::vector{}; + p_resultNum = (std::min)(static_cast(p_resultNum), m_index->GetNumSamples()); + auto results = std::make_shared(p_data.data(), p_resultNum, true); + + if (nullptr != m_index) + { + m_index->SearchIndex(*results); + } + for (const auto &r : *results) + { + auto sr = GetResultFromMetadata(r); + vec.push_back(sr); + } + return winrt::single_threaded_vector(std::move(vec)); +} + +void AnnIndex::Load(winrt::Windows::Storage::StorageFile file) +{ + auto path = file.Path(); + if (sptag::ErrorCode::Success != sptag::VectorIndex::LoadIndexFromFile(winrt::to_string(path), m_index)) + { + throw winrt::hresult_error{}; + } +} + +void AnnIndex::Save(winrt::Windows::Storage::StorageFile file) +{ + auto path = file.Path(); + if (sptag::ErrorCode::Success != m_index->SaveIndexToFile(winrt::to_string(path))) + { + throw winrt::hresult_error{}; + } +} + +template void AnnIndex::_AddWithMetadataImpl(EmbeddingVector p_data, T metadata) +{ + if (m_dimension == 0) + { + m_dimension = p_data.size(); + } + else if (m_dimension != static_cast(p_data.size())) + { + throw winrt::hresult_invalid_argument{}; + } + int p_num{1}; + auto byteArray = GetByteArray(p_data); + std::shared_ptr vectors(new sptag::BasicVectorSet(byteArray, m_inputValueType, + static_cast(m_dimension), + static_cast(p_num))); + + sptag::ByteArray p_meta = GetByteArray(metadata); + bool p_withMetaIndex{true}; + bool p_normalized{true}; + + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0, metadata.size() * sizeof(metadata[0])}; + sptag::ByteArray offsetsArray(reinterpret_cast(offsets), p_num + 1, true); + std::shared_ptr meta(new sptag::MemMetadataSet(p_meta, offsetsArray, p_num)); + + if (sptag::ErrorCode::Success != m_index->AddIndex(vectors, meta, p_withMetaIndex, p_normalized)) + { + throw winrt::hresult_error(E_UNEXPECTED); + } +} + +void AnnIndex::AddWithMetadata(array_view data, array_view metadata) +{ + _AddWithMetadataImpl(data, metadata); +} +} // namespace winrt::SPTAG::implementation diff --git a/Wrappers/WinRT/AnnIndex.h b/Wrappers/WinRT/AnnIndex.h new file mode 100644 index 000000000..7e728e667 --- /dev/null +++ b/Wrappers/WinRT/AnnIndex.h @@ -0,0 +1,71 @@ +#pragma once + +#include "AnnIndex.g.h" +#include "SearchResult.g.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SearchQuery.h" +#include +#include + +#include + +template +struct ReadOnlyProperty { + ReadOnlyProperty() = default; + ReadOnlyProperty(const T& value) : m_value(value) {} + T operator()() const noexcept { return m_value; } + T m_value{}; +}; + +template +struct Property : ReadOnlyProperty { + void operator()(const T& value) { m_value = value; } +}; + +namespace sptag = ::SPTAG; +namespace winrt::SPTAG::implementation +{ + using EmbeddingVector = winrt::array_view; + + struct SearchResult : SearchResultT { + winrt::com_array Metadata() { return winrt::com_array(m_metadata); } + std::vector m_metadata; + ReadOnlyProperty Distance; + + SearchResult() = default; + SearchResult(winrt::array_view metadata, float d) : m_metadata(metadata.begin(), metadata.end()), Distance(d) {} + SearchResult(uint8_t* metadata, size_t length, float d) : m_metadata(metadata, metadata + length), Distance(d) {} + }; + + + struct AnnIndex : AnnIndexT + { + sptag::DimensionType m_dimension{ }; + sptag::VectorValueType m_inputValueType{ sptag::VectorValueType::Float }; + + + AnnIndex() { + sptag::SetLogger(std::make_shared(sptag::Helper::LogLevel::LL_Empty)); + m_index = sptag::VectorIndex::CreateInstance(sptag::IndexAlgoType::BKT, sptag::GetEnumValueType()); + } + + void AddWithMetadata(array_view data, array_view metadata); + + void Save(winrt::Windows::Storage::StorageFile file); + void Load(winrt::Windows::Storage::StorageFile file); + + SPTAG::SearchResult GetResultFromMetadata(const sptag::BasicResult& r) const; + + winrt::Windows::Foundation::Collections::IVector Search(EmbeddingVector p_data, uint32_t p_resultNum) const; + + std::shared_ptr m_index; + template + void _AddWithMetadataImpl(EmbeddingVector p_data, T metadata); + }; + +} + +namespace winrt::SPTAG::factory_implementation +{ + struct AnnIndex : AnnIndexT {}; +} diff --git a/Wrappers/WinRT/AnnIndex.idl b/Wrappers/WinRT/AnnIndex.idl new file mode 100644 index 000000000..9bb3a9428 --- /dev/null +++ b/Wrappers/WinRT/AnnIndex.idl @@ -0,0 +1,34 @@ +namespace SPTAG +{ + + enum LogLevel + { + Debug = 0, + Info, + Status, + Warning, + Error, + Assert, + Count, + Empty + }; + + [default_interface] + runtimeclass SearchResult + { + UInt8[] Metadata {get; }; + Single Distance{ get; }; + }; + + [default_interface] + runtimeclass AnnIndex + { + AnnIndex(); + [default_overload] + void AddWithMetadata(Single[] data, UInt8[] metadata); + + void Save(Windows.Storage.StorageFile file); + void Load(Windows.Storage.StorageFile file); + Windows.Foundation.Collections.IVector Search(Single[] vector, UInt32 neighborCount); + } +} diff --git a/Wrappers/WinRT/PropertySheet.props b/Wrappers/WinRT/PropertySheet.props new file mode 100644 index 000000000..e34141b01 --- /dev/null +++ b/Wrappers/WinRT/PropertySheet.props @@ -0,0 +1,16 @@ + + + + + + + + \ No newline at end of file diff --git a/Wrappers/WinRT/SPTAG.WinRT.targets b/Wrappers/WinRT/SPTAG.WinRT.targets new file mode 100644 index 000000000..904f02408 --- /dev/null +++ b/Wrappers/WinRT/SPTAG.WinRT.targets @@ -0,0 +1,19 @@ + + + + + x86 + $(Platform) + <_nugetNativeFolder>$(MSBuildThisFileDirectory)..\..\runtimes\win10-$(Native-Platform)\native\ + + + + + SPTAG.dll + + + + + + + \ No newline at end of file diff --git a/Wrappers/WinRT/SPTAG.WinRT.vcxproj b/Wrappers/WinRT/SPTAG.WinRT.vcxproj new file mode 100644 index 000000000..72599ed89 --- /dev/null +++ b/Wrappers/WinRT/SPTAG.WinRT.vcxproj @@ -0,0 +1,162 @@ + + + + + true + true + true + true + {8dc74c33-6e15-43ed-9300-2a140589e3da} + SPTAG + SPTAG + en-US + 14.0 + true + Windows Store + 10.0 + 10.0 + 10.0.17134.0 + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + DynamicLibrary + v143 + v143 + v142 + v141 + v140 + Unicode + false + + + true + true + + + false + true + false + + + true + + + true + + + true + + + true + + + + + + + + + + + + + + + + + ../../AnnService/;$(IncludePath) + + + + Use + pch.h + $(IntDir)pch.pch + Level4 + %(AdditionalOptions) /bigobj + _WINRT_DLL;WIN32_LEAN_AND_MEAN;WINRT_LEAN_AND_MEAN;%(PreprocessorDefinitions) + $(WindowsSDK_WindowsMetadata);$(AdditionalUsingDirectories) + + + Console + false + SPTAG.def + + + + + _DEBUG;%(PreprocessorDefinitions) + + + + + NDEBUG;%(PreprocessorDefinitions) + + + true + true + + + + + AnnIndex.idl + Code + + + + + + AnnIndex.idl + Code + + + Create + + + + + + + + + + + false + + + + + true + + + + + Designer + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + \ No newline at end of file diff --git a/Wrappers/WinRT/SPTAG.def b/Wrappers/WinRT/SPTAG.def new file mode 100644 index 000000000..24e7c1235 --- /dev/null +++ b/Wrappers/WinRT/SPTAG.def @@ -0,0 +1,3 @@ +EXPORTS +DllCanUnloadNow = WINRT_CanUnloadNow PRIVATE +DllGetActivationFactory = WINRT_GetActivationFactory PRIVATE diff --git a/Wrappers/WinRT/packages.config b/Wrappers/WinRT/packages.config new file mode 100644 index 000000000..bbda0712b --- /dev/null +++ b/Wrappers/WinRT/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/Wrappers/WinRT/pch.cpp b/Wrappers/WinRT/pch.cpp new file mode 100644 index 000000000..bcb5590be --- /dev/null +++ b/Wrappers/WinRT/pch.cpp @@ -0,0 +1 @@ +#include "pch.h" diff --git a/Wrappers/WinRT/pch.h b/Wrappers/WinRT/pch.h new file mode 100644 index 000000000..21199686d --- /dev/null +++ b/Wrappers/WinRT/pch.h @@ -0,0 +1,4 @@ +#pragma once +#include +#include +#include diff --git a/Wrappers/inc/CLRCoreInterface.h b/Wrappers/inc/CLRCoreInterface.h index 3f5ec36f5..3596a93ae 100644 --- a/Wrappers/inc/CLRCoreInterface.h +++ b/Wrappers/inc/CLRCoreInterface.h @@ -4,6 +4,8 @@ #pragma once #include "ManagedObject.h" #include "inc/Core/VectorIndex.h" +#include "inc/Core/ResultIterator.h" +#include "inc/Core/MultiIndexScan.h" using namespace System; @@ -62,6 +64,29 @@ namespace Microsoft { } } + + property bool RelaxedMono + { + public: + bool get() + { + return m_Instance->RelaxedMono; + } + private: + void set(bool p_relaxedMono) + { + } + } + }; + + public ref class RIterator : + public ManagedObject> + { + public: + RIterator(std::shared_ptr result_iterator); + array^ Next(int p_batch); + bool GetRelaxedMono(); + void Close(); }; public ref class AnnIndex : @@ -76,6 +101,18 @@ namespace Microsoft void SetSearchParam(String^ p_name, String^ p_value, String^ p_section); + bool LoadQuantizer(String^ p_quantizerFile); + + void SetQuantizerADC(bool p_adc); + + array^ QuantizeVector(array^ p_data, int p_num); + + array^ ReconstructVector(array^ p_data, int p_num); + + bool BuildSPANN(bool p_normalized); + + bool BuildSPANNWithMetaData(array^ p_meta, int p_num, bool p_withMetaIndex, bool p_normalized); + bool Build(array^ p_data, int p_num); bool BuildWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex); @@ -88,6 +125,10 @@ namespace Microsoft array^ SearchWithMetaData(array^ p_data, int p_resultNum); + RIterator^ GetIterator(array^ p_data); + + void UpdateIndex(); + bool Save(String^ p_saveFile); array^>^ Dump(); @@ -116,6 +157,18 @@ namespace Microsoft size_t m_inputVectorSize; }; + + public ref class MultiIndexScan : + public ManagedObject> + { + public: + MultiIndexScan(std::shared_ptr multi_index_scan); + MultiIndexScan(array^ indice, array^>^ p_data, array^ weight, int p_resultNum, + bool useTimer, int termCondVal, int searchLimit); + BasicResult^ Next(); + void Close(); + }; + } } } diff --git a/Wrappers/inc/CoreInterface.h b/Wrappers/inc/CoreInterface.h index b855a46cb..861e0b5ea 100644 --- a/Wrappers/inc/CoreInterface.h +++ b/Wrappers/inc/CoreInterface.h @@ -28,6 +28,10 @@ class AnnIndex void SetQuantizerADC(bool p_adc); + ByteArray QuantizeVector(ByteArray p_data, int p_num); + + ByteArray ReconstructVector(ByteArray p_data, int p_num); + bool BuildSPANN(bool p_normalized); bool BuildSPANNWithMetaData(ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized); @@ -36,6 +40,8 @@ class AnnIndex bool BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized); + std::shared_ptr GetIterator(ByteArray p_target); + std::shared_ptr Search(ByteArray p_data, int p_resultNum); std::shared_ptr SearchWithMetaData(ByteArray p_data, int p_resultNum); @@ -56,6 +62,12 @@ class AnnIndex bool DeleteByMetaData(ByteArray p_meta); + uint64_t CalculateBufferSize(); + + ByteArray Dump(ByteArray p_blobs); + + static AnnIndex LoadFromDump(ByteArray p_config, ByteArray p_blobs); + static AnnIndex Load(const char* p_loaderFile); static AnnIndex Merge(const char* p_indexFilePath1, const char* p_indexFilePath2); diff --git a/Wrappers/inc/CsharpCommon.i b/Wrappers/inc/CsharpCommon.i index 20062d971..7102cdd5c 100644 --- a/Wrappers/inc/CsharpCommon.i +++ b/Wrappers/inc/CsharpCommon.i @@ -10,6 +10,10 @@ void deleteArrayOfWrapperArray(void* ptr) { delete[] (WrapperArray*)ptr; } + + void deleteWrapperArray(void* ptr) { + delete[] (std::uint8_t*)ptr; + } %} %pragma(csharp) imclasscode=%{ @@ -24,6 +28,7 @@ %apply void *VOID_INT_PTR { void * } void deleteArrayOfWrapperArray(void* ptr); +void deleteWrapperArray(void* ptr); %typemap(ctype) ByteArray "WrapperArray" %typemap(imtype) ByteArray "WrapperArray" @@ -31,10 +36,7 @@ void deleteArrayOfWrapperArray(void* ptr); %typemap(in) ByteArray { $1.Set((std::uint8_t*)$input._data, $input._size, false); } -%typemap(out) ByteArray { - $result._data = $1.Data(); - $result._size = $1.Length(); -} + %typemap(csin, pre="unsafe { fixed(byte* ptr$csinput = $csinput) { $modulePINVOKE.WrapperArray temp$csinput = new $modulePINVOKE.WrapperArray( (System.IntPtr)ptr$csinput, (ulong)$csinput.LongLength );", terminator="} }" @@ -51,10 +53,16 @@ void deleteArrayOfWrapperArray(void* ptr); } %} +%typemap(out) ByteArray { + $result._size = $1.Length(); + $result._data = $1.Data(); +} + %typemap(csout, excode=SWIGEXCODE) ByteArray %{ $modulePINVOKE.WrapperArray data = $imcall;$excode byte[] ret = new byte[data._size]; System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + if (data._size > 0) $modulePINVOKE.deleteWrapperArray(data._data); return ret; %} @@ -63,6 +71,7 @@ void deleteArrayOfWrapperArray(void* ptr); $modulePINVOKE.WrapperArray data = $imcall; byte[] ret = new byte[data._size]; System.Runtime.InteropServices.Marshal.Copy(data._data, ret, 0, (int)data._size); + if (data._size > 0) $modulePINVOKE.deleteWrapperArray(data._data); return ret; } %} diff --git a/Wrappers/inc/CsharpCore.i b/Wrappers/inc/CsharpCore.i index 6434239b9..3608577e3 100644 --- a/Wrappers/inc/CsharpCore.i +++ b/Wrappers/inc/CsharpCore.i @@ -2,11 +2,14 @@ %{ #include "inc/CoreInterface.h" +#include "inc/Core/ResultIterator.h" %} %include +%include %shared_ptr(AnnIndex) %shared_ptr(QueryResult) +%shared_ptr(ResultIterator) %include "CsharpCommon.i" %{ @@ -15,3 +18,4 @@ %include "CoreInterface.h" %include "../../AnnService/inc/Core/SearchResult.h" +%include "../../AnnService/inc/Core/ResultIterator.h" \ No newline at end of file diff --git a/Wrappers/inc/FileIOInterface.h b/Wrappers/inc/FileIOInterface.h new file mode 100644 index 000000000..7d8f1666c --- /dev/null +++ b/Wrappers/inc/FileIOInterface.h @@ -0,0 +1,115 @@ +#ifndef _SPTAG_PW_FILEIOINTERFACE_H_ +#define _SPTAG_PW_FILEIOINTERFACE_H_ + +#include "inc/Core/SPANN/Options.h" +#include "inc/Core/SPANN/IExtraSearcher.h" +#include "inc/Core/SPANN/ExtraFileController.h" + +typedef std::int64_t AddressType; +typedef std::int32_t SizeType; +typedef SPTAG::ErrorCode ErrorCode; +typedef SPTAG::ByteArray ByteArray; +class FileIOInterface { +public: + FileIOInterface(const char* filePath, SizeType postingBlocks, SizeType bufferSize = 1024, int batchSize = 64) { + SPTAG::SPANN::Options opt; + std::string tmp(filePath); + opt.m_indexDirectory = tmp.substr(0, tmp.find_last_of(FolderSep)); + opt.m_ssdMappingFile = tmp.substr(tmp.find_last_of(FolderSep) + 1); + opt.m_startFileSize = (postingBlocks >> (30 - SPTAG::PageSizeEx)); + opt.m_bufferLength = bufferSize; + opt.m_spdkBatchSize = batchSize; + + m_fileIO = std::make_unique(opt); + m_workSpace = std::make_unique(); + m_workSpace->Initialize(opt.m_maxCheck, opt.m_hashExp, opt.m_searchInternalResultNum, max(opt.m_postingPageLimit, opt.m_searchPostingPageLimit + 1) << SPTAG::PageSizeEx, true, opt.m_enableDataCompression); + } + + void ShutDown() { + m_fileIO->ShutDown(); + } + + bool Get(SizeType key, std::string& value) { + ErrorCode result = m_fileIO->Get(key, &value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)); + if (result != ErrorCode::Success) { + value = ""; + } + return true; + } + + std::string Get(SizeType key) { + std::string value; + ErrorCode result = m_fileIO->Get(key, &value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)); + if (result != ErrorCode::Success) { + return ""; + } + return value; + } + + std::string Get(const std::string& key) { + std::string value; + ErrorCode result = m_fileIO->Get(key, &value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)); + if (result != ErrorCode::Success) { + return ""; + } + return value; + } + + bool MultiGet(const std::vector& keys, std::vector* values, const std::chrono::microseconds &timeout = std::chrono::microseconds::max()) { + return m_fileIO->MultiGet(keys, values, timeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool MultiGet(const std::vector& keys, std::vector* values, const std::chrono::microseconds &timeout = std::chrono::microseconds::max()) { + return m_fileIO->MultiGet(keys, values, timeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool Put(SizeType key, const std::string& value) { + return m_fileIO->Put(key, value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool Put(const std::string& key, const std::string& value) { + return m_fileIO->Put(key, value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool Merge(SizeType key, const std::string& value) { + return m_fileIO->Merge(key, value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool Merge(const std::string& key, const std::string& value) { + return m_fileIO->Merge(key, value, SPTAG::MaxTimeout, &(m_workSpace->m_diskRequests)) == ErrorCode::Success; + } + + bool Delete(SizeType key) { + return m_fileIO->Delete(key) == ErrorCode::Success; + } + + bool Delete(const std::string& key) { + return m_fileIO->Delete(key) == ErrorCode::Success; + } + + void ForceCompaction() { + m_fileIO->ForceCompaction(); + } + + void GetStat() { + m_fileIO->GetStat(); + } + + bool Load(std::string path, SizeType blockSize, SizeType capacity) { + return m_fileIO->Load(path, blockSize, capacity) == ErrorCode::Success; + } + + bool Save(std::string path) { + return m_fileIO->Save(path) == ErrorCode::Success; + } + + bool Checkpoint(std::string prefix) { + return m_fileIO->Checkpoint(prefix) == ErrorCode::Success; + } + +private: +std::unique_ptr m_fileIO; +std::unique_ptr m_workSpace; +}; + +#endif // _SPTAG_PW_FILEIOINTERFACE_H_ \ No newline at end of file diff --git a/Wrappers/inc/JavaCommon.i b/Wrappers/inc/JavaCommon.i index 366052d4f..9688dd682 100644 --- a/Wrappers/inc/JavaCommon.i +++ b/Wrappers/inc/JavaCommon.i @@ -10,6 +10,7 @@ %typemap(out) ByteArray { $result = JCALL1(NewByteArray, jenv, $1.Length()); JCALL4(SetByteArrayRegion, jenv, $result, 0, $1.Length(), (jbyte *)$1.Data()); + if ($1.Length() > 0) delete[] $1.Data(); } %typemap(javain) ByteArray "$javainput" %typemap(javaout) ByteArray { return $jnicall; } @@ -25,7 +26,7 @@ auto& meta = $1->GetMetadata(i); jbyteArray bptr = jenv->NewByteArray(meta.Length()); jenv->SetByteArrayRegion(bptr, 0, meta.Length(), (jbyte *)meta.Data()); - jenv->SetObjectArrayElement(jresult, i, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[B)V"), (jint)($1->GetResult(i)->VID), (jfloat)($1->GetResult(i)->Dist), bptr)); + jenv->SetObjectArrayElement(jresult, i, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[BZ)V"), (jint)($1->GetResult(i)->VID), (jfloat)($1->GetResult(i)->Dist), bptr, (jboolean)($1->GetResult(i)->RelaxedMono))); } } %typemap(javaout) std::shared_ptr { return $jnicall; } @@ -48,7 +49,7 @@ auto& meta = ptr.GetMetadata(j); jbyteArray bptr = jenv->NewByteArray(meta.Length()); jenv->SetByteArrayRegion(bptr, 0, meta.Length(), (jbyte *)meta.Data()); - jenv->SetObjectArrayElement(jresult, id, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[B)V"), (jint)(ptr.GetResult(j)->VID), (jfloat)(ptr.GetResult(j)->Dist), bptr)); + jenv->SetObjectArrayElement(jresult, id, jenv->NewObject(retClass, jenv->GetMethodID(retClass, "", "(IF[BZ)V"), (jint)(ptr.GetResult(j)->VID), (jfloat)(ptr.GetResult(j)->Dist), bptr, (jboolean)(ptr.GetResult(j)->RelaxedMono))); id++; } } diff --git a/Wrappers/inc/JavaCore.i b/Wrappers/inc/JavaCore.i index 78d9dd72e..685f7d275 100644 --- a/Wrappers/inc/JavaCore.i +++ b/Wrappers/inc/JavaCore.i @@ -2,11 +2,14 @@ %{ #include "inc/CoreInterface.h" +#include "inc/Core/ResultIterator.h" %} %include +%include %shared_ptr(AnnIndex) %shared_ptr(QueryResult) +%shared_ptr(ResultIterator) %include "JavaCommon.i" %{ @@ -15,3 +18,4 @@ %include "CoreInterface.h" %include "../../AnnService/inc/Core/SearchResult.h" +%include "../../AnnService/inc/Core/ResultIterator.h" diff --git a/Wrappers/inc/JavaFileIO.i b/Wrappers/inc/JavaFileIO.i new file mode 100644 index 000000000..333abf241 --- /dev/null +++ b/Wrappers/inc/JavaFileIO.i @@ -0,0 +1,16 @@ +%module JAVAFileIO + +%{ +#include "inc/FileIOInterface.h" +%} + +%include +%include +%include "std_string.i" +%include "std_vector.i" +// %shared_ptr(FileIOInterface) +%include "JavaCommon.i" + +%include "../../AnnService/inc/Helper/KeyValueIO.h" +%include "../../AnnService/inc/Core/SPANN/ExtraFileController.h" +%include "FileIOInterface.h" \ No newline at end of file diff --git a/Wrappers/inc/PythonCommon.i b/Wrappers/inc/PythonCommon.i index 7b10d50b9..c27b3e5fd 100644 --- a/Wrappers/inc/PythonCommon.i +++ b/Wrappers/inc/PythonCommon.i @@ -1,18 +1,29 @@ #ifdef SWIGPYTHON +%typemap(out) ByteArray +%{ + { + $result = PyBytes_FromStringAndSize(reinterpret_cast($1.Data()), $1.Length()); + if ($1.Length() > 0) delete[] $1.Data(); + } +%} + %typemap(out) std::shared_ptr %{ { - $result = PyTuple_New(3); + $result = PyTuple_New(4); int resNum = $1->GetResultNum(); auto dstVecIDs = PyList_New(resNum); auto dstVecDists = PyList_New(resNum); auto dstMetadata = PyList_New(resNum); + auto dstRelaxMono = PyList_New(resNum); int i = 0; for (const auto& res : *($1)) { PyList_SetItem(dstVecIDs, i, PyInt_FromLong(res.VID)); PyList_SetItem(dstVecDists, i, PyFloat_FromDouble(res.Dist)); + if (res.RelaxedMono) PyList_SetItem(dstRelaxMono, i, Py_True); + else PyList_SetItem(dstRelaxMono, i, Py_False); i++; } @@ -29,22 +40,26 @@ PyTuple_SetItem($result, 0, dstVecIDs); PyTuple_SetItem($result, 1, dstVecDists); PyTuple_SetItem($result, 2, dstMetadata); + PyTuple_SetItem($result, 3, dstRelaxMono); } %} %typemap(out) std::shared_ptr %{ { - $result = PyTuple_New(3); + $result = PyTuple_New(4); auto dstVecIDs = PyList_New(0); auto dstVecDists = PyList_New(0); auto dstMetadata = PyList_New(0); + auto dstRelaxMono = PyList_New(0); for (const auto& indexRes : $1->m_allIndexResults) { for (const auto& res : indexRes.m_results) { PyList_Append(dstVecIDs, PyInt_FromLong(res.VID)); PyList_Append(dstVecDists, PyFloat_FromDouble(res.Dist)); + if (res.RelaxedMono) PyList_Append(dstRelaxMono, Py_True); + else PyList_Append(dstRelaxMono, Py_False); } if (indexRes.m_results.WithMeta()) @@ -60,6 +75,7 @@ PyTuple_SetItem($result, 0, dstVecIDs); PyTuple_SetItem($result, 1, dstVecDists); PyTuple_SetItem($result, 2, dstMetadata); + PyTuple_SetItem($result, 3, dstRelaxMono); } %} diff --git a/Wrappers/inc/PythonCore.i b/Wrappers/inc/PythonCore.i index d2f38ca85..68c2acacd 100644 --- a/Wrappers/inc/PythonCore.i +++ b/Wrappers/inc/PythonCore.i @@ -2,15 +2,19 @@ %{ #include "inc/CoreInterface.h" +#include "inc/Core/ResultIterator.h" %} %include +%include %shared_ptr(AnnIndex) %shared_ptr(QueryResult) +%shared_ptr(ResultIterator) %include "PythonCommon.i" %{ #define SWIG_FILE_WITH_INIT %} -%include "CoreInterface.h" \ No newline at end of file +%include "CoreInterface.h" +%include "../../AnnService/inc/Core/ResultIterator.h" \ No newline at end of file diff --git a/Wrappers/inc/TransferDataType.h b/Wrappers/inc/TransferDataType.h index 51ef9614a..0bb42404c 100644 --- a/Wrappers/inc/TransferDataType.h +++ b/Wrappers/inc/TransferDataType.h @@ -6,6 +6,7 @@ #include "inc/Core/CommonDataStructure.h" #include "inc/Core/SearchQuery.h" +#include "inc/Core/ResultIterator.h" #include "inc/Socket/RemoteSearchQuery.h" typedef SPTAG::ByteArray ByteArray; diff --git a/Wrappers/packages.config b/Wrappers/packages.config index d4000416c..0d81e9722 100644 --- a/Wrappers/packages.config +++ b/Wrappers/packages.config @@ -1,13 +1,13 @@  - - - - - - - - - - + + + + + + + + + + \ No newline at end of file diff --git a/Wrappers/src/CLRCoreInterface.cpp b/Wrappers/src/CLRCoreInterface.cpp index 883799a43..1ec9276dd 100644 --- a/Wrappers/src/CLRCoreInterface.cpp +++ b/Wrappers/src/CLRCoreInterface.cpp @@ -3,233 +3,423 @@ #include "inc/CLRCoreInterface.h" - namespace Microsoft { - namespace ANN +namespace ANN +{ +namespace SPTAGManaged +{ +RIterator::RIterator(std::shared_ptr result_iterator) : ManagedObject(result_iterator) +{ +} + +array ^ RIterator::Next(int p_batch) +{ + array ^ res; + if (m_Instance == nullptr) + return res; + + std::shared_ptr results = (*m_Instance)->Next(p_batch); + + res = gcnew array(results->GetResultNum()); + for (int i = 0; i < results->GetResultNum(); i++) + res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results->GetResult(i)))); + + return res; +} + +bool RIterator::GetRelaxedMono() +{ + return (*m_Instance)->GetRelaxedMono(); +} + +void RIterator::Close() +{ + (*m_Instance)->Close(); +} + +AnnIndex::AnnIndex(std::shared_ptr p_index) : ManagedObject(p_index) +{ + m_dimension = p_index->GetFeatureDim(); + m_inputVectorSize = SPTAG::GetValueTypeSize(p_index->GetVectorValueType()) * m_dimension; +} + +AnnIndex::AnnIndex(String ^ p_algoType, String ^ p_valueType, int p_dimension) + : ManagedObject(SPTAG::VectorIndex::CreateInstance(string_to(p_algoType), + string_to(p_valueType))) +{ + m_dimension = p_dimension; + m_inputVectorSize = SPTAG::GetValueTypeSize((*m_Instance)->GetVectorValueType()) * m_dimension; +} + +void AnnIndex::SetBuildParam(String ^ p_name, String ^ p_value, String ^ p_section) +{ + if (m_Instance != nullptr) + (*m_Instance) + ->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value), + string_to_char_array(p_section)); +} + +void AnnIndex::SetSearchParam(String ^ p_name, String ^ p_value, String ^ p_section) +{ + if (m_Instance != nullptr) + (*m_Instance) + ->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value), + string_to_char_array(p_section)); +} + +bool AnnIndex::LoadQuantizer(String ^ p_quantizerFile) +{ + if (m_Instance == nullptr) + return false; + + auto ret = ((*m_Instance)->LoadQuantizer(string_to_char_array(p_quantizerFile)) == SPTAG::ErrorCode::Success); + if (ret) + { + m_inputVectorSize = (*m_Instance)->m_pQuantizer->QuantizeSize(); + } + return ret; +} + +void AnnIndex::SetQuantizerADC(bool p_adc) +{ + if (m_Instance != nullptr) + (*m_Instance)->SetQuantizerADC(p_adc); +} + +array ^ AnnIndex::QuantizeVector(array ^ p_data, int p_num) +{ + array ^ res; + if (m_Instance != nullptr && (*m_Instance)->GetQuantizer() != nullptr) { - namespace SPTAGManaged - { - AnnIndex::AnnIndex(std::shared_ptr p_index) : - ManagedObject(p_index) - { - m_dimension = p_index->GetFeatureDim(); - m_inputVectorSize = SPTAG::GetValueTypeSize(p_index->GetVectorValueType()) * m_dimension; - } - - AnnIndex::AnnIndex(String^ p_algoType, String^ p_valueType, int p_dimension) : - ManagedObject(SPTAG::VectorIndex::CreateInstance(string_to(p_algoType), string_to(p_valueType))) - { - m_dimension = p_dimension; - m_inputVectorSize = SPTAG::GetValueTypeSize((*m_Instance)->GetVectorValueType()) * m_dimension; - } - - void AnnIndex::SetBuildParam(String^ p_name, String^ p_value, String^ p_section) - { - if (m_Instance != nullptr) - (*m_Instance)->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value), string_to_char_array(p_section)); - } - - void AnnIndex::SetSearchParam(String^ p_name, String^ p_value, String^ p_section) - { - if (m_Instance != nullptr) - (*m_Instance)->SetParameter(string_to_char_array(p_name), string_to_char_array(p_value), string_to_char_array(p_section)); - } - - bool AnnIndex::Build(array^ p_data, int p_num) - { - return Build(p_data, p_num, false); - } - - bool AnnIndex::Build(array^ p_data, int p_num, bool p_normalized) - { - if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) - return false; - - pin_ptr ptr = &p_data[0]; - return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(ptr, p_num, m_dimension, p_normalized)); - } - - bool AnnIndex::BuildWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex) - { - return BuildWithMetaData(p_data, p_meta, p_num, p_withMetaIndex, false); - } - - bool AnnIndex::BuildWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) - { - if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) - return false; - - pin_ptr dataptr = &p_data[0]; - std::shared_ptr vectors(new SPTAG::BasicVectorSet(SPTAG::ByteArray(dataptr, p_data->LongLength, false), (*m_Instance)->GetVectorValueType(), m_dimension, p_num)); - - pin_ptr metaptr = &p_meta[0]; - std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; - if (!SPTAG::MetadataSet::GetMetadataOffsets(metaptr, p_meta->LongLength, offsets, p_num + 1, '\n')) return false; - std::shared_ptr meta(new SPTAG::MemMetadataSet(SPTAG::ByteArray(metaptr, p_meta->LongLength, false), - SPTAG::ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num, - (*m_Instance)->m_iDataBlockSize, (*m_Instance)->m_iDataCapacity, (*m_Instance)->m_iMetaRecordSize)); - return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(vectors, meta, p_withMetaIndex, p_normalized)); - } - - array^ AnnIndex::Search(array^ p_data, int p_resultNum) - { - array^ res; - if (m_Instance == nullptr || m_dimension == 0 || p_data->LongLength != m_inputVectorSize) - return res; - - pin_ptr ptr = &p_data[0]; - SPTAG::QueryResult results(ptr, p_resultNum, false); - (*m_Instance)->SearchIndex(results); - - res = gcnew array(p_resultNum); - for (int i = 0; i < p_resultNum; i++) - res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); - - return res; - } - - array^ AnnIndex::SearchWithMetaData(array^ p_data, int p_resultNum) - { - array^ res; - if (m_Instance == nullptr || m_dimension == 0 || p_data->LongLength != m_inputVectorSize) - return res; - - pin_ptr ptr = &p_data[0]; - SPTAG::QueryResult results(ptr, p_resultNum, true); - (*m_Instance)->SearchIndex(results); - - res = gcnew array(p_resultNum); - for (int i = 0; i < p_resultNum; i++) - res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); - - return res; - } - - bool AnnIndex::Save(String^ p_saveFile) - { - return SPTAG::ErrorCode::Success == (*m_Instance)->SaveIndex(string_to_char_array(p_saveFile)); - } - - array^>^ AnnIndex::Dump() - { - std::shared_ptr> buffersize = (*m_Instance)->CalculateBufferSize(); - array^>^ res = gcnew array^>(buffersize->size() + 1); - std::vector indexBlobs; - for (int i = 1; i < res->Length; i++) - { - res[i] = gcnew array(buffersize->at(i-1)); - pin_ptr ptr = &res[i][0]; - indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t*)ptr, res[i]->LongLength, false)); - } - std::string config; - if (SPTAG::ErrorCode::Success != (*m_Instance)->SaveIndex(config, indexBlobs)) - { - array^>^ null; - return null; - } - res[0] = gcnew array(config.size()); - Marshal::Copy(IntPtr(&config[0]), res[0], 0, config.size()); - return res; - } - - bool AnnIndex::Add(array^ p_data, int p_num) - { - return Add(p_data, p_num, false); - } - - - bool AnnIndex::Add(array^ p_data, int p_num, bool p_normalized) - { - if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) - return false; - - pin_ptr ptr = &p_data[0]; - return (SPTAG::ErrorCode::Success == (*m_Instance)->AddIndex(ptr, p_num, m_dimension, nullptr, false, p_normalized)); - } - - - bool AnnIndex::AddWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex) - { - return AddWithMetaData(p_data, p_meta, p_num, p_withMetaIndex, false); - } - - bool AnnIndex::AddWithMetaData(array^ p_data, array^ p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) - { - if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) - return false; - - pin_ptr dataptr = &p_data[0]; - std::shared_ptr vectors(new SPTAG::BasicVectorSet(SPTAG::ByteArray(dataptr, p_data->LongLength, false), (*m_Instance)->GetVectorValueType(), m_dimension, p_num)); - - pin_ptr metaptr = &p_meta[0]; - std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; - if (!SPTAG::MetadataSet::GetMetadataOffsets(metaptr, p_meta->LongLength, offsets, p_num + 1, '\n')) return false; - std::shared_ptr meta(new SPTAG::MemMetadataSet(SPTAG::ByteArray(metaptr, p_meta->LongLength, false), SPTAG::ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); - return (SPTAG::ErrorCode::Success == (*m_Instance)->AddIndex(vectors, meta, p_withMetaIndex, p_normalized)); - } - - bool AnnIndex::Delete(array^ p_data, int p_num) - { - if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) - return false; - - pin_ptr ptr = &p_data[0]; - return (SPTAG::ErrorCode::Success == (*m_Instance)->DeleteIndex(ptr, p_num)); - } - - bool AnnIndex::DeleteByMetaData(array^ p_meta) - { - if (m_Instance == nullptr) - return false; - - pin_ptr metaptr = &p_meta[0]; - return (SPTAG::ErrorCode::Success == (*m_Instance)->DeleteIndex(SPTAG::ByteArray(metaptr, p_meta->LongLength, false))); - } - - AnnIndex^ AnnIndex::Load(String^ p_loaderFile) - { - std::shared_ptr vecIndex; - AnnIndex^ res; - if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(string_to_char_array(p_loaderFile), vecIndex) || nullptr == vecIndex) - { - res = gcnew AnnIndex(nullptr); - } - else { - res = gcnew AnnIndex(vecIndex); - } - return res; - } - - AnnIndex^ AnnIndex::Load(array^>^ p_index) - { - std::vector p_indexBlobs; - for (int i = 1; i < p_index->Length; i++) - { - pin_ptr ptr = &p_index[i][0]; - p_indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t*)ptr, p_index[i]->LongLength, false)); - } - pin_ptr configptr = &p_index[0][0]; - - std::shared_ptr vecIndex; - if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(std::string((char*)configptr, p_index[0]->LongLength), p_indexBlobs, vecIndex) || nullptr == vecIndex) - { - return gcnew AnnIndex(nullptr); - } - return gcnew AnnIndex(vecIndex); - } - - AnnIndex^ AnnIndex::Merge(String^ p_indexFilePath1, String^ p_indexFilePath2) - { - AnnIndex^ res = Load(p_indexFilePath1); - AnnIndex^ add = Load(p_indexFilePath2); - if (*(res->m_Instance) == nullptr || *(add->m_Instance) == nullptr || - SPTAG::ErrorCode::Success != (*(res->m_Instance))->MergeIndex(add->m_Instance->get(), std::atoi((*(res->m_Instance))->GetParameter("NumberOfThreads").c_str()), nullptr)) - { - return gcnew AnnIndex(nullptr); - } - return res; - } - } + res = gcnew array((*m_Instance)->GetQuantizer()->GetNumSubvectors() * (size_t)p_num); + pin_ptr ptr = &res[0], optr = &p_data[0]; + (*m_Instance)->QuantizeVector(optr, p_num, SPTAG::ByteArray(ptr, res->LongLength, false)); } -} \ No newline at end of file + return res; +} + +array ^ AnnIndex::ReconstructVector(array ^ p_data, int p_num) +{ + array ^ res; + if (m_Instance != nullptr && (*m_Instance)->GetQuantizer() != nullptr) + { + res = gcnew array((*m_Instance)->GetQuantizer()->ReconstructSize() * (size_t)p_num); + pin_ptr ptr = &res[0], optr = &p_data[0]; + (*m_Instance)->ReconstructVector(optr, p_num, SPTAG::ByteArray(ptr, res->LongLength, false)); + } + return res; +} + +bool AnnIndex::BuildSPANN(bool p_normalized) +{ + if (m_Instance == nullptr || (*m_Instance)->GetIndexAlgoType() != SPTAG::IndexAlgoType::SPANN) + return false; + + return (*m_Instance)->BuildIndex(p_normalized) == SPTAG::ErrorCode::Success; +} + +bool AnnIndex::BuildSPANNWithMetaData(array ^ p_meta, int p_num, bool p_withMetaIndex, bool p_normalized) +{ + if (m_Instance == nullptr || (*m_Instance)->GetIndexAlgoType() != SPTAG::IndexAlgoType::SPANN) + return false; + + pin_ptr metaptr = &p_meta[0]; + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(metaptr, p_meta->LongLength, offsets, p_num + 1, '\n')) + return false; + (*m_Instance) + ->SetMetadata(new SPTAG::MemMetadataSet( + SPTAG::ByteArray(metaptr, p_meta->LongLength, false), + SPTAG::ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num, + (*m_Instance)->m_iDataBlockSize, (*m_Instance)->m_iDataCapacity, (*m_Instance)->m_iMetaRecordSize)); + + if (p_withMetaIndex) + (*m_Instance)->BuildMetaMapping(false); + + return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(p_normalized)); +} + +bool AnnIndex::Build(array ^ p_data, int p_num) +{ + return Build(p_data, p_num, false); +} + +bool AnnIndex::Build(array ^ p_data, int p_num, bool p_normalized) +{ + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(ptr, p_num, m_dimension, p_normalized)); +} + +bool AnnIndex::BuildWithMetaData(array ^ p_data, array ^ p_meta, int p_num, bool p_withMetaIndex) +{ + return BuildWithMetaData(p_data, p_meta, p_num, p_withMetaIndex, false); +} + +bool AnnIndex::BuildWithMetaData(array ^ p_data, array ^ p_meta, int p_num, bool p_withMetaIndex, + bool p_normalized) +{ + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr dataptr = &p_data[0]; + auto vectorType = (*m_Instance)->m_pQuantizer ? SPTAG::VectorValueType::UInt8 : (*m_Instance)->GetVectorValueType(); + auto vectorSize = (*m_Instance)->m_pQuantizer ? (*m_Instance)->m_pQuantizer->GetNumSubvectors() : m_dimension; + std::shared_ptr vectors( + new SPTAG::BasicVectorSet(SPTAG::ByteArray(dataptr, p_data->LongLength, false), vectorType, vectorSize, p_num)); + + pin_ptr metaptr = &p_meta[0]; + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(metaptr, p_meta->LongLength, offsets, p_num + 1, '\n')) + return false; + std::shared_ptr meta(new SPTAG::MemMetadataSet( + SPTAG::ByteArray(metaptr, p_meta->LongLength, false), + SPTAG::ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num, + (*m_Instance)->m_iDataBlockSize, (*m_Instance)->m_iDataCapacity, (*m_Instance)->m_iMetaRecordSize)); + return (SPTAG::ErrorCode::Success == (*m_Instance)->BuildIndex(vectors, meta, p_withMetaIndex, p_normalized)); +} + +array ^ AnnIndex::Search(array ^ p_data, int p_resultNum) +{ + array ^ res; + if (m_Instance == nullptr) + return res; + + pin_ptr ptr = &p_data[0]; + SPTAG::QueryResult results(ptr, p_resultNum, false); + (*m_Instance)->SearchIndex(results); + + res = gcnew array(p_resultNum); + for (int i = 0; i < p_resultNum; i++) + res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); + + return res; +} + +array ^ AnnIndex::SearchWithMetaData(array ^ p_data, int p_resultNum) +{ + array ^ res; + if (m_Instance == nullptr) + return res; + + pin_ptr ptr = &p_data[0]; + SPTAG::QueryResult results(ptr, p_resultNum, true); + (*m_Instance)->SearchIndex(results); + + res = gcnew array(p_resultNum); + for (int i = 0; i < p_resultNum; i++) + res[i] = gcnew BasicResult(new SPTAG::BasicResult(*(results.GetResult(i)))); + + return res; +} + +RIterator ^ AnnIndex::GetIterator(array ^ p_data) +{ + RIterator ^ res; + if (m_Instance == nullptr || m_dimension == 0 || p_data->LongLength != m_inputVectorSize) + return res; + + pin_ptr ptr = &p_data[0]; + std::shared_ptr result_iterator = (*m_Instance)->GetIterator(ptr); + + res = gcnew RIterator(result_iterator); + return res; +} + +void AnnIndex::UpdateIndex() +{ + if (m_Instance != nullptr) + (*m_Instance)->UpdateIndex(); +} + +bool AnnIndex::Save(String ^ p_saveFile) +{ + return SPTAG::ErrorCode::Success == (*m_Instance)->SaveIndex(string_to_char_array(p_saveFile)); +} + +array ^> ^ AnnIndex::Dump() +{ + std::shared_ptr> buffersize = (*m_Instance)->CalculateBufferSize(); + array ^> ^ res = gcnew array ^>(buffersize->size() + 1); + std::vector indexBlobs; + for (int i = 1; i < res->Length; i++) + { + res[i] = gcnew array(buffersize->at(i - 1)); + pin_ptr ptr = &res[i][0]; + indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t *)ptr, res[i]->LongLength, false)); + } + std::string config; + if (SPTAG::ErrorCode::Success != (*m_Instance)->SaveIndex(config, indexBlobs)) + { + array ^> ^ null; + return null; + } + res[0] = gcnew array(config.size()); + Marshal::Copy(IntPtr(&config[0]), res[0], 0, config.size()); + return res; +} + +bool AnnIndex::Add(array ^ p_data, int p_num) +{ + return Add(p_data, p_num, false); +} + +bool AnnIndex::Add(array ^ p_data, int p_num, bool p_normalized) +{ + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == + (*m_Instance)->AddIndex(ptr, p_num, m_dimension, nullptr, false, p_normalized)); +} + +bool AnnIndex::AddWithMetaData(array ^ p_data, array ^ p_meta, int p_num, bool p_withMetaIndex) +{ + return AddWithMetaData(p_data, p_meta, p_num, p_withMetaIndex, false); +} + +bool AnnIndex::AddWithMetaData(array ^ p_data, array ^ p_meta, int p_num, bool p_withMetaIndex, + bool p_normalized) +{ + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr dataptr = &p_data[0]; + std::shared_ptr vectors(new SPTAG::BasicVectorSet( + SPTAG::ByteArray(dataptr, p_data->LongLength, false), (*m_Instance)->GetVectorValueType(), m_dimension, p_num)); + + pin_ptr metaptr = &p_meta[0]; + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(metaptr, p_meta->LongLength, offsets, p_num + 1, '\n')) + return false; + std::shared_ptr meta(new SPTAG::MemMetadataSet( + SPTAG::ByteArray(metaptr, p_meta->LongLength, false), + SPTAG::ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); + return (SPTAG::ErrorCode::Success == (*m_Instance)->AddIndex(vectors, meta, p_withMetaIndex, p_normalized)); +} + +bool AnnIndex::Delete(array ^ p_data, int p_num) +{ + if (m_Instance == nullptr || p_num == 0 || m_dimension == 0 || p_data->LongLength != p_num * m_inputVectorSize) + return false; + + pin_ptr ptr = &p_data[0]; + return (SPTAG::ErrorCode::Success == (*m_Instance)->DeleteIndex(ptr, p_num)); +} + +bool AnnIndex::DeleteByMetaData(array ^ p_meta) +{ + if (m_Instance == nullptr) + return false; + + pin_ptr metaptr = &p_meta[0]; + return (SPTAG::ErrorCode::Success == + (*m_Instance)->DeleteIndex(SPTAG::ByteArray(metaptr, p_meta->LongLength, false))); +} + +AnnIndex ^ AnnIndex::Load(String ^ p_loaderFile) +{ + std::shared_ptr vecIndex; + AnnIndex ^ res; + if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(string_to_char_array(p_loaderFile), vecIndex) || + nullptr == vecIndex) + { + res = gcnew AnnIndex(nullptr); + } + else + { + res = gcnew AnnIndex(vecIndex); + } + return res; +} + +AnnIndex ^ AnnIndex::Load(array ^> ^ p_index) +{ + std::vector p_indexBlobs; + for (int i = 1; i < p_index->Length; i++) + { + pin_ptr ptr = &p_index[i][0]; + p_indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t *)ptr, p_index[i]->LongLength, false)); + } + pin_ptr configptr = &p_index[0][0]; + + std::shared_ptr vecIndex; + if (SPTAG::ErrorCode::Success != + SPTAG::VectorIndex::LoadIndex(std::string((char *)configptr, p_index[0]->LongLength), p_indexBlobs, + vecIndex) || + nullptr == vecIndex) + { + return gcnew AnnIndex(nullptr); + } + return gcnew AnnIndex(vecIndex); +} + +AnnIndex ^ AnnIndex::Merge(String ^ p_indexFilePath1, String ^ p_indexFilePath2) +{ + AnnIndex ^ res = Load(p_indexFilePath1); + AnnIndex ^ add = Load(p_indexFilePath2); + if (*(res->m_Instance) == nullptr || *(add->m_Instance) == nullptr || + SPTAG::ErrorCode::Success != + (*(res->m_Instance)) + ->MergeIndex(add->m_Instance->get(), + std::atoi((*(res->m_Instance))->GetParameter("NumberOfThreads").c_str()), nullptr)) + { + return gcnew AnnIndex(nullptr); + } + return res; +} + +MultiIndexScan::MultiIndexScan(std::shared_ptr multi_index_scan) + : ManagedObject(multi_index_scan) +{ +} + +MultiIndexScan::MultiIndexScan(array ^ indice, array ^> ^ p_data, array ^ weight, + int p_resultNum, bool useTimer, int termCondVal, int searchLimit) + : ManagedObject(std::make_shared()) +{ + std::vector> vecIndices; + for (int i = 0; i < indice->Length; i++) + { + std::shared_ptr index = *(indice[i]->GetInstance()); + vecIndices.push_back(index); + } + + std::vector data_array; + for (int i = 0; i < p_data->Length; i++) + { + pin_ptr ptr = &p_data[i][0]; + SPTAG::ByteArray byte_target = SPTAG::ByteArray::Alloc(p_data[i]->LongLength); + byte_target.Set((std::uint8_t *)ptr, p_data[i]->LongLength, false); + data_array.push_back(byte_target); + } + + std::vector weight_array; + for (int i = 0; i < weight->Length; i++) + { + float w = weight[i]; + weight_array.push_back(w); + } + + (*m_Instance)->Init(vecIndices, data_array, weight_array, p_resultNum, useTimer, termCondVal, searchLimit); +} + +BasicResult ^ MultiIndexScan::Next() +{ + SPTAG::BasicResult *result = new SPTAG::BasicResult(); + (*m_Instance)->Next(*result); + return gcnew BasicResult(result); +} + +void MultiIndexScan::Close() +{ + (*m_Instance)->Close(); +} + +} // namespace SPTAGManaged +} // namespace ANN +} // namespace Microsoft \ No newline at end of file diff --git a/Wrappers/src/ClientInterface.cpp b/Wrappers/src/ClientInterface.cpp index f53ccda62..d7ae03e3a 100644 --- a/Wrappers/src/ClientInterface.cpp +++ b/Wrappers/src/ClientInterface.cpp @@ -2,17 +2,15 @@ // Licensed under the MIT License. #include "inc/ClientInterface.h" +#include "inc/Helper/Base64Encode.h" #include "inc/Helper/CommonHelper.h" #include "inc/Helper/Concurrent.h" -#include "inc/Helper/Base64Encode.h" #include "inc/Helper/StringConvert.h" #include - -AnnClient::AnnClient(const char* p_serverAddr, const char* p_serverPort) - : m_connectionID(SPTAG::Socket::c_invalidConnectionID), - m_timeoutInMilliseconds(9000) +AnnClient::AnnClient(const char *p_serverAddr, const char *p_serverPort) + : m_connectionID(SPTAG::Socket::c_invalidConnectionID), m_timeoutInMilliseconds(9000) { using namespace SPTAG; @@ -26,8 +24,7 @@ AnnClient::AnnClient(const char* p_serverAddr, const char* p_serverPort) m_server = p_serverAddr; m_port = p_serverPort; - auto connectCallback = [this](Socket::ConnectionID p_cid, ErrorCode p_ec) - { + auto connectCallback = [this](Socket::ConnectionID p_cid, ErrorCode p_ec) { m_connectionID = p_cid; if (ErrorCode::Socket_FailedResolveEndPoint == p_ec) @@ -45,8 +42,7 @@ AnnClient::AnnClient(const char* p_serverAddr, const char* p_serverPort) m_socketClient->AsyncConnectToServer(m_server, m_port, connectCallback); - m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) - { + m_socketClient->SetEventOnConnectionClose([this](Socket::ConnectionID p_cid) { ErrorCode errCode; m_connectionID = Socket::c_invalidConnectionID; while (Socket::c_invalidConnectionID == m_connectionID) @@ -57,21 +53,16 @@ AnnClient::AnnClient(const char* p_serverAddr, const char* p_serverPort) }); } - AnnClient::~AnnClient() { } - -void -AnnClient::SetTimeoutMilliseconds(int p_timeout) +void AnnClient::SetTimeoutMilliseconds(int p_timeout) { m_timeoutInMilliseconds = p_timeout; } - -void -AnnClient::SetSearchParam(const char* p_name, const char* p_value) +void AnnClient::SetSearchParam(const char *p_name, const char *p_value) { std::lock_guard guard(m_paramMutex); @@ -92,17 +83,14 @@ AnnClient::SetSearchParam(const char* p_name, const char* p_value) m_params[name] = p_value; } - -void -AnnClient::ClearSearchParam() +void AnnClient::ClearSearchParam() { std::lock_guard guard(m_paramMutex); m_params.clear(); } - -std::shared_ptr -AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bool p_withMetaData) +std::shared_ptr AnnClient::Search(ByteArray p_data, int p_resultNum, const char *p_valueType, + bool p_withMetaData) { using namespace SPTAG; @@ -116,8 +104,7 @@ AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bo auto signal = std::make_shared(1); - auto callback = [&ret, signal](RemoteSearchResult p_result) - { + auto callback = [&ret, signal](RemoteSearchResult p_result) { if (RemoteSearchResult::ResultStatus::Success == p_result.m_status) { ret = std::move(p_result); @@ -126,8 +113,7 @@ AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bo signal->FinishOne(); }; - auto timeoutCallback = [this](std::shared_ptr p_callback) - { + auto timeoutCallback = [this](std::shared_ptr p_callback) { if (nullptr != p_callback) { RemoteSearchResult result; @@ -137,8 +123,7 @@ AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bo } }; - auto connectCallback = [callback, this](bool p_connectSucc) - { + auto connectCallback = [callback, this](bool p_connectSucc) { if (!p_connectSucc) { RemoteSearchResult result; @@ -153,8 +138,7 @@ AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bo packet.Header().m_packetType = Socket::PacketType::SearchRequest; packet.Header().m_processStatus = Socket::PacketProcessStatus::Ok; packet.Header().m_resourceID = m_callbackManager.Add(std::make_shared(std::move(callback)), - m_timeoutInMilliseconds, - std::move(timeoutCallback)); + m_timeoutInMilliseconds, std::move(timeoutCallback)); Socket::RemoteQuery query; query.m_queryString = CreateSearchQuery(p_data, p_resultNum, p_withMetaData, valueType); @@ -168,39 +152,30 @@ AnnClient::Search(ByteArray p_data, int p_resultNum, const char* p_valueType, bo signal->Wait(); } - else { - LOG(Helper::LogLevel::LL_Error, "Error connection or data type!"); + else + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Error connection or data type!"); } return std::make_shared(ret); } - -bool -AnnClient::IsConnected() const +bool AnnClient::IsConnected() const { return m_connectionID != SPTAG::Socket::c_invalidConnectionID; } - -SPTAG::Socket::PacketHandlerMapPtr -AnnClient::GetHandlerMap() +SPTAG::Socket::PacketHandlerMapPtr AnnClient::GetHandlerMap() { using namespace SPTAG; Socket::PacketHandlerMapPtr handlerMap(new Socket::PacketHandlerMap); - handlerMap->emplace(Socket::PacketType::SearchResponse, - std::bind(&AnnClient::SearchResponseHanlder, - this, - std::placeholders::_1, - std::placeholders::_2)); + handlerMap->emplace(Socket::PacketType::SearchResponse, std::bind(&AnnClient::SearchResponseHanlder, this, + std::placeholders::_1, std::placeholders::_2)); return handlerMap; } - -void -AnnClient::SearchResponseHanlder(SPTAG::Socket::ConnectionID p_localConnectionID, - SPTAG::Socket::Packet p_packet) +void AnnClient::SearchResponseHanlder(SPTAG::Socket::ConnectionID p_localConnectionID, SPTAG::Socket::Packet p_packet) { using namespace SPTAG; @@ -225,12 +200,8 @@ AnnClient::SearchResponseHanlder(SPTAG::Socket::ConnectionID p_localConnectionID } } - -std::string -AnnClient::CreateSearchQuery(const ByteArray& p_data, - int p_resultNum, - bool p_extractMetadata, - SPTAG::VectorValueType p_valueType) +std::string AnnClient::CreateSearchQuery(const ByteArray &p_data, int p_resultNum, bool p_extractMetadata, + SPTAG::VectorValueType p_valueType) { std::stringstream out; @@ -244,7 +215,7 @@ AnnClient::CreateSearchQuery(const ByteArray& p_data, { std::lock_guard guard(m_paramMutex); - for (const auto& param : m_params) + for (const auto ¶m : m_params) { out << " $" << param.first << ":" << param.second; } @@ -252,4 +223,3 @@ AnnClient::CreateSearchQuery(const ByteArray& p_data, return out.str(); } - diff --git a/Wrappers/src/CoreInterface.cpp b/Wrappers/src/CoreInterface.cpp index 30e4e5ab6..62c37639f 100644 --- a/Wrappers/src/CoreInterface.cpp +++ b/Wrappers/src/CoreInterface.cpp @@ -4,19 +4,14 @@ #include "inc/CoreInterface.h" #include "inc/Helper/StringConvert.h" - AnnIndex::AnnIndex(DimensionType p_dimension) - : m_algoType(SPTAG::IndexAlgoType::BKT), - m_inputValueType(SPTAG::VectorValueType::Float), - m_dimension(p_dimension) + : m_algoType(SPTAG::IndexAlgoType::BKT), m_inputValueType(SPTAG::VectorValueType::Float), m_dimension(p_dimension) { m_inputVectorSize = SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; } - -AnnIndex::AnnIndex(const char* p_algoType, const char* p_valueType, DimensionType p_dimension) - : m_algoType(SPTAG::IndexAlgoType::Undefined), - m_inputValueType(SPTAG::VectorValueType::Undefined), +AnnIndex::AnnIndex(const char *p_algoType, const char *p_valueType, DimensionType p_dimension) + : m_algoType(SPTAG::IndexAlgoType::Undefined), m_inputValueType(SPTAG::VectorValueType::Undefined), m_dimension(p_dimension) { SPTAG::Helper::Convert::ConvertStringTo(p_algoType, m_algoType); @@ -24,57 +19,53 @@ AnnIndex::AnnIndex(const char* p_algoType, const char* p_valueType, DimensionTyp m_inputVectorSize = SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; } - -AnnIndex::AnnIndex(const std::shared_ptr& p_index) - : m_algoType(p_index->GetIndexAlgoType()), - m_inputValueType(p_index->GetVectorValueType()), - m_dimension(p_index->GetFeatureDim()), - m_index(p_index) +AnnIndex::AnnIndex(const std::shared_ptr &p_index) + : m_algoType(p_index->GetIndexAlgoType()), m_inputValueType(p_index->GetVectorValueType()), + m_dimension(p_index->GetFeatureDim()), m_index(p_index) { - m_inputVectorSize = p_index->m_pQuantizer ? p_index->m_pQuantizer->GetNumSubvectors() : SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; + m_inputVectorSize = p_index->m_pQuantizer ? p_index->m_pQuantizer->GetNumSubvectors() + : SPTAG::GetValueTypeSize(m_inputValueType) * m_dimension; } - AnnIndex::~AnnIndex() { } - - -bool -AnnIndex::BuildSPANN(bool p_normalized) +bool AnnIndex::BuildSPANN(bool p_normalized) { if (nullptr == m_index) { m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); } - if (nullptr == m_index) return false; + if (nullptr == m_index) + return false; return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_normalized)); } -bool -AnnIndex::BuildSPANNWithMetaData(ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized) +bool AnnIndex::BuildSPANNWithMetaData(ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized) { if (nullptr == m_index) { m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); } - if (nullptr == m_index) return false; + if (nullptr == m_index) + return false; - std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; - if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) return false; + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) + return false; - m_index->SetMetadata((new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num, + m_index->SetMetadata((new SPTAG::MemMetadataSet( + p_meta, ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num, m_index->m_iDataBlockSize, m_index->m_iDataCapacity, m_index->m_iMetaRecordSize))); - if (p_withMetaIndex) m_index->BuildMetaMapping(false); + if (p_withMetaIndex) + m_index->BuildMetaMapping(false); return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_normalized)); } - -bool -AnnIndex::Build(ByteArray p_data, SizeType p_num, bool p_normalized) +bool AnnIndex::Build(ByteArray p_data, SizeType p_num, bool p_normalized) { if (nullptr == m_index) { @@ -84,12 +75,12 @@ AnnIndex::Build(ByteArray p_data, SizeType p_num, bool p_normalized) { return false; } - return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_data.Data(), (SPTAG::SizeType)p_num, (SPTAG::DimensionType)m_dimension, p_normalized)); + return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_data.Data(), (SPTAG::SizeType)p_num, + (SPTAG::DimensionType)m_dimension, p_normalized)); } - -bool -AnnIndex::BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized) +bool AnnIndex::BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, + bool p_normalized) { if (nullptr == m_index) { @@ -99,53 +90,52 @@ AnnIndex::BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, { return false; } - + auto vectorType = m_index->m_pQuantizer ? SPTAG::VectorValueType::UInt8 : m_inputValueType; auto vectorSize = m_index->m_pQuantizer ? m_index->m_pQuantizer->GetNumSubvectors() : m_dimension; - std::shared_ptr vectors(new SPTAG::BasicVectorSet(p_data, - vectorType, - static_cast(vectorSize), - static_cast(p_num))); - - std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; - if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) return false; - std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num, + std::shared_ptr vectors(new SPTAG::BasicVectorSet( + p_data, vectorType, static_cast(vectorSize), static_cast(p_num))); + + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) + return false; + std::shared_ptr meta(new SPTAG::MemMetadataSet( + p_meta, ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num, m_index->m_iDataBlockSize, m_index->m_iDataCapacity, m_index->m_iMetaRecordSize)); return (SPTAG::ErrorCode::Success == m_index->BuildIndex(vectors, meta, p_withMetaIndex, p_normalized)); } - -void -AnnIndex::SetBuildParam(const char* p_name, const char* p_value, const char* p_section) +void AnnIndex::SetBuildParam(const char *p_name, const char *p_value, const char *p_section) { - if (nullptr == m_index) + if (nullptr == m_index) { - if (SPTAG::IndexAlgoType::Undefined == m_algoType || - SPTAG::VectorValueType::Undefined == m_inputValueType) + if (SPTAG::IndexAlgoType::Undefined == m_algoType || SPTAG::VectorValueType::Undefined == m_inputValueType) { - return; + return; } m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); - } m_index->SetParameter(p_name, p_value, p_section); } - -void -AnnIndex::SetSearchParam(const char* p_name, const char* p_value, const char* p_section) +void AnnIndex::SetSearchParam(const char *p_name, const char *p_value, const char *p_section) { - if (nullptr != m_index) m_index->SetParameter(p_name, p_value, p_section); + if (nullptr != m_index) + m_index->SetParameter(p_name, p_value, p_section); } +std::shared_ptr AnnIndex::GetIterator(ByteArray p_target) +{ + if (nullptr != m_index) + return m_index->GetIterator(p_target.Data()); + return nullptr; +} -bool -AnnIndex::LoadQuantizer(const char* p_quantizerFile) +bool AnnIndex::LoadQuantizer(const char *p_quantizerFile) { if (nullptr == m_index) { - if (SPTAG::IndexAlgoType::Undefined == m_algoType || - SPTAG::VectorValueType::Undefined == m_inputValueType) + if (SPTAG::IndexAlgoType::Undefined == m_algoType || SPTAG::VectorValueType::Undefined == m_inputValueType) { return false; } @@ -160,16 +150,41 @@ AnnIndex::LoadQuantizer(const char* p_quantizerFile) return ret; } +void AnnIndex::SetQuantizerADC(bool p_adc) +{ + if (nullptr != m_index) + return m_index->SetQuantizerADC(p_adc); +} -void -AnnIndex::SetQuantizerADC(bool p_adc) +ByteArray AnnIndex::QuantizeVector(ByteArray p_data, int p_num) { - if (nullptr != m_index) return m_index->SetQuantizerADC(p_adc); + if (nullptr != m_index && m_index->GetQuantizer() != nullptr) + { + size_t outsize = m_index->GetQuantizer()->GetNumSubvectors() * (size_t)p_num; + std::uint8_t *outdata = new std::uint8_t[outsize]; + if (SPTAG::ErrorCode::Success != + m_index->QuantizeVector(p_data.Data(), p_num, ByteArray(outdata, outsize, false))) + return ByteArray::c_empty; + return ByteArray(outdata, outsize, false); + } + return ByteArray::c_empty; } +ByteArray AnnIndex::ReconstructVector(ByteArray p_data, int p_num) +{ + if (nullptr != m_index && m_index->GetQuantizer() != nullptr) + { + size_t outsize = m_index->GetQuantizer()->ReconstructSize() * (size_t)p_num; + std::uint8_t *outdata = new std::uint8_t[outsize]; + if (SPTAG::ErrorCode::Success != + m_index->ReconstructVector(p_data.Data(), p_num, ByteArray(outdata, outsize, false))) + return ByteArray::c_empty; + return ByteArray(outdata, outsize, false); + } + return ByteArray::c_empty; +} -std::shared_ptr -AnnIndex::Search(ByteArray p_data, int p_resultNum) +std::shared_ptr AnnIndex::Search(ByteArray p_data, int p_resultNum) { std::shared_ptr results = std::make_shared(p_data.Data(), p_resultNum, false); @@ -180,8 +195,7 @@ AnnIndex::Search(ByteArray p_data, int p_resultNum) return std::move(results); } -std::shared_ptr -AnnIndex::SearchWithMetaData(ByteArray p_data, int p_resultNum) +std::shared_ptr AnnIndex::SearchWithMetaData(ByteArray p_data, int p_resultNum) { std::shared_ptr results = std::make_shared(p_data.Data(), p_resultNum, true); @@ -192,10 +206,11 @@ AnnIndex::SearchWithMetaData(ByteArray p_data, int p_resultNum) return std::move(results); } -std::shared_ptr -AnnIndex::BatchSearch(ByteArray p_data, int p_vectorNum, int p_resultNum, bool p_withMetaData) +std::shared_ptr AnnIndex::BatchSearch(ByteArray p_data, int p_vectorNum, int p_resultNum, + bool p_withMetaData) { - std::shared_ptr results = std::make_shared(p_data.Data(), p_vectorNum * p_resultNum, p_withMetaData); + std::shared_ptr results = + std::make_shared(p_data.Data(), p_vectorNum * p_resultNum, p_withMetaData); if (nullptr != m_index) { m_index->SearchIndex(p_data.Data(), p_vectorNum, p_resultNum, p_withMetaData, results->GetResults()); @@ -203,43 +218,22 @@ AnnIndex::BatchSearch(ByteArray p_data, int p_vectorNum, int p_resultNum, bool p return std::move(results); } -bool -AnnIndex::ReadyToServe() const +bool AnnIndex::ReadyToServe() const { return m_index != nullptr; } - -void -AnnIndex::UpdateIndex() +void AnnIndex::UpdateIndex() { m_index->UpdateIndex(); } - -bool -AnnIndex::Save(const char* p_savefile) const +bool AnnIndex::Save(const char *p_savefile) const { return SPTAG::ErrorCode::Success == m_index->SaveIndex(p_savefile); } - -AnnIndex -AnnIndex::Load(const char* p_loaderFile) -{ - std::shared_ptr vecIndex; - auto ret = SPTAG::VectorIndex::LoadIndex(p_loaderFile, vecIndex); - if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) - { - return AnnIndex(0); - } - - return AnnIndex(vecIndex); -} - - -bool -AnnIndex::Add(ByteArray p_data, SizeType p_num, bool p_normalized) +bool AnnIndex::Add(ByteArray p_data, SizeType p_num, bool p_normalized) { if (nullptr == m_index) { @@ -250,12 +244,13 @@ AnnIndex::Add(ByteArray p_data, SizeType p_num, bool p_normalized) return false; } - return (SPTAG::ErrorCode::Success == m_index->AddIndex(p_data.Data(), (SPTAG::SizeType)p_num, (SPTAG::DimensionType)m_dimension, nullptr, false, p_normalized)); + return (SPTAG::ErrorCode::Success == m_index->AddIndex(p_data.Data(), (SPTAG::SizeType)p_num, + (SPTAG::DimensionType)m_dimension, nullptr, false, + p_normalized)); } - -bool -AnnIndex::AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, bool p_normalized) +bool AnnIndex::AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bool p_withMetaIndex, + bool p_normalized) { if (nullptr == m_index) { @@ -266,20 +261,18 @@ AnnIndex::AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num, bo return false; } - std::shared_ptr vectors(new SPTAG::BasicVectorSet(p_data, - m_inputValueType, - static_cast(m_dimension), - static_cast(p_num))); + std::shared_ptr vectors(new SPTAG::BasicVectorSet( + p_data, m_inputValueType, static_cast(m_dimension), static_cast(p_num))); - std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; - if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) return false; - std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num)); + std::uint64_t *offsets = new std::uint64_t[p_num + 1]{0}; + if (!SPTAG::MetadataSet::GetMetadataOffsets(p_meta.Data(), p_meta.Length(), offsets, p_num + 1, '\n')) + return false; + std::shared_ptr meta(new SPTAG::MemMetadataSet( + p_meta, ByteArray((std::uint8_t *)offsets, (p_num + 1) * sizeof(std::uint64_t), true), (SPTAG::SizeType)p_num)); return (SPTAG::ErrorCode::Success == m_index->AddIndex(vectors, meta, p_withMetaIndex, p_normalized)); } - -bool -AnnIndex::Delete(ByteArray p_data, SizeType p_num) +bool AnnIndex::Delete(ByteArray p_data, SizeType p_num) { if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) { @@ -289,23 +282,105 @@ AnnIndex::Delete(ByteArray p_data, SizeType p_num) return (SPTAG::ErrorCode::Success == m_index->DeleteIndex(p_data.Data(), (SPTAG::SizeType)p_num)); } - -bool -AnnIndex::DeleteByMetaData(ByteArray p_meta) +bool AnnIndex::DeleteByMetaData(ByteArray p_meta) { - if (nullptr == m_index) return false; - + if (nullptr == m_index) + return false; + return (SPTAG::ErrorCode::Success == m_index->DeleteIndex(p_meta)); } +uint64_t AnnIndex::CalculateBufferSize() +{ + if (nullptr == m_index) + return 0; + + std::shared_ptr> buffersize = m_index->CalculateBufferSize(); + uint64_t total = sizeof(int) + sizeof(uint64_t) * buffersize->size(); + for (uint64_t bs : *buffersize) + { + total += bs; + } + return total; +} + +ByteArray AnnIndex::Dump(ByteArray p_blobs) +{ + if (nullptr == m_index) + return ByteArray::c_empty; + + std::shared_ptr> buffersize = m_index->CalculateBufferSize(); + std::uint8_t *ptr = p_blobs.Data(), *pdata = ptr + sizeof(int) + sizeof(uint64_t) * buffersize->size(); + *((int *)ptr) = (int)(buffersize->size()); + ptr += sizeof(int); + + std::vector indexBlobs; + for (size_t i = 0; i < buffersize->size(); i++) + { + *((uint64_t *)ptr) = buffersize->at(i); + ptr += sizeof(uint64_t); + indexBlobs.push_back(SPTAG::ByteArray(pdata, buffersize->at(i), false)); + pdata += buffersize->at(i); + } + + std::string config; + if (SPTAG::ErrorCode::Success != m_index->SaveIndex(config, indexBlobs)) + { + return ByteArray::c_empty; + } + std::uint8_t *newdata = new std::uint8_t[config.size()]; + memcpy(newdata, config.c_str(), config.size()); + return ByteArray(newdata, config.size(), false); +} + +AnnIndex AnnIndex::LoadFromDump(ByteArray p_config, ByteArray p_blobs) +{ + if (p_config.Length() == 0) + return AnnIndex(0); + + std::uint8_t *ptr = p_blobs.Data(); + int streamNum = *((int *)ptr); + ptr += sizeof(int); + std::uint8_t *pdata = ptr + sizeof(uint64_t) * streamNum; + + std::vector p_indexBlobs; + for (int i = 0; i < streamNum; i++) + { + std::uint64_t streamSize = *((uint64_t *)ptr); + ptr += sizeof(uint64_t); + p_indexBlobs.push_back(SPTAG::ByteArray((std::uint8_t *)pdata, streamSize, false)); + pdata += streamSize; + } + + std::shared_ptr vecIndex; + std::string config((char *)p_config.Data(), p_config.Length()); + if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(config, p_indexBlobs, vecIndex) || + nullptr == vecIndex) + { + return AnnIndex(0); + } + return AnnIndex(vecIndex); +} + +AnnIndex AnnIndex::Load(const char *p_loaderFile) +{ + std::shared_ptr vecIndex; + auto ret = SPTAG::VectorIndex::LoadIndex(p_loaderFile, vecIndex); + if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) + { + return AnnIndex(0); + } + + return AnnIndex(vecIndex); +} -AnnIndex -AnnIndex::Merge(const char* p_indexFilePath1, const char* p_indexFilePath2) +AnnIndex AnnIndex::Merge(const char *p_indexFilePath1, const char *p_indexFilePath2) { std::shared_ptr vecIndex, addIndex; if (SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(p_indexFilePath1, vecIndex) || SPTAG::ErrorCode::Success != SPTAG::VectorIndex::LoadIndex(p_indexFilePath2, addIndex) || - SPTAG::ErrorCode::Success != vecIndex->MergeIndex(addIndex.get(), std::atoi(vecIndex->GetParameter("NumberOfThreads").c_str()), nullptr)) + SPTAG::ErrorCode::Success != + vecIndex->MergeIndex(addIndex.get(), std::atoi(vecIndex->GetParameter("NumberOfThreads").c_str()), nullptr)) return AnnIndex(0); return AnnIndex(vecIndex); diff --git a/bdev.json b/bdev.json new file mode 100644 index 000000000..5d970de6d --- /dev/null +++ b/bdev.json @@ -0,0 +1,17 @@ +{ + "subsystems": [ + { + "subsystem": "bdev", + "config": [ + { + "method": "bdev_nvme_attach_controller", + "params": { + "trtype": "pcie", + "name": "Nvme1", + "traddr": "0000:5c:00.0" + } + } + ] + } + ] +} diff --git a/build_spresh.sh b/build_spresh.sh new file mode 100755 index 000000000..72e087bf1 --- /dev/null +++ b/build_spresh.sh @@ -0,0 +1,468 @@ +#!/bin/bash -x + +if [ "$1" == "init_env" ]; then +sudo apt install -y cmake gcc-9 g++-9 libjemalloc-dev libsnappy-dev libgflags-dev pkg-config swig libboost-all-dev libtbb-dev libisal-dev gnuplot +git clone https://github.com/SPFresh/SPFresh +cd SPFresh +git submodule update --init --recursive +cd ThirdParty/spdk +sudo bash ./scripts/pkgdep.sh +CC=gcc-9 ./configure +CC=gcc-9 make -j +cd ../isal-l_crypto +./autogen.sh +./configure +make -j +cd .. +git clone https://github.com/PtilopsisL/rocksdb +cd rocksdb +mkdir build && cd build +cmake -DUSE_RTTI=1 -DWITH_JEMALLOC=1 -DWITH_SNAPPY=1 -DCMAKE_C_COMPILER=gcc-9 -DCMAKE_CXX_COMPILER=g++-9 -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-fPIC" .. +make -j +sudo make install +cd ../../.. +mkdir build && cd build +cmake -DCMAKE_BUILD_TYPE=Release -DGPU=OFF .. +# add #include in SPFresh/AnnService/inc/Helper/ConcurrentSet.h and SPFresh/AnnService/inc/Helper/Logging.h +make -j +cd .. + +cp Script_AE/bdev.json . +sudo nvme format /dev/nvme0n1 +sudo ./ThirdParty/spdk/scripts/setup.sh +# it will print out: +#1462:00:00.0 (1414 b111): nvme -> uio_pci_generic +#01ca:00:00.0 (1414 b111): nvme -> uio_pci_generic +# fill the 1462:00:00.0 into bdev.json and use the traddr in the PCI_ALLOWED=1462.00.00.0 in the run_update commend +else +echo "go to SPFresh Directory" +#cd SPFresh +fi + +cd Release +dataset="sift" +datatype='UInt8' +dim=128 +basefile="base.1B.u8bin" +queryfile="query.public.10K.u8bin" + +storage=FILEIO +ssdpath=/mnt_ssd/data2/cheqi/tmp/ +checkpointpath=/mnt_ssd/data1/cheqi/tmp/ +testscale="1m" +updateto="2m" +testscale_number=1000000 +updateto_number=2000000 +query_number=10000 +batch_size=10000 + +if [ "$1" == "create_dataset" ]; then +mkdir -p ${dataset}1b +cd ${dataset}1b +if [ "$dataset" == "sift" ]; then + echo "begin download $dataset..." + + if [ ! -f "$basefile" ]; then + wget https://dl.fbaipublicfiles.com/billion-scale-ann-benchmarks/bigann/$basefile + fi + if [ ! -f "$queryfile" ]; then + wget https://dl.fbaipublicfiles.com/billion-scale-ann-benchmarks/bigann/$queryfile + fi +else + #TODO: download spacev dataset + echo "not support $dataset..." +fi + +pip3 install numpy +echo 'import numpy as np +import argparse +import struct + +def process_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", help="The input file (.fvecs)") + parser.add_argument("--dst", help="The output file (.fvecs)") + parser.add_argument("--topk", type=int, help="The number of element to pick up") + return parser.parse_args() + + +if __name__ == "__main__": + args = process_args() + + # Read topk vector one by one + vecs = "" + row_bin = ""; + dim_bin = ""; + with open(args.src, "rb") as f: + + row_bin = f.read(4) + assert row_bin != b"" + row, = struct.unpack("i", row_bin) + + dim_bin = f.read(4) + assert dim_bin != b"" + dim, = struct.unpack("i", dim_bin) + + vecs = f.read(args.topk * dim) + + with open(args.dst, "wb") as f: + f.write(struct.pack("i", args.topk)) + f.write(dim_bin) + f.write(vecs) +' > generate_dataset.py +echo "python3 generate_dataset.py --src $basefile --dst $dataset.$testscale.bin --topk $testscale_number" +python3 generate_dataset.py --src $basefile --dst $dataset.$testscale.bin --topk $testscale_number +echo "python3 generate_dataset.py --src $basefile --dst $dataset.$updateto.bin --topk $updateto_number" +python3 generate_dataset.py --src $basefile --dst $dataset.$updateto.bin --topk $updateto_number + +toolpath=.. +setname="6c VectorPath=${dataset}${testscale}_update_set" +truthname="18c TruthPath=${dataset}${testscale}_update_truth" +deletesetname="${dataset}${testscale}_update_set" +reservesetname="${dataset}${testscale}_update_reserve" +currentsetname="${dataset}${testscale}_update_current" +echo "[Base] +ValueType=$datatype +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=$dim +VectorPath=$dataset.$testscale.bin +VectorType=DEFAULT +VectorSize=$testscale_number +VectorDelimiter= +QueryPath=$queryfile +QueryType=DEFAULT +QuerySize=$query_number +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=$query_number +WarmupDelimiter= +TruthPath=${dataset}${testscale}_truth +TruthType=DEFAULT +GenerateTruth=true + +[SearchSSDIndex] +ResultNum=100 +NumberOfThreads=16 +" > genTruth.ini +$toolpath/ssdserving genTruth.ini + +echo "[Base] +ValueType=$datatype +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=$dim +VectorPath=${dataset}${testscale}_update_set88 +VectorType=DEFAULT +VectorSize=$testscale_number +VectorDelimiter= +QueryPath=$queryfile +QueryType=DEFAULT +QuerySize=$query_number +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=$query_number +WarmupDelimiter= +TruthPath=${dataset}${testscale}_update_truth88 +TruthType=DEFAULT +GenerateTruth=true + +[SearchSSDIndex] +ResultNum=100 +NumberOfThreads=16" > genTruth.ini +for i in {0..99} +do + echo "start batch $i..." + $toolpath/usefultool --GenTrace true --vectortype $datatype --VectorPath $dataset.$updateto.bin --filetype DEFAULT --UpdateSize $batch_size --BaseNum $testscale_number --ReserveNum $testscale_number --CurrentListFileName ${dataset}${testscale}_update_current --ReserveListFileName ${dataset}${testscale}_update_reserve --TraceFileName ${dataset}${testscale}_update_trace --NewDataSetFileName ${dataset}${testscale}_update_set -d $dim --Batch $i -f DEFAULT + newsetname=$setname$i + newtruthname=$truthname$i + newdeletesetname=$deletesetname$i + newreservesetname=$reservesetname$i + newcurrentsetname=$currentsetname$i + sed -i "$newsetname" genTruth.ini + sed -i "$newtruthname" genTruth.ini + $toolpath/ssdserving genTruth.ini + $toolpath/usefultool --ConvertTruth true --vectortype $datatype --VectorPath $dataset.$updateto.bin --filetype DEFAULT --UpdateSize $batch_size --BaseNum $testscale_number --ReserveNum $testscale_number --CurrentListFileName ${dataset}${testscale}_update_current --ReserveListFileName ${dataset}${testscale}_update_reserve --TraceFileName ${dataset}${testscale}_update_trace --NewDataSetFileName ${dataset}${testscale}_update_set -d $dim --Batch $i -f DEFAULT --truthPath ${dataset}${testscale}_update_truth --truthType DEFAULT --querySize $query_number --resultNum 100 +done +cd .. +fi + +if [ "$1" == "build_index" ]; then +sudo rm -rf $ssdpath/* + +echo "[Base] +ValueType=$datatype +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=$dim +VectorPath=${dataset}1b/$dataset.$testscale.bin +VectorType=DEFAULT +VectorSize=$testscale_number +VectorDelimiter= +QueryPath=${dataset}1b/$queryfile +QueryType=DEFAULT +QuerySize=$query_number +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=$query_number +WarmupDelimiter= +TruthPath=${dataset}1b/${dataset}${testscale}_truth +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_$datatype\_L2_base_DEFUALT.bin +HeadVectors=head_vectors_$datatype\_L2_base_DEFUALT.bin +IndexDirectory=store_${dataset}${testscale}/ +HeadIndexFolder=head_index + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=16 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.1 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +NumberOfThreads=16 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=true +InternalResultNum=64 +NumberOfThreads=16 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=store_${dataset}${testscale}/tmpdir +Storage=${storage} +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=false +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +SteadyState=true +Days=100 +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +TruthFilePrefix=${dataset}1b/ +FullVectorPath=${dataset}1b/$dataset.$updateto.bin +DisableReassign=false +ReassignK=64 +LatencyLimit=50.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=true +MergeThreshold=10 +UpdateFilePrefix=${dataset}1b/${dataset}${testscale}_update_trace +DeleteQPS=800 +ShowUpdateProgress=false +Sampling=4 +BufferLength=6 +InPlace=true +LoadAllVectors=true +PersistentBufferPath=${checkpointpath}/bf +SsdInfoFile=${ssdpath}/postingSizeRecords +SpdkMappingPath=${ssdpath}/spdkmapping +EndVectorNum=2000000 + +[SearchSSDIndex] +isExecute=true +BuildSsdIndex=false +SearchThreadNum=2 +" > build_SPANN_store_${dataset}${testscale}.ini +./ssdserving build_SPANN_store_${dataset}${testscale}.ini +echo "[Index] +IndexAlgoType=SPANN +ValueType=$datatype + +[Base] +ValueType=$datatype +DistCalcMethod=L2 +IndexAlgoType=BKT +Dim=$dim +VectorPath=${dataset}1b/$dataset.$testscale.bin +VectorType=DEFAULT +VectorSize=$testscale_number +VectorDelimiter= +QueryPath=${dataset}1b/$queryfile +QueryType=DEFAULT +QuerySize=$query_number +QueryDelimiter= +WarmupPath= +WarmupType=DEFAULT +WarmupSize=$query_number +WarmupDelimiter= +TruthPath=${dataset}1b/ +TruthType=DEFAULT +GenerateTruth=false +HeadVectorIDs=head_vectors_ID_$datatype\_L2_base_DEFUALT.bin +HeadVectors=head_vectors_$datatype\_L2_base_DEFUALT.bin +IndexDirectory=store_${dataset}${testscale}/ +HeadIndexFolder=head_index + + +[SelectHead] +isExecute=false +TreeNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +SamplesNumber=1000 +NumberOfThreads=16 +SaveBKT=false +AnalyzeOnly=false +CalcStd=true +SelectDynamically=true +NoOutput=false +SelectThreshold=12 +SplitFactor=9 +SplitThreshold=18 +Ratio=0.15 +RecursiveCheckSmallCluster=true +PrintSizeCount=true + +[BuildHead] +isExecute=false +TreeFilePath=tree.bin +GraphFilePath=graph.bin +VectorFilePath=vectors.bin +DeleteVectorFilePath=deletes.bin +EnableBfs=0 +BKTNumber=1 +BKTKmeansK=32 +BKTLeafSize=8 +Samples=1000 +BKTLambdaFactor=100.000000 +TPTNumber=32 +TPTLeafSize=2000 +NumTopDimensionTpTreeSplit=5 +NeighborhoodSize=32 +GraphNeighborhoodScale=2.000000 +GraphCEFScale=2.000000 +RefineIterations=2 +EnableRebuild=0 +CEF=1000 +AddCEF=500 +MaxCheckForRefineGraph=8192 +RNGFactor=1.000000 +GPUGraphType=2 +GPURefineSteps=0 +GPURefineDepth=30 +GPULeafSize=500 +HeadNumGPUs=1 +TPTBalanceFactor=2 +NumberOfThreads=16 +DistCalcMethod=L2 +DeletePercentageForRefine=0.400000 +AddCountForRebuild=1000 +MaxCheck=4096 +ThresholdOfNumberOfContinuousNoBetterPropagation=3 +NumberOfInitialDynamicPivots=50 +NumberOfOtherDynamicPivots=4 +HashTableExponent=2 +DataBlockSize=1048576 +DataCapacity=2147483647 +MetaRecordSize=10 + +[BuildSSDIndex] +isExecute=true +BuildSsdIndex=false +NumberOfThreads=16 +InternalResultNum=64 +ReplicaCount=8 +PostingPageLimit=4 +OutputEmptyReplicaID=1 +TmpDir=store_${dataset}${testscale}/tmpdir +Storage=FILEIO +SpdkBatchSize=64 +ExcludeHead=false +UseDirectIO=false +ResultNum=10 +SearchInternalResultNum=64 +SearchThreadNum=2 +SearchTimes=1 +Update=true +SteadyState=true +Days=100 +InsertThreadNum=1 +AppendThreadNum=1 +ReassignThreadNum=0 +TruthFilePrefix=${dataset}1b/${dataset}${testscale}_update_truth_after +FullVectorPath=${dataset}1b/$dataset.$updateto.bin +DisableReassign=false +ReassignK=64 +LatencyLimit=50.0 +CalTruth=true +SearchPostingPageLimit=4 +MaxDistRatio=1000000 +SearchDuringUpdate=true +MergeThreshold=10 +UpdateFilePrefix=${dataset}1b/${dataset}${testscale}_update_trace +DeleteQPS=800 +ShowUpdateProgress=false +Sampling=4 +BufferLength=6 +InPlace=true +LoadAllVectors=true +PersistentBufferPath=${checkpointpath}/bf +SsdInfoFile=${ssdpath}/postingSizeRecords +SpdkMappingPath=${ssdpath}/spdkmapping +SearchResult=${dataset}1b/result_spfresh_balance +EndVectorNum=2000000" > store_${dataset}${testscale}/indexloader.ini +fi + +if [ "$1" == "run_update" ]; then +rm -rf ${dataset}1b/result_spfresh_balance* +#SPDK version +if [ "$storage" == "SPDKIO" ]; then + echo "Run SPDKIO..." + PCI_ALLOWED="1462:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=../bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 sudo -E ./spfresh store_${dataset}${testscale} |tee log_spfresh.log +else + echo "RUN FILEIO..." + PCI_ALLOWED="1462:00:00.0" SPFRESH_SPDK_USE_SSD_IMPL=1 SPFRESH_SPDK_CONF=../bdev.json SPFRESH_SPDK_BDEV=Nvme0n1 SPFRESH_FILE_IO_USE_CACHE=False SPFRESH_FILE_IO_THREAD_NUM=16 SPFRESH_FILE_IO_USE_LOCK=False SPFRESH_FILE_IO_LOCK_SIZE=262144 sudo -E ./spfresh store_${dataset}${testscale} |tee log_spfresh.log +fi +fi + +if [ "$1" == "plot_result" ]; then +cp ../Script_AE/Figure6/process_spfresh.py . +python3 process_spfresh.py log_spfresh.log overall_performance_${dataset}_spfresh_result.csv + +mkdir -p spfresh_result +cp -rf ${dataset}1b/result_spfresh_balance* spfresh_result + +resultnamePrefix=/spfresh_result/ +i=-1 +for FILE in `ls -v1 ./spfresh_result/` +do + if [ $i -eq -1 ]; + then + ./usefultool --CallRecall true --resultNum 10 --queryPath ${dataset}1b/$queryfile --searchResult $PWD$resultnamePrefix$FILE --truthType DEFAULT --truthPath ${dataset}1b/${dataset}${testscale}_truth --VectorPath ${dataset}1b/$dataset.$updateto.bin --vectortype $datatype -d $dim -f DEFAULT |tee log_spfresh_$i + else + ./usefultool --CallRecall true --resultNum 10 --queryPath ${dataset}1b/$queryfile --searchResult $PWD$resultnamePrefix$FILE --truthType DEFAULT --truthPath ${dataset}1b/${dataset}${testscale}_update_truth_after$i --VectorPath ${dataset}1b/$dataset.$updateto.bin --vectortype $datatype -d $dim -f DEFAULT |tee log_spfresh_$i + fi + let "i=i+1" +done +cp ../Script_AE/Figure6/OverallPerformance_merge_result.py . +cp ../Script_AE/Figure6/overall_performance_spacev_new.p . +python3 OverallPerformance_merge_result.py log_spfresh_ log_spfresh_ log_spfresh_ overall_performance_${dataset}_spfresh_result.csv overall_performance_${dataset}_spfresh_result.csv overall_performance_${dataset}_spfresh_result.csv +gnuplot overall_performance_spacev_new.p +fi diff --git a/docs/Tutorial.ipynb b/docs/Tutorial.ipynb index 3c423ab71..0cc1c0927 100644 --- a/docs/Tutorial.ipynb +++ b/docs/Tutorial.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -61,7 +61,7 @@ " 'vectors.bin']" ] }, - "execution_count": 25, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -87,16 +87,17 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[458, 474, 120, 431, 383, 686, 927, 522, 795, 167]\n", - "[10.889548301696777, 11.595512390136719, 11.68349838256836, 12.185091972351074, 12.233480453491211, 12.484166145324707, 12.542947769165039, 12.677164077758789, 12.810869216918945, 12.867197036743164]\n", - "[b'458\\n', b'474\\n', b'120\\n', b'431\\n', b'383\\n', b'686\\n', b'927\\n', b'522\\n', b'795\\n', b'167\\n']\n" + "[450, 247, 150, 832, 626, 828, 253, 471, 243, 365]\n", + "[11.603649139404297, 12.164414405822754, 12.225785255432129, 12.336504936218262, 12.416375160217285, 12.481977462768555, 12.573633193969727, 12.616722106933594, 12.708147048950195, 12.821012496948242]\n", + "[b'450\\n', b'247\\n', b'150\\n', b'832\\n', b'626\\n', b'828\\n', b'253\\n', b'471\\n', b'243\\n', b'365\\n']\n", + "[(450, 11.603649139404297, b'450\\n'), (247, 12.164414405822754, b'247\\n'), (150, 12.225785255432129, b'150\\n'), (832, 12.336504936218262, b'832\\n'), (626, 12.416375160217285, b'626\\n'), (828, 12.481977462768555, b'828\\n'), (253, 12.573633193969727, b'253\\n'), (471, 12.616722106933594, b'471\\n'), (243, 12.708147048950195, b'243\\n'), (365, 12.821012496948242, b'365\\n')]\n" ] } ], @@ -110,7 +111,23 @@ "result = index.SearchWithMetaData(q, 10) # Search k=3 nearest vectors for query vector q\n", "print (result[0]) # nearest k vector ids\n", "print (result[1]) # nearest k vector distances\n", - "print (result[2]) # nearest k vector metadatas" + "print (result[2]) # nearest k vector metadatas\n", + "\n", + "results = []\n", + "iterator = index.GetIterator(q)\n", + "while True:\n", + " result = iterator.Next(10)\n", + " #print (result[0]) # vector ids of the batch index scan candidates\n", + " #print (result[1]) # vector distances of the batch index scan candidates\n", + " #print (result[2]) # vector metadatas of the batch index scan candidates\n", + " #print (result[3]) # relaxed monotonicity of the batch index scan candidates\n", + " for j in range(len(result[0])):\n", + " results.append((result[0][j], result[1][j], result[2][j]))\n", + " if len(result[0]) < 10 or result[3][0] == True: break\n", + "iterator.Close()\n", + "\n", + "results.sort(key = lambda item: item[1])\n", + "print (results[0:10])\n" ] }, { @@ -122,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -160,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -176,7 +193,7 @@ " 'vectors.bin']" ] }, - "execution_count": 28, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -202,16 +219,16 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[458, 120, 474, 431, 383, 927, 686, 795, 522, 322, 167, 69]\n", - "[10.78742790222168, 11.62553882598877, 11.796710014343262, 12.205741882324219, 12.333723068237305, 12.543768882751465, 12.558581352233887, 12.744610786437988, 12.799510955810547, 12.994494438171387, 13.008194923400879, 13.038790702819824]\n", - "[b'458\\n', b'120\\n', b'474\\n', b'431\\n', b'383\\n', b'927\\n', b'686\\n', b'795\\n', b'522\\n', b'322\\n', b'167\\n', b'69\\n']\n" + "[450, 247, 832, 626, 150, 253, 471, 828, 243, 207, 365, 823]\n", + "[11.497368812561035, 12.21573257446289, 12.328060150146484, 12.369790077209473, 12.420766830444336, 12.47612476348877, 12.48221206665039, 12.526467323303223, 12.802266120910645, 12.819355010986328, 12.876933097839355, 12.913871765136719]\n", + "[b'450\\n', b'247\\n', b'832\\n', b'626\\n', b'150\\n', b'253\\n', b'471\\n', b'828\\n', b'243\\n', b'207\\n', b'365\\n', b'823\\n']\n" ] } ], @@ -227,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -242,7 +259,7 @@ " 'SPTAGHeadVectors.bin']" ] }, - "execution_count": 30, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -284,16 +301,16 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[458, 474, 120, 431, 383, 686, 927, 522, 795, 167, 457, 322]\n", - "[10.889548301696777, 11.595512390136719, 11.68349838256836, 12.185091972351074, 12.233480453491211, 12.484166145324707, 12.542947769165039, 12.677164077758789, 12.810869216918945, 12.867197036743164, 12.98002815246582, 13.027528762817383]\n", - "[b'458\\n', b'474\\n', b'120\\n', b'431\\n', b'383\\n', b'686\\n', b'927\\n', b'522\\n', b'795\\n', b'167\\n', b'457\\n', b'322\\n']\n" + "[450, 247, 150, 832, 626, 828, 253, 471, 243, 365, 207, 823]\n", + "[11.603649139404297, 12.164414405822754, 12.225785255432129, 12.336504936218262, 12.416375160217285, 12.481977462768555, 12.573633193969727, 12.616722106933594, 12.708147048950195, 12.821012496948242, 12.878414154052734, 13.01927375793457]\n", + "[b'450\\n', b'247\\n', b'150\\n', b'832\\n', b'626\\n', b'828\\n', b'253\\n', b'471\\n', b'243\\n', b'365\\n', b'207\\n', b'823\\n']\n" ] } ], @@ -307,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -323,7 +340,7 @@ " 'SPTAGHeadVectors.bin']" ] }, - "execution_count": 32, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -366,16 +383,16 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[458, 120, 474, 431, 383, 927, 686, 795, 522, 322, 167, 69]\n", - "[10.78742790222168, 11.62553882598877, 11.796710014343262, 12.205741882324219, 12.333723068237305, 12.543768882751465, 12.558581352233887, 12.744610786437988, 12.799510955810547, 12.994494438171387, 13.008194923400879, 13.038790702819824]\n", - "[b'458\\n', b'120\\n', b'474\\n', b'431\\n', b'383\\n', b'927\\n', b'686\\n', b'795\\n', b'522\\n', b'322\\n', b'167\\n', b'69\\n']\n" + "[450, 247, 832, 626, 150, 253, 471, 828, 243, 207, 365, 823]\n", + "[11.497368812561035, 12.21573257446289, 12.328060150146484, 12.369790077209473, 12.420766830444336, 12.47612476348877, 12.48221206665039, 12.526467323303223, 12.802266120910645, 12.819355010986328, 12.876933097839355, 12.913871765136719]\n", + "[b'450\\n', b'247\\n', b'832\\n', b'626\\n', b'150\\n', b'253\\n', b'471\\n', b'828\\n', b'243\\n', b'207\\n', b'365\\n', b'823\\n']\n" ] } ], @@ -510,7 +527,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.9.3" }, "mimetype": "text/x-python", "name": "python", diff --git a/docs/WindowsInstallation.md b/docs/WindowsInstallation.md index f5775eb16..70c276e60 100644 --- a/docs/WindowsInstallation.md +++ b/docs/WindowsInstallation.md @@ -22,6 +22,10 @@ Boost's precompiled binaries are available at https://sourceforge.net/projects/b 2. Launch the installation 3. Add the folder path to PATH environment variable, for instance `C:\Sptag\boost_1_67_0\` +On the off chance of needing to build boost, because you encountered this message during the Build step `Could NOT find Boost (missing: system thread serialization wserialization regex filesystem)`, try the following +1. Locate and run `bootstrap.bat`, this creates a file called `b2.exe` +2. Run the generated exe (in this case `b2.exe`), be patient this may take +1 hr +3. Once complete, the built libraries will be located in `\stage\lib` (or something along the lines) ## Build diff --git a/docs/examples/requirements.txt b/docs/examples/requirements.txt index f333acfe6..2aef7a37e 100644 --- a/docs/examples/requirements.txt +++ b/docs/examples/requirements.txt @@ -1,5 +1,5 @@ -numpy==1.21.0 +numpy==1.24.2 matplotlib==2.2.2 -Keras==2.2.4 -Pillow==9.0.1 -scikit_learn==0.21.3 +Keras==2.11.0 +Pillow==10.0.1 +scikit_learn==0.24.2 diff --git a/setup.py b/setup.py index 69d435028..a601a46ad 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,11 @@ python_version = "%d.%d" % (sys.version_info.major, sys.version_info.minor) print ("Python version:%s" % python_version) +if 'bdist_wheel' in sys.argv: + if not any(arg.startswith('--python-tag') for arg in sys.argv): + sys.argv.extend(['--python-tag', 'py%d%d'%(sys.version_info.major, sys.version_info.minor)]) +print (sys.argv) + def _setup(): setuptools.setup( name = 'sptag', @@ -59,7 +64,7 @@ def _setup(): ], packages = _find_python_packages(), - python_requires = '>=3.7', + python_requires = '>=3.6', install_requires = ['numpy'], cmdclass = { @@ -77,12 +82,21 @@ def _find_python_packages(): elif os.path.exists(os.path.join('x64', 'Release')): shutil.copytree(os.path.join('x64', 'Release'), 'sptag') - if os.path.exists('lib'): shutil.rmtree('lib') - os.mkdir('lib') - os.mkdir(os.path.join('lib', 'net472')) + if not os.path.exists('lib'): os.mkdir('lib') + if not os.path.exists('lib\\net472'): os.mkdir('lib\\net472') for file in glob.glob(r'x64\\Release\\Microsoft.ANN.SPTAGManaged.*'): print (file) shutil.copy(file, "lib\\net472\\") + sfiles = '' + for framework in ['net5.0', 'net472']: + if os.path.exists("lib\\%s\\Microsoft.ANN.SPTAGManaged.dll" % framework): + sfiles += '' % (framework, framework) + sfiles += '' % (framework, framework) + if os.path.exists('x64\\Release\\libzstd.dll'): + sfiles += '' + if os.path.exists('x64\\Release\\Ijwhost.dll'): + sfiles += '' + sfiles += '' f = open('sptag.nuspec', 'w') spec = ''' @@ -99,13 +113,48 @@ def _find_python_packages(): Copyright @ Microsoft - - +%s -''' % (nuget_release) +''' % (nuget_release, sfiles) f.write(spec) f.close() + + fwinrt = open('sptag.winrt.nuspec', 'w') + spec = ''' + + + MSSPTAG.WinRT + %s + MSSPTAG.WinRT + cheqi,haidwa,mingqli + cheqi,haidwa,mingqli + false + https://github.com/microsoft/SPTAG + https://github.com/microsoft/SPTAG + SPTAG (Space Partition Tree And Graph) is a library for large scale vector approximate nearest neighbor search scenario released by Microsoft Research (MSR) and Microsoft Bing. + Copyright @ Microsoft + + + + + + + + + + + + + + + + + +''' % (nuget_release) + fwinrt.write(spec) + fwinrt.close() + f = open(os.path.join('sptag', '__init__.py'), 'w') f.close() return ['sptag'] @@ -136,4 +185,4 @@ def run(self): shutil.rmtree('dist', ignore_errors=True) if __name__ == '__main__': - _setup() \ No newline at end of file + _setup()