From 0c3ec957786f2ced0188b6a98616fe49aaf73958 Mon Sep 17 00:00:00 2001 From: Saibo Geng Date: Thu, 29 Feb 2024 18:06:39 +0100 Subject: [PATCH 1/2] add a basic QueryBuilder, test and doc --- docs/docs/language/reference.md | 17 ++++++++ src/__init__.py | 0 src/lmql/language/query_builder.py | 60 ++++++++++++++++++++++++++++ src/lmql/tests/__init__.py | 0 src/lmql/tests/test_query_builder.py | 39 ++++++++++++++++++ 5 files changed, 116 insertions(+) create mode 100644 src/__init__.py create mode 100644 src/lmql/language/query_builder.py create mode 100644 src/lmql/tests/__init__.py create mode 100644 src/lmql/tests/test_query_builder.py diff --git a/docs/docs/language/reference.md b/docs/docs/language/reference.md index d83a8c85..93f77592 100644 --- a/docs/docs/language/reference.md +++ b/docs/docs/language/reference.md @@ -201,6 +201,23 @@ distribution ANSWER in ["A", "B"] ``` +Instead of writing the query in string, you could also write it in a more programmatic way with query builder. +```python +import lmql +from lmql.language.query_builder import QueryBuilder + +query = (QueryBuilder() + .set_decoder('argmax') + .set_prompt('What is the capital of France? [ANSWER]') + .set_model('gpt2') + .set_distribution('ANSWER', '["A", "B"]') + .build()) + +lmql.run_sync(query) +``` + + + ::: ### Decoder Clause diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lmql/language/query_builder.py b/src/lmql/language/query_builder.py new file mode 100644 index 00000000..412e3792 --- /dev/null +++ b/src/lmql/language/query_builder.py @@ -0,0 +1,60 @@ +class QueryBuilder: + def __init__(self): + self.decoder = None + self.prompt = None + self.model = None + self.where = None + self.distribution_expr = None + + def set_decoder(self, decoder='argmax', **kwargs): + if decoder not in ['argmax', 'sample', 'beam', 'beam_var', 'var', 'best_k']: + raise ValueError(f"Invalid decoder: {decoder}") + self.decoder = (decoder, kwargs) + return self + + def set_prompt(self, prompt="What is the capital of France? [ANSWER]"): + self.prompt = prompt + return self + + def set_model(self, model="gpt2"): + self.model = model + return self + + def set_where(self, where="len(TOKENS(ANSWER)) < 10"): + """ + Add a where clause to the query + If a where clause already exists, the new clause is appended with an 'and' + If the user wants to use 'or', they need to put or in the where clause + such as: "len(TOKENS(ANSWER)) < 10 or len(TOKENS(ANSWER)) > 2" + """ + self.where = where if self.where is None else f"{self.where} and {where}" + return self + + def set_distribution(self, variable="ANSWER", expr='["A", "B"]'): + self.distribution_expr = (variable, expr) + return self + + def build(self): + components = [] + + if self.decoder: + decoder_str = self.decoder[0] + if self.decoder[1]: # If keyword arguments are provided + decoder_str += f"({self.decoder[1]})" + components.append(decoder_str) + + if self.prompt: + components.append(f'"{self.prompt}"') + + if self.model: + components.append(f'from "{self.model}"') + + if self.where: + components.append(f'where {self.where}') + + if self.distribution_expr: + variable, expr = self.distribution_expr + components.append(f'distribution {variable} in {expr}') + + return ' '.join(components) + diff --git a/src/lmql/tests/__init__.py b/src/lmql/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/lmql/tests/test_query_builder.py b/src/lmql/tests/test_query_builder.py new file mode 100644 index 00000000..520b062c --- /dev/null +++ b/src/lmql/tests/test_query_builder.py @@ -0,0 +1,39 @@ +import lmql +import numpy as np + +from lmql.tests.expr_test_utils import run_all_tests + +from lmql.language.query_builder import QueryBuilder + + +def test_query_builder(): + # Example usage: + prompt = (QueryBuilder() + .set_decoder('argmax') + .set_prompt('What is the capital of France? [ANSWER]') + .set_model('gpt2') + .set_where('len(TOKENS(ANSWER)) < 10') + .set_where('len(TOKENS(ANSWER)) > 2') + .build()) + + expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2' + + assert expected==prompt, f"Expected: {expected}, got: {prompt}" + out = lmql.run_sync(prompt,) + +def test_query_builder_with_dist(): + + prompt = (QueryBuilder() + .set_decoder('argmax') + .set_prompt('What is the capital of France? [ANSWER]') + .set_model('gpt2') + .set_distribution('ANSWER', '["Paris", "London"]') + .build()) + + expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]' + + assert expected==prompt, f"Expected: {expected}, got: {prompt}" + out = lmql.run_sync(prompt,) + +if __name__ == "__main__": + run_all_tests(globals()) \ No newline at end of file From 2d6c818641f8f7e46930d123598a4e7766cd5a9e Mon Sep 17 00:00:00 2001 From: Saibo Geng Date: Wed, 20 Mar 2024 10:31:43 +0100 Subject: [PATCH 2/2] 1. Expose QueryBuilder as lmql.QueryBuilder and move the implementation to lmql/api 2. Wrap the output of `QueryBuilder` as `QueryExecution` class which has `async run()` and `run_sync()` methods. Usage example: `result = lmql.QueryBuilder().set_prompt("What is the capital of France? [ANSWER]").set_model("gpt2").build().run_sync()` --- docs/docs/language/reference.md | 6 +++--- src/lmql/__init__.py | 5 ++++- src/lmql/{language => api}/query_builder.py | 20 +++++++++++++++++++- src/lmql/tests/test_query_builder.py | 13 ++++++------- 4 files changed, 32 insertions(+), 12 deletions(-) rename src/lmql/{language => api}/query_builder.py (75%) diff --git a/docs/docs/language/reference.md b/docs/docs/language/reference.md index 93f77592..5d19674a 100644 --- a/docs/docs/language/reference.md +++ b/docs/docs/language/reference.md @@ -204,16 +204,16 @@ distribution Instead of writing the query in string, you could also write it in a more programmatic way with query builder. ```python import lmql -from lmql.language.query_builder import QueryBuilder -query = (QueryBuilder() +query = (lmql.QueryBuilder() .set_decoder('argmax') .set_prompt('What is the capital of France? [ANSWER]') .set_model('gpt2') .set_distribution('ANSWER', '["A", "B"]') .build()) -lmql.run_sync(query) +query.run_sync() +# You can also run it asynchronously with query.run_async() and await the result ``` diff --git a/src/lmql/__init__.py b/src/lmql/__init__.py index b6a8866c..419dc3ab 100644 --- a/src/lmql/__init__.py +++ b/src/lmql/__init__.py @@ -31,4 +31,7 @@ from lmql.runtime.lmql_runtime import (LMQLQueryFunction, compiled_query, tag) # event loop utils -from lmql.runtime.loop import main \ No newline at end of file +from lmql.runtime.loop import main + +# query builder +from lmql.api.query_builder import QueryBuilder \ No newline at end of file diff --git a/src/lmql/language/query_builder.py b/src/lmql/api/query_builder.py similarity index 75% rename from src/lmql/language/query_builder.py rename to src/lmql/api/query_builder.py index 412e3792..d4b1d75f 100644 --- a/src/lmql/language/query_builder.py +++ b/src/lmql/api/query_builder.py @@ -1,3 +1,19 @@ +from lmql.api import run, run_sync + + +class QueryExecution: + def __init__(self, query_string): + self.query_string = query_string + + async def run(self, *args, **kwargs): + # This method should asynchronously execute the query_string + return await run(self.query_string, *args, **kwargs) + + def run_sync(self, *args, **kwargs): + # This method should synchronously execute the query_string + return run_sync(self.query_string, *args, **kwargs) + + class QueryBuilder: def __init__(self): self.decoder = None @@ -56,5 +72,7 @@ def build(self): variable, expr = self.distribution_expr components.append(f'distribution {variable} in {expr}') - return ' '.join(components) + query_string = ' '.join(components) + # Return an instance of QueryExecution instead of a string + return QueryExecution(query_string) diff --git a/src/lmql/tests/test_query_builder.py b/src/lmql/tests/test_query_builder.py index 520b062c..1a73152e 100644 --- a/src/lmql/tests/test_query_builder.py +++ b/src/lmql/tests/test_query_builder.py @@ -3,12 +3,11 @@ from lmql.tests.expr_test_utils import run_all_tests -from lmql.language.query_builder import QueryBuilder def test_query_builder(): # Example usage: - prompt = (QueryBuilder() + query = (lmql.QueryBuilder() .set_decoder('argmax') .set_prompt('What is the capital of France? [ANSWER]') .set_model('gpt2') @@ -18,12 +17,12 @@ def test_query_builder(): expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2' - assert expected==prompt, f"Expected: {expected}, got: {prompt}" - out = lmql.run_sync(prompt,) + assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}" + out = query.run_sync() def test_query_builder_with_dist(): - prompt = (QueryBuilder() + query = (lmql.QueryBuilder() .set_decoder('argmax') .set_prompt('What is the capital of France? [ANSWER]') .set_model('gpt2') @@ -32,8 +31,8 @@ def test_query_builder_with_dist(): expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]' - assert expected==prompt, f"Expected: {expected}, got: {prompt}" - out = lmql.run_sync(prompt,) + assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}" + out = query.run_sync() if __name__ == "__main__": run_all_tests(globals()) \ No newline at end of file