Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binary format support with IVF method in Faiss Engine #1784

Open
wants to merge 1 commit into
base: feature/binary-format
Choose a base branch
from

Conversation

junqiu-lei
Copy link
Member

@junqiu-lei junqiu-lei commented Jul 2, 2024

Resolved comments in PR heemin32#2, which was friendly to check the file diffs when #1781 wasn't merged. Because #1781 now is merged, I rebased junqiu-lei:binary-ivf against opensearch-project:feature/binary-format

Description

This PR will support using binary format with Faiss IVF method, it mainly have changes:

  • introduce data_type field when train model
  • support train model with binary format index
  • create binary format IVF target index with with binary format model
  • query IVF index with binary format

JNI layer related refactor works will be complete in another PR tracked by #1846

Example workflow

1. Create binary format train index

PUT /train-index HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 481

{
  "settings": {
    "number_of_shards": 3,
    "number_of_replicas": 0,
    "index.knn": true
  },
  "mappings": {
    "properties": {
      "train-field": {
        "type": "knn_vector",
        "dimension": 8,
        "data_type": "binary",
        "method": {
          "name": "hnsw",
          "space_type": "hammingbit",
          "engine": "faiss",
          "parameters": {
            "ef_construction": 128,
            "m": 24
          }
        }
      }
    }
  }
}

2. Ingest tran index

POST /_bulk HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 7067

{ "index": { "_index": "train-index", "_id": "1" } }
{ "train-field": [1]}
{ "index": { "_index": "train-index", "_id": "2" } }
{ "train-field": [2]}
{ "index": { "_index": "train-index", "_id": "3" } }
{ "train-field": [4]}
{ "index": { "_index": "train-index", "_id": "4" } }
........

3. Create train model

POST /_plugins/_knn/models/my-model/_train HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 313

{
  "training_index": "train-index",
  "training_field": "train-field",
  "dimension": 8,
  "description": "My model description",
  "data_type": "binary",
  "method": {
    "name": "ivf",
    "engine": "faiss",
    "space_type": "hammingbit",
    "parameters": {
      "nlist": 4,
      "nprobes":2 
    }
  }
}

4. Create IVF binary format target index

PUT /target-index HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 242

{
  "settings": {
    "number_of_shards": 1,
    "number_of_replicas": 1,
    "index.knn": true
  },
  "mappings": {
    "properties": {
      "target-field": {
        "type": "knn_vector",
        "model_id": "my-model"
      }
    }
  }
}

5. Bulk target index

POST /_bulk HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 931

{ "index": { "_index": "target-index", "_id": "1" } }
{ "target-field": [2]}
{ "index": { "_index": "target-index", "_id": "2" } }
{ "target-field": [3]}
.......

6. Query target index

GET /target-index/_search HTTP/1.1
Host: localhost:9200
Content-Type: application/json
Content-Length: 110

{
  "query": {
    "knn": {
      "target-field": {
        "vector": [10],
        "k": 5
      }
    }
  }
}

7. Query result

{
    "took": 2898,
    "timed_out": false,
    "_shards": {
        "total": 1,
        "successful": 1,
        "skipped": 0,
        "failed": 0
    },
    "hits": {
        "total": {
            "value": 5,
            "relation": "eq"
        },
        "max_score": 1.0,
        "hits": [
            {
                "_index": "target-index",
                "_id": "9",
                "_score": 1.0,
                "_source": {
                    "target-field": [
                        10
                    ]
                }
            },
            {
                "_index": "target-index",
                "_id": "1",
                "_score": 0.5,
                "_source": {
                    "target-field": [
                        2
                    ]
                }
            },
            {
                "_index": "target-index",
                "_id": "7",
                "_score": 0.5,
                "_source": {
                    "target-field": [
                        8
                    ]
                }
            },
            {
                "_index": "target-index",
                "_id": "10",
                "_score": 0.5,
                "_source": {
                    "target-field": [
                        11
                    ]
                }
            },
            {
                "_index": "target-index",
                "_id": "2",
                "_score": 0.33333334,
                "_source": {
                    "target-field": [
                        3
                    ]
                }
            }
        ]
    }
}

Issues Resolved

part of #1767

Check List

  • New functionality includes testing.
    • All tests pass
  • New functionality has been documented.
    • New functionality has javadoc added
  • Commits are signed as per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@junqiu-lei junqiu-lei added Features Introduces a new unit of functionality that satisfies a requirement feature branch labels Jul 2, 2024
@heemin32 heemin32 force-pushed the feature/binary-format branch 8 times, most recently from dafd79b to a913082 Compare July 3, 2024 16:40
@junqiu-lei junqiu-lei force-pushed the binary-ivf branch 2 times, most recently from 22d044a to 629757b Compare July 9, 2024 00:07
@junqiu-lei junqiu-lei changed the title Add binary format support for Faiss IVF train model and create index api Add binary format support for Faiss IVF Jul 9, 2024
@junqiu-lei junqiu-lei force-pushed the binary-ivf branch 2 times, most recently from f726d81 to 648f342 Compare July 9, 2024 00:49
@junqiu-lei junqiu-lei changed the title Add binary format support for Faiss IVF Add binary format support with IVF method in Faiss Engine Jul 9, 2024
@junqiu-lei junqiu-lei marked this pull request as ready for review July 10, 2024 17:31
@junqiu-lei junqiu-lei self-assigned this Jul 17, 2024
@junqiu-lei junqiu-lei requested a review from heemin32 July 17, 2024 15:59
floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new)
)
);
long memoryAddress = trainingDataAllocation.getMemoryAddress();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test to validate the code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

@@ -248,6 +263,28 @@ public static ValidationException validateKnnField(
return exception;
}

// Return if vector data type does not need to be checked
if (expectedVectorDataType == null) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to write a unit test. You may find an issue with current implementation.
It is okay to return single exception but the issue here is you are returning when there is no exception. Then the rest of the validation code won't work.
Also, please check if your code works when expectedDimension is < 0 or null. In such case, your validation won't get executed.

Copy link

@Vikasht34 Vikasht34 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have Unit test for all c++ functions written.

jni/src/faiss_wrapper.cpp Show resolved Hide resolved
jni/src/faiss_wrapper.cpp Show resolved Hide resolved
jni/src/faiss_wrapper.cpp Show resolved Hide resolved
jni/src/faiss_wrapper.cpp Show resolved Hide resolved
src/main/java/org/opensearch/knn/index/IndexUtil.java Outdated Show resolved Hide resolved
src/main/java/org/opensearch/knn/index/IndexUtil.java Outdated Show resolved Hide resolved
long memoryAddress = trainingDataAllocation.getMemoryAddress();

if (IndexUtil.isBinaryIndex(this.vectorDataType)) {
byte[][] byteArray = floats.stream()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to do this conversion , if it's Binary Index then training would be also on Byte array right , why customer will provide Float array to train Binary Index?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to have FloatTrainingDataConsumer and ByteTrainingDataConsumer to handle train data consuming separately.

@junqiu-lei
Copy link
Member Author

junqiu-lei commented Jul 18, 2024

Can we have Unit test for all c++ functions written.

will include jni tests along with refactor work


private boolean isBinaryField(FieldInfo field) {
if (field.attributes().containsKey(MODEL_ID)) {
Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model can be quite big because it contains the binary blob. Can we read from model metadata instead via ModelDao?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code to use model only once in the addKNNBinaryField function

}

@Override
public List<byte[]> getTotalAddedVectors() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this method is called below. Where the implementation is?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is used in TestFloatTrainingDataConsumer

    // create test float training data consumer class extending FloatTrainingDataConsumer
    private static class TestFloatTrainingDataConsumer extends FloatTrainingDataConsumer {
        @Getter
        private List<Float[]> totalAddedVectors = new ArrayList<>();

        public TestFloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
            super(trainingDataAllocation);
        }

        @Override
        public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
            List<Float[]> vectors = extractVectorsFromHits(searchResponse, vectorsToAdd, fieldName);
            totalAddedVectors.addAll(vectors);
            setTotalVectorsCountAdded(getTotalVectorsCountAdded() + vectors.size());
            accept(vectors);
        }
    }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is wired to have a method only to be used in testing... Can we do differently?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but couldn't find other way to read the whole vectors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to. For unit test, all you need to verify is that JNILayer method is called properly with expected parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

}

@Override
public void accept(List<byte[]> byteVectors) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously implemented from Consumer, which are public, now removed parent implement from Consumer and changed it to protected

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be even private. I don't see it being used outside of this class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accept method is used in some tests, but I removed other methods.

}

@Override
public List<byte[]> extractVectorsFromHits(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

}

@Override
public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only this method is needs to be public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

*/
public class TrainingDataConsumer implements Consumer<List<Float[]>> {
public abstract class TrainingDataConsumer<T> implements Consumer<List<T>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to implement Consumer?
I think accept method is called only inside processTrainingVectors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated without implement from Consumer

*/
public class TrainingDataConsumer implements Consumer<List<Float[]>> {
public abstract class TrainingDataConsumer<T> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we don't need T here.
I see only public abstract void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName); is actually used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

floats.stream().map(ArrayUtils::toPrimitive).toArray(float[][]::new)
)
);
protected abstract void accept(List<T> vectors);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove accept, extractVectorsFromHits, getTotalAddedVectors which is not being called outside of this class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed extractVectorsFromHits, getTotalAddedVectors and kept accept to expose for some tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature branch Features Introduces a new unit of functionality that satisfies a requirement v2.16.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants