Skip to content

Commit

Permalink
Refactored and Fixed Parallel Upsert in PostgreSQL Embeddings (#213)
Browse files Browse the repository at this point in the history
* Fixing Issues

* Refactoring Code

* Fixing Issues
  • Loading branch information
EmadHanif01 committed Sep 3, 2023
1 parent 1ce01cd commit 7b6daec
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,9 @@
@Service
public class BgeSmallClient {

private BgeSmallEndpoint endpoint;

private static volatile ZooModel<String, float[]> bgeSmallEn;

public BgeSmallEndpoint getEndpoint() {
return endpoint;
}

public void setEndpoint(BgeSmallEndpoint endpoint) {
this.endpoint = endpoint;
}

public EdgeChain<BgeSmallResponse> createEmbeddings(String input) {
public EdgeChain<BgeSmallResponse> createEmbeddings(String input, BgeSmallEndpoint endpoint) {

return new EdgeChain<>(
Observable.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,24 @@
@Service
public class MiniLMClient {

private MiniLMEndpoint endpoint;

private static volatile ZooModel<String, float[]> allMiniL6V2;
private static volatile ZooModel<String, float[]> allMiniL12V2;

private static volatile ZooModel<String, float[]> paraphraseMiniLML3v2;

private static volatile ZooModel<String, float[]> multiQAMiniLML6CosV1;

public MiniLMEndpoint getEndpoint() {
return endpoint;
}

public void setEndpoint(MiniLMEndpoint endpoint) {
this.endpoint = endpoint;
}

public EdgeChain<MiniLMResponse> createEmbeddings(String input) {
public EdgeChain<MiniLMResponse> createEmbeddings(String input, MiniLMEndpoint endpoint) {

return new EdgeChain<>(
Observable.create(
emitter -> {
try {

if (this.endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L6_V2)) {
if (endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L6_V2)) {

Predictor<String, float[]> predictor =
loadAllMiniL6V2(this.endpoint.getMiniLMModel()).newPredictor();
loadAllMiniL6V2(endpoint.getMiniLMModel()).newPredictor();

float[] predict = predictor.predict(input);

Expand All @@ -59,10 +49,10 @@ public EdgeChain<MiniLMResponse> createEmbeddings(String input) {

emitter.onNext(new MiniLMResponse(floatList));
emitter.onComplete();
} else if (this.endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L12_V2)) {
} else if (endpoint.getMiniLMModel().equals(MiniLMModel.ALL_MINILM_L12_V2)) {

Predictor<String, float[]> predictor =
loadAllMiniL12V2(this.endpoint.getMiniLMModel()).newPredictor();
loadAllMiniL12V2(endpoint.getMiniLMModel()).newPredictor();

float[] predict = predictor.predict(input);

Expand All @@ -73,11 +63,9 @@ public EdgeChain<MiniLMResponse> createEmbeddings(String input) {

emitter.onNext(new MiniLMResponse(floatList));
emitter.onComplete();
} else if (this.endpoint
.getMiniLMModel()
.equals(MiniLMModel.PARAPHRASE_MINILM_L3_V2)) {
} else if (endpoint.getMiniLMModel().equals(MiniLMModel.PARAPHRASE_MINILM_L3_V2)) {
Predictor<String, float[]> predictor =
loadParaphraseMiniLML3v2(this.endpoint.getMiniLMModel()).newPredictor();
loadParaphraseMiniLML3v2(endpoint.getMiniLMModel()).newPredictor();

float[] predict = predictor.predict(input);

Expand All @@ -92,7 +80,7 @@ public EdgeChain<MiniLMResponse> createEmbeddings(String input) {

System.out.println("d");
ZooModel<String, float[]> zooModel =
loadMultiQAMiniLML6CosV1(this.endpoint.getMiniLMModel());
loadMultiQAMiniLML6CosV1(endpoint.getMiniLMModel());

Predictor<String, float[]> predictor = zooModel.newPredictor();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,23 @@ public void createTable(PostgresEndpoint postgresEndpoint) {
}

public List<String> batchInsertMetadata(String metadataTableName, List<String> metadataList) {
List<String> uuidList = new ArrayList<>();

String[] sql = new String[metadataList.size()];
Set<String> uuidSet = new HashSet<>();

for (int i = 0; i < metadataList.size(); i++) {
UUID uuid = UuidCreator.getTimeOrderedEpoch();

sql[i] =
String.format(
"INSERT INTO %s (metadata_id, metadata) VALUES ('%s', '%s');",
metadataTableName, uuid, metadataList.get(i));
uuidList.add(uuid.toString());
UUID metadataId =
jdbcTemplate.queryForObject(
String.format(
"INSERT INTO %s (metadata_id, metadata) VALUES ('%s', '%s') RETURNING metadata_id;",
metadataTableName, UuidCreator.getTimeOrderedEpoch(), metadataList.get(i)),
UUID.class);

if (metadataId != null) {
uuidSet.add(metadataId.toString());
}
}
jdbcTemplate.batchUpdate(sql);

return uuidList;
return new ArrayList<>(uuidSet);
}

@Transactional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,37 @@ public List<String> batchUpsertEmbeddings(
String filename,
String namespace) {

List<String> uuidList = new ArrayList<>();
Set<String> uuidSet = new HashSet<>();

for (int i = 0; i < wordEmbeddingsList.size(); i++) {

UUID id =
jdbcTemplate.queryForObject(
String.format(
"INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename) VALUES"
+ " ('%s', '%s', '%s', '%s', '%s', '%s') ON CONFLICT (raw_text) DO UPDATE"
+ " SET embedding = EXCLUDED.embedding RETURNING id;",
tableName,
UuidCreator.getTimeOrderedEpoch(),
wordEmbeddingsList.get(i).getId(),
Arrays.toString(FloatUtils.toFloatArray(wordEmbeddingsList.get(i).getValues())),
LocalDateTime.now(),
namespace,
filename),
UUID.class);

if (Objects.nonNull(id)) {
uuidList.add(id.toString());
WordEmbeddings wordEmbeddings = wordEmbeddingsList.get(i);

if (wordEmbeddings != null && wordEmbeddings.getValues() != null) {

float[] floatArray = FloatUtils.toFloatArray(wordEmbeddings.getValues());

UUID id =
jdbcTemplate.queryForObject(
String.format(
"INSERT INTO %s (id, raw_text, embedding, timestamp, namespace, filename) VALUES"
+ " ('%s', '%s', '%s', '%s', '%s', '%s') ON CONFLICT (raw_text) DO UPDATE"
+ " SET embedding = EXCLUDED.embedding RETURNING id;",
tableName,
UuidCreator.getTimeOrderedEpoch(),
wordEmbeddings.getId(),
Arrays.toString(floatArray),
LocalDateTime.now(),
namespace,
filename),
UUID.class);

if (id != null) {
uuidSet.add(id.toString());
}
}
}

return uuidList;
return new ArrayList<>(uuidSet);
}

@Transactional
Expand Down Expand Up @@ -155,7 +161,7 @@ public List<Map<String, Object>> query(
tableName,
namespace,
PostgresDistanceMetric.getDistanceMetric(metric),
Arrays.toString(FloatUtils.toFloatArray(values)),
embeddings,
topK));

} else if (metric.equals(PostgresDistanceMetric.COSINE)) {
Expand All @@ -169,7 +175,7 @@ public List<Map<String, Object>> query(
tableName,
namespace,
PostgresDistanceMetric.getDistanceMetric(metric),
Arrays.toString(FloatUtils.toFloatArray(values)),
embeddings,
topK));
} else {
return jdbcTemplate.queryForList(
Expand All @@ -181,7 +187,7 @@ public List<Map<String, Object>> query(
tableName,
namespace,
PostgresDistanceMetric.getDistanceMetric(metric),
Arrays.toString(FloatUtils.toFloatArray(values)),
embeddings,
topK));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,8 @@ public class OpenAiClient {
private final Logger logger = LoggerFactory.getLogger(getClass());
private final RestTemplate restTemplate = new RestTemplate();

private OpenAiEndpoint endpoint;

public OpenAiEndpoint getEndpoint() {
return endpoint;
}

public void setEndpoint(OpenAiEndpoint endpoint) {
this.endpoint = endpoint;
}

public EdgeChain<ChatCompletionResponse> createChatCompletion(ChatCompletionRequest request) {
public EdgeChain<ChatCompletionResponse> createChatCompletion(
ChatCompletionRequest request, OpenAiEndpoint endpoint) {

return new EdgeChain<>(
Observable.create(
Expand Down Expand Up @@ -74,7 +65,7 @@ public EdgeChain<ChatCompletionResponse> createChatCompletion(ChatCompletionRequ
}

public EdgeChain<ChatCompletionResponse> createChatCompletionStream(
ChatCompletionRequest request) {
ChatCompletionRequest request, OpenAiEndpoint endpoint) {

try {
logger.info("Logging ChatCompletion Stream....");
Expand Down Expand Up @@ -105,7 +96,8 @@ public EdgeChain<ChatCompletionResponse> createChatCompletionStream(
}
}

public EdgeChain<CompletionResponse> createCompletion(CompletionRequest request) {
public EdgeChain<CompletionResponse> createCompletion(
CompletionRequest request, OpenAiEndpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
Expand All @@ -131,7 +123,8 @@ public EdgeChain<CompletionResponse> createCompletion(CompletionRequest request)
endpoint);
}

public EdgeChain<OpenAiEmbeddingResponse> createEmbeddings(OpenAiEmbeddingRequest request) {
public EdgeChain<OpenAiEmbeddingResponse> createEmbeddings(
OpenAiEmbeddingRequest request, OpenAiEndpoint endpoint) {
return new EdgeChain<>(
Observable.create(
emitter -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ public class BgeSmallController {
@PostMapping
public Single<BgeSmallResponse> embeddings(@RequestBody BgeSmallEndpoint bgeSmallEndpoint) {

this.bgeSmallClient.setEndpoint(bgeSmallEndpoint);

EdgeChain<BgeSmallResponse> edgeChain =
this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getRawText());
this.bgeSmallClient.createEmbeddings(bgeSmallEndpoint.getRawText(), bgeSmallEndpoint);

if (Objects.nonNull(env.getProperty("postgres.db.host"))) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ public class MiniLMController {
@PostMapping
public Single<MiniLMResponse> embeddings(@RequestBody MiniLMEndpoint miniLMEndpoint) {

this.miniLMClient.setEndpoint(miniLMEndpoint);

EdgeChain<MiniLMResponse> edgeChain =
this.miniLMClient.createEmbeddings(miniLMEndpoint.getRawText());
this.miniLMClient.createEmbeddings(miniLMEndpoint.getRawText(), miniLMEndpoint);

if (Objects.nonNull(env.getProperty("postgres.db.host"))) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ public Single<ChatCompletionResponse> chatCompletion(@RequestBody OpenAiEndpoint
.user(openAiEndpoint.getUser())
.build();

this.openAiClient.setEndpoint(openAiEndpoint);

EdgeChain<ChatCompletionResponse> edgeChain =
openAiClient.createChatCompletion(chatCompletionRequest);
openAiClient.createChatCompletion(chatCompletionRequest, openAiEndpoint);

if (Objects.nonNull(env.getProperty("postgres.db.host"))) {

Expand Down Expand Up @@ -140,17 +138,14 @@ public SseEmitter chatCompletionStream(@RequestBody OpenAiEndpoint openAiEndpoin
.logitBias(openAiEndpoint.getLogitBias())
.user(openAiEndpoint.getUser())
.build();

this.openAiClient.setEndpoint(openAiEndpoint);

SseEmitter emitter = new SseEmitter();
ExecutorService executorService = Executors.newSingleThreadExecutor();

executorService.execute(
() -> {
try {
EdgeChain<ChatCompletionResponse> edgeChain =
openAiClient.createChatCompletionStream(chatCompletionRequest);
openAiClient.createChatCompletionStream(chatCompletionRequest, openAiEndpoint);

AtomInteger chunks = AtomInteger.of(0);

Expand Down Expand Up @@ -262,9 +257,8 @@ public Single<CompletionResponse> completion(@RequestBody OpenAiEndpoint openAiE
.temperature(openAiEndpoint.getTemperature())
.build();

this.openAiClient.setEndpoint(openAiEndpoint);

EdgeChain<CompletionResponse> edgeChain = openAiClient.createCompletion(completionRequest);
EdgeChain<CompletionResponse> edgeChain =
openAiClient.createCompletion(completionRequest, openAiEndpoint);

return edgeChain.toSingle();
}
Expand All @@ -273,11 +267,10 @@ public Single<CompletionResponse> completion(@RequestBody OpenAiEndpoint openAiE
public Single<OpenAiEmbeddingResponse> embeddings(@RequestBody OpenAiEndpoint openAiEndpoint)
throws SQLException {

this.openAiClient.setEndpoint(openAiEndpoint);

EdgeChain<OpenAiEmbeddingResponse> edgeChain =
openAiClient.createEmbeddings(
new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()));
new OpenAiEmbeddingRequest(openAiEndpoint.getModel(), openAiEndpoint.getRawText()),
openAiEndpoint);

if (Objects.nonNull(env.getProperty("postgres.db.host"))) {

Expand Down

0 comments on commit 7b6daec

Please sign in to comment.