Skip to content

Commit

Permalink
Answser endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Roaoch committed Jun 25, 2024
1 parent 5340fbe commit 59b4efe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 45 deletions.
27 changes: 10 additions & 17 deletions src/cyberclaasic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import pandas as pd

from src.discriminator import DiscriminatorModel

from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GenerationConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, GPT2LMHeadModel, GenerationConfig

import numpy as np

Expand All @@ -22,7 +20,8 @@ def __init__(

self.tokenizer = AutoTokenizer.from_pretrained('Roaoch/CyberClassic-Generator')
self.generator: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained('Roaoch/CyberClassic-Generator')
self.discriminator = DiscriminatorModel.from_pretrained('Roaoch/CyberClassic-Discriminator')
self.discriminator_tokenizer = AutoTokenizer.from_pretrained('Roaoch/CyberClassic-Discriminator')
self.discriminator = AutoModelForSequenceClassification.from_pretrained('Roaoch/CyberClassic-Discriminator')

self.generation_config = GenerationConfig(
max_new_tokens=max_length,
Expand All @@ -33,33 +32,27 @@ def __init__(
pad_token_id=self.tokenizer.pad_token_id
)

def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
last_hidden_state = self.generator(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)['hidden_states'][-1]
weights_for_non_padding = attention_mask * torch.arange(start=1, end=last_hidden_state.shape[1] + 1).unsqueeze(0)
sum_embeddings = torch.sum(last_hidden_state * weights_for_non_padding.unsqueeze(-1), dim=1)
num_of_none_padding_tokens = torch.sum(weights_for_non_padding, dim=-1).unsqueeze(-1)
return sum_embeddings / num_of_none_padding_tokens

def generate(self) -> str:
starts = self.startings['text'].values[np.random.randint(0, len(self.startings), 4)].tolist()
tokens = self.tokenizer(starts, return_tensors='pt', padding=True, truncation=True)
generated = self.generator.generate(**tokens, generation_config=self.generation_config)

input_emb = self.encode(input_ids=generated, attention_mask=torch.full(generated.size(), 1))
score = self.discriminator(input_emb)
score = torch.abs(score - 0.889)
index = int(torch.argmin(score))

decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)

score = self.discriminator(decoded)
index = int(torch.argmax(score))

return decoded[index]

def answer(self, promt: str) -> str:
promt = promt + ' .'
length = len(promt)

promt_tokens = self.tokenizer(promt, return_tensors='pt')
output = self.generator.generate(
**promt_tokens,
generation_config=self.generation_config,
)

decoded = self.tokenizer.batch_decode(output)
return decoded[0]
return decoded[0][length:].strip()
28 changes: 0 additions & 28 deletions src/discriminator.py

This file was deleted.

0 comments on commit 59b4efe

Please sign in to comment.