diff --git a/inltk/codemixed_util.py b/inltk/codemixed_util.py new file mode 100644 index 0000000..662308b --- /dev/null +++ b/inltk/codemixed_util.py @@ -0,0 +1,76 @@ +import os + +import torch +import torch.optim as optim + +import random +from fastai import * +from fastai.text import * +from fastai.callbacks import * +from transformers import PreTrainedModel, PreTrainedTokenizer, PretrainedConfig +from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig + +class TransformersBaseTokenizer(BaseTokenizer): + """Wrapper around PreTrainedTokenizer to be compatible with fast.ai""" + def __init__(self, pretrained_tokenizer: PreTrainedTokenizer, model_type = 'bert', **kwargs): + self._pretrained_tokenizer = pretrained_tokenizer + self.max_seq_len = pretrained_tokenizer.max_len + self.model_type = model_type + + def __call__(self, *args, **kwargs): + return self + + def tokenizer(self, t:str) -> List[str]: + """Limits the maximum sequence length and add the spesial tokens""" + CLS = self._pretrained_tokenizer.cls_token + SEP = self._pretrained_tokenizer.sep_token + if self.model_type in ['roberta']: + tokens = self._pretrained_tokenizer.tokenize(t, add_prefix_space=True)[:self.max_seq_len - 2] + tokens = [CLS] + tokens + [SEP] + else: + tokens = self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + if self.model_type in ['xlnet']: + tokens = tokens + [SEP] + [CLS] + else: + tokens = [CLS] + tokens + [SEP] + return tokens + +class TransformersVocab(Vocab): + def __init__(self, pretrained_model_name: str): + super(TransformersVocab, self).__init__(itos = []) + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + + def numericalize(self, t:Collection[str]) -> List[int]: + "Convert a list of tokens `t` to their ids." + return self.tokenizer.convert_tokens_to_ids(t) + #return self.tokenizer.encode(t) + + def textify(self, nums:Collection[int], sep=' ') -> List[str]: + "Convert a list of `nums` to their tokens." + nums = np.array(nums).tolist() + return sep.join(self.tokenizer.convert_ids_to_tokens(nums)) if sep is not None else self.tokenizer.convert_ids_to_tokens(nums) + def __getstate__(self): + return {'itos':self.itos, 'tokenizer':self.tokenizer} + + def __setstate__(self, state:dict): + self.itos = state['itos'] + self.tokenizer = state['tokenizer'] + self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)}) + + +class CustomTransformerModel(nn.Module): + def __init__(self, transformer_model: PreTrainedModel): + super(CustomTransformerModel,self).__init__() + self.transformer = transformer_model + self.pad_idx = AutoTokenizer.from_pretrained('ai4bharat/indic-bert').pad_token_id + def forward(self, input_ids, attention_mask=None): + + # attention_mask + # Mask to avoid performing attention on padding token indices. + # Mask values selected in ``[0, 1]``: + # ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + attention_mask = (input_ids!=self.pad_idx).type(input_ids.type()) + + logits = self.transformer(input_ids, + attention_mask = attention_mask)[0] + return logits diff --git a/inltk/config.py b/inltk/config.py index 1c58fbc..45424e8 100644 --- a/inltk/config.py +++ b/inltk/config.py @@ -88,5 +88,7 @@ def get_config(): 'all_languages_identifying_model_url': 'https://www.dropbox.com/s/a06fa0zlr7bfif0/export.pkl?raw=1', 'all_languages_identifying_tokenizer_name': 'tokenizer.model', 'all_languages_identifying_tokenizer_url': - 'https://www.dropbox.com/s/t4mypdd8aproj88/all_language.model?raw=1' + 'https://www.dropbox.com/s/t4mypdd8aproj88/all_language.model?raw=1', + 'codemixed_identifying_model_name': 'export.pkl', + 'codemixed_identifying_model_url': 'https://www.dropbox.com/s/tlhnkbqffqb832a/export.pkl?raw=1' } diff --git a/inltk/download_assets.py b/inltk/download_assets.py index 9790b6d..3d7a3e1 100644 --- a/inltk/download_assets.py +++ b/inltk/download_assets.py @@ -46,10 +46,13 @@ def verify_language(language_code: str): async def check_all_languages_identifying_model(): config = AllLanguageConfig.get_config() if (path/'models'/'all'/f'{config["all_languages_identifying_model_name"]}').exists() and \ - (path/'models'/'all'/f'{config["all_languages_identifying_tokenizer_name"]}').exists(): + (path/'models'/'all'/f'{config["all_languages_identifying_tokenizer_name"]}').exists() and \ + (path/'models'/'codemixed'/f'{config["codemixed_identifying_model_name"]}').exists(): return True done = await download_file(config["all_languages_identifying_model_url"], path/'models'/'all', config["all_languages_identifying_model_name"]) + done = await download_file(config["codemixed_identifying_model_url"], path/'models'/'codemixed', + config["codemixed_identifying_model_name"]) done = await download_file(config["all_languages_identifying_tokenizer_url"], path/'models'/'all', config["all_languages_identifying_tokenizer_name"]) return done diff --git a/inltk/inltk.py b/inltk/inltk.py index 9997eed..55d43f6 100644 --- a/inltk/inltk.py +++ b/inltk/inltk.py @@ -11,6 +11,8 @@ from inltk.tokenizer import LanguageTokenizer from inltk.const import tokenizer_special_cases from inltk.utils import cos_sim, reset_models, is_english +from inltk.utils import * +from inltk.codemixed_util import * if not sys.warnoptions: warnings.simplefilter("ignore") @@ -56,17 +58,42 @@ def predict_next_words(input: str, n_words: int, language_code: str, randomness= output = output.replace(special_str, '\n') return output + tok = LanguageTokenizer(language_code) def tokenize(input: str, language_code: str): check_input_language(language_code) - tok = LanguageTokenizer(language_code) output = tok.tokenizer(input) return output - -def identify_language(input: str): +def identify_codemixed(input: str): + asyncio.set_event_loop(asyncio.new_event_loop()) + loop = asyncio.get_event_loop() + tasks = [asyncio.ensure_future(check_all_languages_identifying_model())] + done = loop.run_until_complete(asyncio.gather(*tasks))[0] + loop.close() + defaults.device = torch.device('cpu') + path = Path(__file__).parent + try: + learn = load_learner(path / 'models' / 'codemixed') + output = learn.predict(input) + map_dict = { + '1': 'en', + '2': 'hi-en', + '3': 'ta-en', + '4': 'ml-en' + } + return map_dict[str(output[0])] + except AttributeError: + print("Probably you haven't imported the required classes. Try running 'from inltk.codemixed_util import *'") + print() + raise + +def identify_language(input: str, check_codemixed=False): if is_english(input): - return 'en' + if check_codemixed: + return identify_codemixed(input) + else: + return 'en' asyncio.set_event_loop(asyncio.new_event_loop()) loop = asyncio.get_event_loop() tasks = [asyncio.ensure_future(check_all_languages_identifying_model())] @@ -88,6 +115,7 @@ def remove_foreign_languages(input: str, host_language_code: str): def reset_language_identifying_models(): reset_models('all') + reset_models('codemixed') def get_embedding_vectors(input: str, language_code: str):