-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add binary format support with IVF method in Faiss Engine #1784
base: feature/binary-format
Are you sure you want to change the base?
Add binary format support with IVF method in Faiss Engine #1784
Conversation
dafd79b
to
a913082
Compare
22d044a
to
629757b
Compare
f726d81
to
648f342
Compare
floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new) | ||
) | ||
); | ||
long memoryAddress = trainingDataAllocation.getMemoryAddress(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a unit test to validate the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack
src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java
Outdated
Show resolved
Hide resolved
src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java
Show resolved
Hide resolved
@@ -248,6 +263,28 @@ public static ValidationException validateKnnField( | |||
return exception; | |||
} | |||
|
|||
// Return if vector data type does not need to be checked | |||
if (expectedVectorDataType == null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to write a unit test. You may find an issue with current implementation.
It is okay to return single exception but the issue here is you are returning when there is no exception. Then the rest of the validation code won't work.
Also, please check if your code works when expectedDimension is < 0 or null. In such case, your validation won't get executed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have Unit test for all c++ functions written.
src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
Show resolved
Hide resolved
long memoryAddress = trainingDataAllocation.getMemoryAddress(); | ||
|
||
if (IndexUtil.isBinaryIndex(this.vectorDataType)) { | ||
byte[][] byteArray = floats.stream() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to do this conversion , if it's Binary Index then training would be also on Byte array right , why customer will provide Float array to train Binary Index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored to have FloatTrainingDataConsumer and ByteTrainingDataConsumer to handle train data consuming separately.
a1fe81b
to
93fef1e
Compare
will include jni tests along with refactor work |
|
||
private boolean isBinaryField(FieldInfo field) { | ||
if (field.attributes().containsKey(MODEL_ID)) { | ||
Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model can be quite big because it contains the binary blob. Can we read from model metadata instead via ModelDao?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated code to use model only once in the addKNNBinaryField function
} | ||
|
||
@Override | ||
public List<byte[]> getTotalAddedVectors() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this method is called below. Where the implementation is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method is used in TestFloatTrainingDataConsumer
// create test float training data consumer class extending FloatTrainingDataConsumer
private static class TestFloatTrainingDataConsumer extends FloatTrainingDataConsumer {
@Getter
private List<Float[]> totalAddedVectors = new ArrayList<>();
public TestFloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
super(trainingDataAllocation);
}
@Override
public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
List<Float[]> vectors = extractVectorsFromHits(searchResponse, vectorsToAdd, fieldName);
totalAddedVectors.addAll(vectors);
setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size());
accept(vectors);
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is wired to have a method only to be used in testing... Can we do differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried but couldn't find other way to read the whole vectors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to. For unit test, all you need to verify is that JNILayer method is called properly with expected parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
} | ||
|
||
@Override | ||
public void accept(List<byte[]> byteVectors) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
previously implemented from Consumer, which are public, now removed parent implement from Consumer and changed it to protected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be even private. I don't see it being used outside of this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accept method is used in some tests, but I removed other methods.
} | ||
|
||
@Override | ||
public List<byte[]> extractVectorsFromHits(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
} | ||
|
||
@Override | ||
public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think only this method is needs to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
*/ | ||
public class TrainingDataConsumer implements Consumer<List<Float[]>> { | ||
public abstract class TrainingDataConsumer<T> implements Consumer<List<T>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it have to implement Consumer?
I think accept
method is called only inside processTrainingVectors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated without implement from Consumer
*/ | ||
public class TrainingDataConsumer implements Consumer<List<Float[]>> { | ||
public abstract class TrainingDataConsumer<T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we don't need T
here.
I see only public abstract void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName); is actually used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new) | ||
) | ||
); | ||
protected abstract void accept(List<T> vectors); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove accept, extractVectorsFromHits, getTotalAddedVectors which is not being called outside of this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed extractVectorsFromHits, getTotalAddedVectors and kept accept to expose for some tests
Signed-off-by: Junqiu Lei <[email protected]>
Resolved comments in PR heemin32#2, which was friendly to check the file diffs when #1781 wasn't merged. Because #1781 now is merged, I rebased junqiu-lei:binary-ivf against opensearch-project:feature/binary-format
Description
This PR will support using binary format with Faiss IVF method, it mainly have changes:
data_type
field when train modelJNI layer related refactor works will be complete in another PR tracked by #1846
Example workflow
1. Create binary format train index
2. Ingest tran index
3. Create train model
4. Create IVF binary format target index
5. Bulk target index
6. Query target index
7. Query result
Issues Resolved
part of #1767
Check List
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.