Skip to content

Commit

Permalink
Merge pull request #1 from JyutdictEB/master
Browse files Browse the repository at this point in the history
v0.0 first commit
  • Loading branch information
TsinamLeung committed May 4, 2022
2 parents 4c9f4f9 + edc1f74 commit f01235b
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# iNTRODUCTION

a model than classify Cantonese & Mandarin

Download Pretrained model from realease.

# requirements

pytorch == 1.9.0
transformers > 4.0

# Datasets

[粵語對話語料](https://github.com/CanCLID/sentences)
[PTT 八卦版問答中文語料](https://github.com/zake7749/Gossiping-Chinese-Corpus)
90 changes: 90 additions & 0 deletions predict.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from transformers import BertTokenizerFast\n",
"from transformers import AutoModelForSequenceClassification\n",
"import torch\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"./yue-classifier-can\", num_labels=2)\n",
"\n",
"tokenizer = BertTokenizerFast.from_pretrained('../bert-base-chinese')\n",
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def predict(text):\n",
" # model trained with max_length = 64\n",
" token = tokenizer(text,return_tensors='pt',max_length=128)\n",
" inputs = token['input_ids']\n",
" mask = token['attention_mask']\n",
" outputs = model(input_ids=inputs, attention_mask=mask)\n",
" logits = outputs.logits\n",
" return logits.argmax(-1).tolist()[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#output labels to files\n",
"with open('../yue-cmn-classification-task-main/test.text.txt') as f:\n",
" lines = [l.strip() for l in f.readlines()]\n",
"labels = [str(1 - predict(l)) for l in lines]\n",
"with open('./output_label_base.txt',mode='w') as f:\n",
" f.writelines(\"\\n\".join(labels))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# predict to stdout\n",
"with open('daaizungDimpingRaw.txt') as f:\n",
" lines = [l.strip() for l in f.readlines()]\n",
"for l in lines:\n",
" if predict(l) == 0:\n",
" print(l)"
]
}
],
"metadata": {
"interpreter": {
"hash": "d3709d0369c0f205075bbf2caa11ca9e3ee451d42fb4d7ba05e7bde4a2f4f2a6"
},
"kernelspec": {
"display_name": "Python 3.9.12 ('pytorch')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
213 changes: 213 additions & 0 deletions train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liang\\anaconda3\\envs\\pytorch\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"\n",
"from transformers import BertTokenizerFast\n",
"import numpy as np\n",
"import torch\n",
"learning_rate = 0.000005\n",
"batch_size = 192\n",
"n_epochs = 10\n",
"\n",
"# if you want to use online BERT model,try replace it to 'ckip/bert-base-chinese'\n",
"tokenizer = BertTokenizerFast.from_pretrained('../bert-base-chinese')\n",
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"def text2token(txt):\n",
" d = tokenizer(txt,padding=\"max_length\", truncation=True,return_tensors='pt',max_length=64)\n",
" d.update((k,v[0]) for k,v in d.items())\n",
" return d\n",
"LabeledData = namedtuple(\"LabeledData\",['token','label'])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_raw_to_data(path : str,label:int):\n",
" label_tensor = torch.tensor(label)\n",
" ret = []\n",
" with open(path) as f:\n",
" for line in f.readlines():\n",
" ret.append(LabeledData(text2token(line.strip()),label_tensor))\n",
" return ret\n",
"\n",
"t_c = load_raw_to_data('oral.txt',1)\n",
"t_m = load_raw_to_data('Literature.txt',0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"import torch\n",
"train_dataloader = DataLoader([*t_c[:-200],*t_m[:-200]], shuffle=True, batch_size=batch_size)\n",
"eval_dataloader = DataLoader([*t_c[-200:],*t_m[-200:]], batch_size=batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at ../bert-base-cantonese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../bert-base-cantonese and are newly initialized: ['classifier.weight', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"../bert-base-cantonese\", num_labels=2)\n",
"model.to(device)\n",
"pass"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"from torch.optim import AdamW\n",
"from transformers import get_scheduler\n",
"import torch\n",
"\n",
"optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
"\n",
"num_training_steps = n_epochs * len(train_dataloader)\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps\n",
")\n",
"\n",
"def bce_loss(logits, labels):\n",
" logits = nn.functional.softmax(logits)\n",
" term_0 = logits[:, 0] * ~labels\n",
" term_1 = logits[:, 1] * labels\n",
" loss = -(term_0 + term_1).mean()\n",
" return loss\n",
"\n",
"def forward(params, key, inputs, labels, mask):\n",
" outputs = model(input_ids=inputs, attention_mask=mask, params=params, train=True, dropout_rng=key)\n",
" logits = outputs.logits\n",
" loss = bce_loss(logits, labels)\n",
" return loss\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 57/57 [19:08<00:00, 20.15s/it, loss=0.252]\n",
"Epoch 1: 100%|██████████| 57/57 [19:17<00:00, 20.31s/it, loss=0.114] \n",
"Epoch 2: 100%|██████████| 57/57 [19:29<00:00, 20.52s/it, loss=0.066] \n",
"Epoch 3: 100%|██████████| 57/57 [19:13<00:00, 20.24s/it, loss=0.0175]\n",
"Epoch 4: 100%|██████████| 57/57 [19:37<00:00, 20.65s/it, loss=0.0144]\n",
"Epoch 5: 100%|██████████| 57/57 [19:31<00:00, 20.56s/it, loss=0.0156]\n",
"Epoch 6: 100%|██████████| 57/57 [19:28<00:00, 20.51s/it, loss=0.0264]\n",
"Epoch 7: 100%|██████████| 57/57 [19:31<00:00, 20.55s/it, loss=0.00772]\n",
"Epoch 8: 100%|██████████| 57/57 [19:32<00:00, 20.57s/it, loss=0.127] \n",
"Epoch 9: 100%|██████████| 57/57 [19:41<00:00, 20.73s/it, loss=0.0131] \n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"model.train()\n",
"for epoch in range(n_epochs):\n",
" loop = tqdm(train_dataloader, leave=True)\n",
" for batch in loop:\n",
" input_ids = batch.token['input_ids'].to(device)\n",
" attention_mask = batch.token['attention_mask'].to(device)\n",
" label = batch.label.to(device)\n",
" outputs = model(labels=label,input_ids=input_ids,attention_mask=attention_mask)\n",
" \n",
" loss = outputs.loss\n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" loop.set_description(f'Epoch {epoch}')\n",
" loop.set_postfix(loss=loss.item())\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained(\"yue-classifier-can\")"
]
}
],
"metadata": {
"interpreter": {
"hash": "527a93331b4b1a8345148922acc34427fb7591433d63b66d32040b6fbbc6d593"
},
"kernelspec": {
"display_name": "Python 3.9.7 ('pytorch')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit f01235b

Please sign in to comment.