Skip to content

Commit

Permalink
1. Expose QueryBuilder as lmql.QueryBuilder and move the implementati…
Browse files Browse the repository at this point in the history
…on to lmql/api

2. Wrap the output of `QueryBuilder` as `QueryExecution` class which has `async run()` and `run_sync()` methods. Usage example: `result = lmql.QueryBuilder().set_prompt("What is the capital of France? [ANSWER]").set_model("gpt2").build().run_sync()`
  • Loading branch information
Saibo-creator committed Mar 20, 2024
1 parent 0c3ec95 commit 2d6c818
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
6 changes: 3 additions & 3 deletions docs/docs/language/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,16 @@ distribution
Instead of writing the query in string, you could also write it in a more programmatic way with query builder.
```python
import lmql
from lmql.language.query_builder import QueryBuilder

query = (QueryBuilder()
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_distribution('ANSWER', '["A", "B"]')
.build())

lmql.run_sync(query)
query.run_sync()
# You can also run it asynchronously with query.run_async() and await the result
```


Expand Down
5 changes: 4 additions & 1 deletion src/lmql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@
from lmql.runtime.lmql_runtime import (LMQLQueryFunction, compiled_query, tag)

# event loop utils
from lmql.runtime.loop import main
from lmql.runtime.loop import main

# query builder
from lmql.api.query_builder import QueryBuilder
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
from lmql.api import run, run_sync


class QueryExecution:
def __init__(self, query_string):
self.query_string = query_string

async def run(self, *args, **kwargs):
# This method should asynchronously execute the query_string
return await run(self.query_string, *args, **kwargs)

def run_sync(self, *args, **kwargs):
# This method should synchronously execute the query_string
return run_sync(self.query_string, *args, **kwargs)


class QueryBuilder:
def __init__(self):
self.decoder = None
Expand Down Expand Up @@ -56,5 +72,7 @@ def build(self):
variable, expr = self.distribution_expr
components.append(f'distribution {variable} in {expr}')

return ' '.join(components)
query_string = ' '.join(components)
# Return an instance of QueryExecution instead of a string
return QueryExecution(query_string)

13 changes: 6 additions & 7 deletions src/lmql/tests/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

from lmql.tests.expr_test_utils import run_all_tests

from lmql.language.query_builder import QueryBuilder


def test_query_builder():
# Example usage:
prompt = (QueryBuilder()
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
Expand All @@ -18,12 +17,12 @@ def test_query_builder():

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2'

assert expected==prompt, f"Expected: {expected}, got: {prompt}"
out = lmql.run_sync(prompt,)
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()

def test_query_builder_with_dist():

prompt = (QueryBuilder()
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
Expand All @@ -32,8 +31,8 @@ def test_query_builder_with_dist():

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]'

assert expected==prompt, f"Expected: {expected}, got: {prompt}"
out = lmql.run_sync(prompt,)
assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()

if __name__ == "__main__":
run_all_tests(globals())

0 comments on commit 2d6c818

Please sign in to comment.