forked from cognitivecomputations/laserRMT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rmt_laser_dpo.py
282 lines (251 loc) · 11.3 KB
/
rmt_laser_dpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gc
from lib.utils.load_benchmark_dataset import get_benchmark_data
from lib.utils.assets import PromptTemplate
from lib.utils.prompt_template import get_llm_prompt
import torch.nn.functional as F
from src.AutoModelForSentenceEmbedding import (
AutoModelForSentenceEmbedding,
get_cosine_embeddings,
)
class ModelModifier:
def __init__(
self,
model_name,
prompt_template: PromptTemplate = PromptTemplate.chatml,
input_length=512,
output_length=512,
):
self.model_name = model_name
self.prompt_template = prompt_template
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map={"": 0}
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=True
)
self.layer_snr = {}
self.modified_layers = set()
self.original_weights = {}
self.input_length = input_length
self.output_length = output_length
self.embeddings_model = AutoModelForSentenceEmbedding(
model_name, self.tokenizer
)
def calculate_snr_for_layer(self, layer_type, layer_number):
for name, module in self.model.named_modules():
if layer_type in name and str(layer_number) in name:
weights = module.weight.double()
S = torch.linalg.svdvals(weights)
weights = weights.detach().cpu()
S = S.detach().cpu()
sigma_estimated = self.estimate_sigma_with_full_iqr(S)
n, m = weights.shape
mp_threshold = self.marchenko_pastur_threshold(sigma_estimated, n, m)
signal = S[S > mp_threshold].sum()
noise = S[S <= mp_threshold].sum()
snr = signal / noise if noise != 0 else float("inf")
del S, weights
torch.cuda.empty_cache() # Clear PyTorch's CUDA memory cache
gc.collect()
return snr
def update_model_reduce_layer(self, layer_type, layer_number):
layer_id = f"{layer_type}+{layer_number}"
if layer_id in self.modified_layers:
print(f"Layer {layer_id} has already been modified. Skipping.")
return False
for name, module in self.model.named_modules():
if layer_type in name and str(layer_number) in name:
print(f"Reconstructing layer: {name}")
original_dtype = module.weight.dtype
self.original_weights[name] = module.weight.detach().clone()
weights = module.weight.double()
U, S, V = torch.linalg.svd(weights, full_matrices=False)
# Estimate sigma using the full IQR method
sigma_estimated_full_iqr = self.estimate_sigma_with_full_iqr(S)
# Calculate Marchenko-Pastur threshold
n, m = weights.shape
mp_threshold_full_iqr = self.marchenko_pastur_threshold(
sigma_estimated_full_iqr, n, m
)
# Retain only the singular values above the MP threshold
S_reduced = torch.zeros_like(S)
k = (S > mp_threshold_full_iqr).sum().item()
S_reduced[:k] = S[:k]
print(f"Reduced from {S.shape} to {k}")
# Reconstruct the matrix using the thresholded singular values
reconstructed_weights = U @ torch.diag(S_reduced) @ V
reconstructed_weights = reconstructed_weights.to(original_dtype)
module.weight = torch.nn.Parameter(reconstructed_weights)
self.modified_layers.add(layer_id)
return True
@staticmethod
def marchenko_pastur_threshold(sigma, n, m):
beta = n / m if n < m else m / n
threshold = sigma * np.sqrt((1 + np.sqrt(beta)) ** 2)
return threshold
## Calculate an estimate of the standard deviation of the singular values based on Inter Quantile Range
@staticmethod
def estimate_sigma_with_full_iqr(S):
q75 = torch.quantile(S, 0.75)
q25 = torch.quantile(S, 0.25)
iqr = q75 - q25
sigma_estimated = (
iqr / 1.349
) ## 0.6745 * sigma is the expected range between the quantiles (Q1 and Q3)
return sigma_estimated
def restore_model_original_layer(self, layer_type, layer_number):
layer_id = f"{layer_type}+{layer_number}"
for name, module in self.model.named_modules():
if layer_type in name and layer_number in name:
if name in self.original_weights:
module.weight = torch.nn.Parameter(self.original_weights[name])
print(f"Restored original weights for layer: {name}", flush=True)
if layer_id in self.modified_layers:
self.modified_layers.remove(layer_id)
break
else:
print(f"No original weights saved for layer: {name}", flush=True)
return
def calculate_model_performance(
self,
datasets=["orca_dpo", "ultrafeedback"], # "openhermes"
n_samples=128,
input_length=512,
output_length=512,
):
score_accumulated = 0.0
model = self.model
tokenizer = self.tokenizer
embeddings_model = self.embeddings_model
for dataset in datasets:
benchmark_dataset = get_benchmark_data(
dataset, n_samples, input_length, output_length
)
print("Calculating performance for dataset:", dataset)
for index, sample in enumerate(benchmark_dataset.data):
progress = str(f"{index}/{n_samples}")
print(progress)
prompt = get_llm_prompt(sample.instruction, sample.prompt)
prompt_enc = tokenizer([prompt], return_tensors="pt")
prompt_enc.to("cuda")
model_output = model.generate(
**prompt_enc,
max_new_tokens=self.output_length,
use_cache=False,
output_hidden_states=False,
output_attentions=False,
pad_token_id=tokenizer.eos_token_id,
)
expected_answer = sample.chosen
expected_answer_enc = tokenizer(
[expected_answer],
return_tensors="pt",
padding="max_length",
max_length=self.output_length,
)
expected_answer_enc.to("cuda")
expected_answer_embs = embeddings_model(**expected_answer_enc)
rejected_answer = sample.rejected
rejected_answer_enc = tokenizer(
[rejected_answer],
return_tensors="pt",
padding="max_length",
max_length=self.output_length,
)
rejected_answer_enc.to("cuda")
rejected_answer_embs = embeddings_model(**rejected_answer_enc)
input_length = len(prompt_enc["input_ids"][0])
# Slice the output to remove the input tokens
response_tokens = model_output[0][input_length:]
output_string = tokenizer.decode(
response_tokens, skip_special_tokens=True
)
answer_enc = tokenizer(
[output_string],
return_tensors="pt",
padding="max_length",
max_length=self.output_length,
)
answer_enc.to("cuda")
model_output_embs = embeddings_model(**answer_enc)
cosine_similarity_gain = get_cosine_embeddings(
model_output_embs, expected_answer_embs
)
score_accumulated += cosine_similarity_gain.item()
cosine_similarity_loss = get_cosine_embeddings(
model_output_embs, rejected_answer_embs
)
score_accumulated -= cosine_similarity_loss.item()
del (
answer_enc,
rejected_answer_enc,
expected_answer_enc,
prompt_enc,
model_output_embs,
expected_answer_embs,
rejected_answer_embs,
cosine_similarity_gain,
cosine_similarity_loss,
)
torch.cuda.empty_cache()
performance = score_accumulated / (n_samples * len(datasets))
return performance
def assess_layers_snr(self, layer_types, layer_numbers):
for name, _ in self.model.named_modules():
for layer_number in layer_numbers:
for layer_type in layer_types:
if layer_type in name and str(layer_number) in name:
layer_name = f"{layer_type}+{layer_number}"
print("*" * 50, flush=True)
print(
f"Calculating Signal to Noise Ratio at layer {layer_name}",
flush=True,
)
snr = self.calculate_snr_for_layer(layer_type, layer_number)
self.layer_snr[layer_name] = snr
print(
f"Signal to Noise Ratio at layer {layer_name} = {snr}",
flush=True,
)
print("*" * 50, flush=True)
def select_layers_for_modification(self, k):
sorted_layers = sorted(
self.layer_snr.items(), key=lambda x: x[1], reverse=False
)
return [layer[0] for layer in sorted_layers[:k]]
def test_and_modify_layers(self, candidate_layers):
initial_performance = self.calculate_model_performance()
print(f"Initial Model Performance: {initial_performance}")
for layer in candidate_layers:
# Modify the layer
layer_type = layer.split("+")[0]
layer_number = layer.split("+")[1]
self.update_model_reduce_layer(
layer_type=layer_type, layer_number=layer_number
)
# Test the model's performance
new_performance = self.calculate_model_performance()
print(
f"Tested Model Performance after modifying {layer}: {new_performance}"
)
# If the performance does not improve, revert the change
if new_performance <= initial_performance:
self.restore_model_original_layer(
layer_type=layer_type, layer_number=layer_number
)
print(
f"Reverted changes in {layer} due to lack of improvement.",
flush=True,
)
else:
initial_performance = new_performance
print(
f"Modification kept for {layer}. New baseline performance: {initial_performance}",
flush=True,
)
def save_model(self, save_dir):
self.model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)