-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
123 lines (101 loc) · 3.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import argparse
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
import os.path as osp
import ssl
import urllib.request
import os
import json
from typing import Tuple
from models.modify_llama import MyLlamaForCausalLM
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path", type=str, default="models/llama/llama-7b"
)
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--task", type=str, default="wikitext-2-raw-v1")
parser.add_argument(
"--split", type=str, default="test", choices=["validation", "test"]
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
)
parser.add_argument(
"--output_dir",
type=str,
default="outputs/debug",
)
parser.add_argument("--enable_start_recent_kv_cache", action="store_true")
parser.add_argument("--start_size", type=int, default=1)
parser.add_argument("--recent_size", type=int, default=255)
parser.add_argument("--enable_pos_shift", action="store_true")
parser.add_argument("--num_eval_tokens", type=int, default=None)
args = parser.parse_args()
return args
def load(model_name_or_path: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
print(f"Loading model from {model_name_or_path} ...")
# however, tensor parallel for running falcon will occur bugs
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
# model = AutoModelForCausalLM.from_pretrained(
model = MyLlamaForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
model.eval()
return model, tokenizer
def download_url(url: str, folder="folder"):
"""
Downloads the content of an url to a folder. Modified from \
https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
Args:
url (string): The url of target file.
folder (string): The target folder.
Returns:
string: File path of downloaded files.
"""
file = url.rpartition("/")[2]
file = file if file[0] == "?" else file.split("?")[0]
path = osp.join(folder, file)
if osp.exists(path):
print(f"File {file} exists, use existing file.")
return path
print(f"Downloading {url}")
os.makedirs(folder, exist_ok=True)
ctx = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=ctx)
with open(path, "wb") as f:
f.write(data.read())
return path
def load_jsonl(
file_path,
):
list_data_dict = []
with open(file_path, "r") as f:
for line in f:
list_data_dict.append(json.loads(line))
return list_data_dict
def enable_streaming_llm(model, start_size, recent_size):
if "llama" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from models.modify_llama import (
my_attn_forward,
)
my_attn_forward(model)