Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for prompt_edit_token_weights parsing #22

Open
mix1009 opened this issue Oct 22, 2022 · 0 comments
Open

Better support for prompt_edit_token_weights parsing #22

mix1009 opened this issue Oct 22, 2022 · 0 comments

Comments

@mix1009
Copy link

mix1009 commented Oct 22, 2022

Instead of counting indices for tokens to pass into prompt_edit_token_weights, it would be easier to reference it by 'word'.
parse_edit_weights converts weights with words and word list, in addition to int indices to weights with int indices:

prompt = 'the quick brown fox jumps over the lazy dog'
parse_edit_weights(prompt, None, [('brown', -1), (2, 0.5), (['lazy', 'dog'], -1.5)])

returned result is [(3, -1), (2, 0.5), (8, -1.5), (9, -1.5)].

Here's the code:

def sep_token(prompt):
    tokens = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True).input_ids[0]
    words = []
    index = 1
    while True:
        word = clip_tokenizer.decode(tokens[index:index+1])
        if not word: break
        if word == '<|endoftext|>': break
        words.append(word)
        index += 1
        if index > 500: break
    return words

def parse_edit_weights(prompt, prompt_edit, edit_weights):
    if prompt_edit:
        tokens = sep_token(prompt_edit)
    else:
        tokens = sep_token(prompt)
    
    prompt_edit_token_weights=[]
    for tl, w in edit_weights:
        if isinstance(tl, list) or isinstance(tl, tuple):
            pass
        else:
            tl = [tl]
        for t in tl:
            try:
                if isinstance(t, str):
                    idx = tokens.index(t) + 1
                elif isinstance(t, int):
                    idx = t
                prompt_edit_token_weights.append((idx, w))
            except ValueError as e:
                print(f'error {e}')
            
    return prompt_edit_token_weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant