Skip to content

Commit

Permalink
Merge pull request #334 from Saibo-creator/query_builder
Browse files Browse the repository at this point in the history
add a basic QueryBuilder, test and doc
  • Loading branch information
lbeurerkellner committed May 9, 2024
2 parents eaf03be + 2d6c818 commit fdb773c
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 1 deletion.
17 changes: 17 additions & 0 deletions docs/docs/language/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,23 @@ distribution
ANSWER in ["A", "B"]
```

Instead of writing the query in string, you could also write it in a more programmatic way with query builder.
```python
import lmql

query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_distribution('ANSWER', '["A", "B"]')
.build())

query.run_sync()
# You can also run it asynchronously with query.run_async() and await the result
```



:::

### Decoder Clause
Expand Down
Empty file added src/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion src/lmql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@
from lmql.runtime.lmql_runtime import (LMQLQueryFunction, compiled_query, tag)

# event loop utils
from lmql.runtime.loop import main
from lmql.runtime.loop import main

# query builder
from lmql.api.query_builder import QueryBuilder
78 changes: 78 additions & 0 deletions src/lmql/api/query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from lmql.api import run, run_sync


class QueryExecution:
def __init__(self, query_string):
self.query_string = query_string

async def run(self, *args, **kwargs):
# This method should asynchronously execute the query_string
return await run(self.query_string, *args, **kwargs)

def run_sync(self, *args, **kwargs):
# This method should synchronously execute the query_string
return run_sync(self.query_string, *args, **kwargs)


class QueryBuilder:
def __init__(self):
self.decoder = None
self.prompt = None
self.model = None
self.where = None
self.distribution_expr = None

def set_decoder(self, decoder='argmax', **kwargs):
if decoder not in ['argmax', 'sample', 'beam', 'beam_var', 'var', 'best_k']:
raise ValueError(f"Invalid decoder: {decoder}")
self.decoder = (decoder, kwargs)
return self

def set_prompt(self, prompt="What is the capital of France? [ANSWER]"):
self.prompt = prompt
return self

def set_model(self, model="gpt2"):
self.model = model
return self

def set_where(self, where="len(TOKENS(ANSWER)) < 10"):
"""
Add a where clause to the query
If a where clause already exists, the new clause is appended with an 'and'
If the user wants to use 'or', they need to put or in the where clause
such as: "len(TOKENS(ANSWER)) < 10 or len(TOKENS(ANSWER)) > 2"
"""
self.where = where if self.where is None else f"{self.where} and {where}"
return self

def set_distribution(self, variable="ANSWER", expr='["A", "B"]'):
self.distribution_expr = (variable, expr)
return self

def build(self):
components = []

if self.decoder:
decoder_str = self.decoder[0]
if self.decoder[1]: # If keyword arguments are provided
decoder_str += f"({self.decoder[1]})"
components.append(decoder_str)

if self.prompt:
components.append(f'"{self.prompt}"')

if self.model:
components.append(f'from "{self.model}"')

if self.where:
components.append(f'where {self.where}')

if self.distribution_expr:
variable, expr = self.distribution_expr
components.append(f'distribution {variable} in {expr}')

query_string = ' '.join(components)
# Return an instance of QueryExecution instead of a string
return QueryExecution(query_string)

Empty file added src/lmql/tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions src/lmql/tests/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import lmql
import numpy as np

from lmql.tests.expr_test_utils import run_all_tests



def test_query_builder():
# Example usage:
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_where('len(TOKENS(ANSWER)) < 10')
.set_where('len(TOKENS(ANSWER)) > 2')
.build())

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2'

assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()

def test_query_builder_with_dist():

query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_distribution('ANSWER', '["Paris", "London"]')
.build())

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]'

assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()

if __name__ == "__main__":
run_all_tests(globals())

0 comments on commit fdb773c

Please sign in to comment.