Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scripts, command options for large, xlarge model and enhance README.md, #5

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ modify the `share_type` parameter:
在加载`config`时,指定`share_type`参数,如下:

```python
config = BertConfig.from_pretrained(bert_config_file,share_type=share_type)
config = AlbertConfig.from_pretrained(bert_config_file,share_type=share_type)
```
## Download Pre-trained Models of Chinese

Expand Down Expand Up @@ -103,15 +103,66 @@ config = BertConfig.from_pretrained(bert_config_file,share_type=share_type)

## 下游任务Fine-tuning

1.下载预训练的albert模型
1.下载预训练的albert模型,例如下载albert_large_zh.zip,解压到 ~/tmp文件夹下:
```
$ tree ~/tmp/
/home/dell/tmp/
└── albert_large_zh
├── albert_config_large.json
├── albert_model.ckpt.data-00000-of-00001
├── albert_model.ckpt.index
├── albert_model.ckpt.meta
├── checkpoint
└── vocab.txt
```


2.运行`python convert_albert_tf_checkpoint_to_pytorch.py`将TF模型权重转化为pytorch模型权重(默认情况下shar_type=all)
```
$python convert_albert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ~/tmp/albert_large_zh/ \
--bert_config_file configs/albert_config_large.json \
--pytorch_dump_path pretrain/pytorch/pytorch_model.bin
```
请参考 convert.sh.



3.下载对应的数据集,比如[LCQMC](https://drive.google.com/open?id=1HXYMqsXjmA5uIfu_SFqP7r_vZZG-m_H0)数据集,包含训练、验证和测试集,训练集包含24万口语化描述的中文句子对,标签为1或0,1为句子语义相似,0为语义不相似,将下载文件解压到dataset/lcqmc/。
```
$ tree dataset/lcqmc/
dataset/lcqmc/
├── dev.txt
├── __init__.py
├── test.txt
└── train.txt
```

3.下载对应的数据集,比如[LCQMC](https://drive.google.com/open?id=1HXYMqsXjmA5uIfu_SFqP7r_vZZG-m_H0)数据集,包含训练、验证和测试集,训练集包含24万口语化描述的中文句子对,标签为1或0。1为句子语义相似,0为语义不相似。

4.运行`python run_classifier.py --do_train`进行Fine-tuning训练
```
python run_classifier.py \
--arch albert_large \
--albert_config_path configs/albert_config_large.json \
--bert_dir pretrain/pytorch/albert_large_zh \
--train_batch_size 24 \
--num_train_epochs 10 \
--do_train
```
请参考 train.sh.



5. 运行`python run_classifier.py --do_test`进行test评估
```
python run_classifier.py \
--arch albert_large \
--albert_config_path configs/albert_config_large.json \
--bert_dir pretrain/pytorch/albert_large_zh \
--do_test
```

请参考 test.sh.

## 结果

Expand Down
6 changes: 3 additions & 3 deletions callback/optimizater.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def step(self, closure=None):
#
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
It has been proposed in `Large Batch Optimization for Deep Learning: Training ALBERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
Expand All @@ -702,7 +702,7 @@ class Lamb(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
.. _Large Batch Optimization for Deep Learning: Training ALBERT in 76 minutes:
https://arxiv.org/abs/1904.00962
Example:
>>> model = ResNet()
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def step(self, closure=None):
return loss

class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
"""Implements ALBERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
Expand Down
99 changes: 1 addition & 98 deletions common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
import torch
import numpy as np
import json
import pickle

import torch.nn as nn
from collections import OrderedDict
from pathlib import Path
import logging

logger = logging.getLogger()
def print_config(config):
info = "Running with the following configs:\n"
for k, v in config.items():
info += f"\t{k} : {str(v)}\n"
print("\n" + info + "\n")
return

def init_logger(log_file=None, log_file_level=logging.NOTSET):
'''
Expand Down Expand Up @@ -57,97 +51,6 @@ def seed_everything(seed=1029):
torch.backends.cudnn.deterministic = True


def prepare_device(n_gpu_use):
"""
setup GPU device if available, move model into configured device
# 如果n_gpu_use为数字,则使用range生成list
# 如果输入的是一个list,则默认使用list[0]作为controller
"""
if not n_gpu_use:
device_type = 'cpu'
else:
n_gpu_use = n_gpu_use.split(",")
device_type = f"cuda:{n_gpu_use[0]}"
n_gpu = torch.cuda.device_count()
if len(n_gpu_use) > 0 and n_gpu == 0:
logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
device_type = 'cpu'
if len(n_gpu_use) > n_gpu:
msg = f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are available on this machine."
logger.warning(msg)
n_gpu_use = range(n_gpu)
device = torch.device(device_type)
list_ids = n_gpu_use
return device, list_ids


def model_device(n_gpu, model):
'''
判断环境 cpu还是gpu
支持单机多卡
:param n_gpu:
:param model:
:return:
'''
device, device_ids = prepare_device(n_gpu)
if len(device_ids) > 1:
logger.info(f"current {len(device_ids)} GPUs")
model = torch.nn.DataParallel(model, device_ids=device_ids)
if len(device_ids) == 1:
os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0])
model = model.to(device)
return model, device


def restore_checkpoint(resume_path, model=None):
'''
加载模型
:param resume_path:
:param model:
:param optimizer:
:return:
注意: 如果是加载Bert模型的话,需要调整,不能使用该模式
可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict)
'''
if isinstance(resume_path, Path):
resume_path = str(resume_path)
checkpoint = torch.load(resume_path)
best = checkpoint['best']
start_epoch = checkpoint['epoch'] + 1
states = checkpoint['state_dict']
if isinstance(model, nn.DataParallel):
model.module.load_state_dict(states)
else:
model.load_state_dict(states)
return [model,best,start_epoch]


def save_pickle(data, file_path):
'''
保存成pickle文件
:param data:
:param file_name:
:param pickle_path:
:return:
'''
if isinstance(file_path, Path):
file_path = str(file_path)
with open(file_path, 'wb') as f:
pickle.dump(data, f)


def load_pickle(input_file):
'''
读取pickle文件
:param pickle_path:
:param file_name:
:return:
'''
with open(str(input_file), 'rb') as f:
data = pickle.load(f)
return data


def save_json(data, file_path):
'''
保存成json文件
Expand Down
82 changes: 82 additions & 0 deletions convert.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#/************************************************************************************
#***
#*** Copyright 2019 Dell([email protected]), All Rights Reserved.
#***
#*** File Author: Dell, 2019-10-18 13:36:25
#***
#************************************************************************************/
#
#! /bin/sh

usage()
{
echo "Usage: $0 [options] commands"
echo "Options:"
echo " --base Convert base model"
echo " --large Convert large model"
echo " --xlarge Convert xlarge model"
exit 1
}

base_model()
{
INPUT_DIR=~/tmp/albert_base_zh
OUTPUT_DIR=pretrain/pytorch/albert_base_zh

mkdir -p ${OUTPUT_DIR}

python convert_albert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ${INPUT_DIR} \
--bert_config_file ${INPUT_DIR}/albert_config_base.json \
--pytorch_dump_path ${OUTPUT_DIR}/pytorch_model.bin

cp -v ${INPUT_DIR}/albert_config_base.json ${OUTPUT_DIR}
}

large_model()
{
INPUT_DIR=~/tmp/albert_large_zh
OUTPUT_DIR=pretrain/pytorch/albert_large_zh

mkdir -p ${OUTPUT_DIR}

python convert_albert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ${INPUT_DIR} \
--bert_config_file ${INPUT_DIR}/albert_config_large.json \
--pytorch_dump_path ${OUTPUT_DIR}/pytorch_model.bin

cp -v ${INPUT_DIR}/albert_config_large.json ${OUTPUT_DIR}
}

xlarge_model()
{
INPUT_DIR=~/tmp/albert_xlarge_zh
OUTPUT_DIR=pretrain/pytorch/albert_xlarge_zh

mkdir -p ${OUTPUT_DIR}

python convert_albert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path ${INPUT_DIR} \
--bert_config_file ${INPUT_DIR}/albert_config_xlarge.json \
--pytorch_dump_path ${OUTPUT_DIR}/pytorch_model.bin

cp -v ${INPUT_DIR}/albert_config_xlarge.json ${OUTPUT_DIR}
}

[ "$*" = "" ] && usage

case $1 in
--base)
base_model
;;
--large)
large_model
;;
--xlarge)
xlarge_model
;;
*)
usage
;;
esac

14 changes: 8 additions & 6 deletions convert_albert_tf_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
"""Convert ALBERT checkpoint."""

from __future__ import absolute_import
from __future__ import division
Expand All @@ -21,16 +21,18 @@
import argparse
import torch

from model.modeling_albert import BertConfig, BertForPreTraining, load_tf_weights_in_albert
from model.modeling_albert import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert

import logging
logging.basicConfig(level=logging.INFO)


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,share_type, pytorch_dump_path):

# Initialise PyTorch model
config = BertConfig.from_pretrained(bert_config_file,share_type=share_type)
config = AlbertConfig.from_pretrained(bert_config_file,share_type=share_type)
# print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
model = AlbertForPreTraining(config)

# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
Expand All @@ -52,7 +54,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,share_
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
help = "The config json file corresponding to the pre-trained ALBERT model. \n"
"This specifies the model architecture.")
parser.add_argument('--share_type',
default='all',
Expand All @@ -78,4 +80,4 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,share_
--pytorch_dump_path=./pretrain/pytorch/albert_xlarge_zh/pytorch_model.bin \
--share_type=all

'''
'''
1 change: 0 additions & 1 deletion dataset/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion dataset/lcqmc/__init__.py

This file was deleted.

Loading