forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
EmbeddingBag.cpp
883 lines (789 loc) · 32.8 KB
/
EmbeddingBag.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
#include <TH/THBlasUtils.h>
#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
#else
#include <caffe2/perfkernels/embedding_lookup_idx.h>
#endif
#include <algorithm>
#include <cstring>
#include <iostream>
#include <memory>
#include <sstream>
#include <tuple>
#include <vector>
namespace {
const int MODE_SUM = 0;
const int MODE_MEAN = 1;
const int MODE_MAX = 2;
}
namespace at {
namespace native {
template<typename scalar_t>
scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) {
offset2bag.index_add_(
0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1]
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
}
namespace {
bool isFastPathIndexSelect(const Tensor& src, Tensor& output) {
return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1;
}
bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor& output) {
return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1 && scale.stride(0) == 1;
}
// This function combines index_select (using select_indices as the index) and
// index_add (using add_indices as the index), without creating an intermediary
// tensor to hold the selected embeddings
template<typename T>
void index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* src_data = src.data_ptr<T>();
auto* output_data = output.data_ptr<T>();
auto numel = add_indices.numel();
int64_t ddim = src.size(1);
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
for (int64_t i = 0; i < numel; i++) {
THBlas_axpy<T>(ddim, 1,
src_data + src_stride0 * select_indices_data[i], src_stride1,
output_data + output_stride0 * add_indices_data[i], output_stride1);
}
}
template<>
void index_select_add<float>(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
Tensor &output,
const Tensor& offsets,
bool include_last_offset) {
int64_t ddim = src.size(1);
auto* src_data = src.data_ptr<float>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* output_data = output.data_ptr<float>();
if (isFastPathIndexSelect(src, output)) {
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<int64_t>();
std::vector<int64_t> offsets_include_last;
if (include_last_offset) {
output_size = offsets.numel() - 1;
} else {
output_size = offsets.numel();
offsets_include_last.resize(offsets.numel() + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * offsets.numel());
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#ifdef USE_FBGEMM
auto kernel_fp32_i64 =
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
/* block_size */ddim,
/* has_weight */false,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true
);
#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
#ifdef USE_FBGEMM
kernel_fp32_i64(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
/* input */src_data,
/* indices */select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */offsets_data + start_idx,
/* weights */nullptr,
/* output */output_data + start_idx * ddim);
#else
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/src.size(0),
/*input=*/src_data,
/*indices=*/select_indices_data + offsets_data[start_idx],
/*offsets=*/offsets_data + start_idx,
/*weights=*/nullptr,
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data + start_idx * ddim);
#endif
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
auto numel = add_indices.numel();
for (int64_t i = 0; i < numel; i++) {
THBlas_axpy<float>(
ddim,
1,
src_data + src_stride0 * select_indices_data[i],
src_stride1,
output_data + output_stride0 * add_indices_data[i],
output_stride1);
}
}
}
// This function fuses the following three fns:
// index_select (using select_indices as the index)
// mul (scaling by per_sample_weights)
// index_add (using add_indices as the index)
template<typename T>
static void index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
Tensor &output,
const Tensor& /*offsets*/,
bool /*include_last_offset*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* src_data = src.data_ptr<T>();
auto* output_data = output.data_ptr<T>();
auto numel = add_indices.numel();
int64_t ddim = src.size(1);
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
auto* scale_data = scale.data_ptr<T>();
auto scale_stride = scale.stride(0);
for (int64_t i = 0; i < numel; i++) {
auto* src_base = src_data + src_stride0 * select_indices_data[i];
auto* output_base = output_data + output_stride0 * add_indices_data[i];
auto scale = scale_data[i * scale_stride];
for (int64_t j = 0; j < ddim; j++) {
output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
}
}
}
template<>
void index_select_scale_add<float>(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
Tensor &output,
const Tensor& offsets,
bool include_last_offset) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data_ptr<float>();
auto* select_indices_data = select_indices.data_ptr<int64_t>();
auto* src_data = src.data_ptr<float>();
auto* output_data = output.data_ptr<float>();
if (isFastPathIndexSelectScale(src, scale, output)) {
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.data_ptr<int64_t>();
std::vector<int64_t> offsets_include_last;
if (include_last_offset) {
output_size = offsets.numel() - 1;
} else {
output_size = offsets.numel();
offsets_include_last.resize(offsets.numel() + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * offsets.numel());
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#ifdef USE_FBGEMM
auto kernel_fp32_i64 =
fbgemm::GenerateEmbeddingSpMDM<float, int64_t, int64_t>(
/* block_size */ddim,
/* has_weight */true,
/* normalize_by_lengths */false,
/* prefetch */16,
/* is_weight_positional */false,
/* use_offsets */true
);
#endif
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
#ifdef USE_FBGEMM
kernel_fp32_i64(
/* output_size */end_idx - start_idx,
/* index_size */offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */src.size(0),
/* input */src_data,
/* indices */select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */offsets_data + start_idx,
/* weights */scale_data + offsets_data[start_idx],
/* output */output_data + start_idx * ddim);
#else
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/src.size(0),
/*input=*/src_data,
/*indices=*/select_indices_data + offsets_data[start_idx],
/*offsets=*/offsets_data + start_idx,
/*weights=*/scale_data + offsets_data[start_idx],
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data + start_idx * ddim);
#endif
});
} else {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.data_ptr<int64_t>();
auto src_stride0 = src.stride(0);
auto src_stride1 = src.stride(1);
auto output_stride0 = output.stride(0);
auto output_stride1 = output.stride(1);
auto scale_stride = scale.stride(0);
auto numel = add_indices.numel();
for (int64_t i = 0; i < numel; i++) {
auto* src_base = src_data + src_stride0 * select_indices_data[i];
auto* output_base = output_data + output_stride0 * add_indices_data[i];
auto scale = scale_data[i * scale_stride];
for (int64_t j = 0; j < ddim; j++) {
output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
}
}
}
}
} // namespace
static at::Tensor make_bag_size(
const Tensor& offsets,
const Tensor& indices,
const int64_t mode,
const bool requires_grad) {
at::Tensor bag_size;
if (mode == MODE_MEAN || mode == MODE_MAX) {
bag_size = at::zeros(offsets.sizes(), indices.options());
// Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards)
if (offsets.size(0) != 1) {
bag_size.slice(0, 0, bag_size.size(0) - 1, 1) =
offsets.slice(0, 1, offsets.size(0), 1) -
offsets.slice(0, 0, offsets.size(0) - 1, 1);
}
bag_size[-1] = indices.size(0) - offsets[-1];
} else if (requires_grad) {
// in MODE_SUM, only allocate bag_size if we need gradients
bag_size = at::empty(offsets.sizes(), indices.options());
}
return bag_size;
}
static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &output,
const Tensor &bag_size) {
if (mode == MODE_MEAN) {
// Avoid dividing by 0 for empty bags.
// Instead we want empty bags to return all 0s
if (offsets.size(0) == 1) {
auto bag_size_ = std::max(indices.size(0), static_cast<int64_t>(1));
output /= bag_size_;
} else {
auto bag_size_ = at::max(bag_size, at::ones_like(bag_size, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
.to(output.options())
.unsqueeze(1)
.expand_as(output);
output /= bag_size_;
}
}
return output;
}
static Tensor apply_bag_size_backward(const Tensor &offsets,
const Tensor &indices, const int64_t mode,
Tensor &output, const Tensor &offset2bag,
const Tensor &bag_size) {
if (mode == MODE_MEAN) {
if (offsets.size(0) == 1) {
auto bag_size_ = indices.size(0);
output /= bag_size_;
} else {
auto inv_bag_size_ = (1 / bag_size.to(output.options()))
.unsqueeze(1)
.index_select(0, offset2bag);
output *= inv_bag_size_;
}
}
return output;
}
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor, Tensor> embedding_bag_cpu_max(
const Tensor& weight,
const Tensor& indices,
const Tensor& offset2bag,
const Tensor& output,
const Tensor& bag_size,
const Tensor& offsets) {
auto max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.options());
int64_t numel = indices.numel();
int64_t dims = weight.size(1);
auto* indices_data = indices.data_ptr<int64_t>();
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
auto* max_indices_data = max_indices.data_ptr<int64_t>();
auto max_indices_stride = max_indices.stride(0);
auto* weight_data = weight.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
auto weight_stride0 = weight.stride(0);
auto weight_stride1 = weight.stride(1);
auto output_stride = output.stride(0);
for (int i = 0; i < numel; i++) {
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];
for (int dim = 0; dim < dims; dim++) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item = weight_data[weight_stride0 * word_idx + dim * weight_stride1];
bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag;
if (is_first_for_bag || weight_item > current_item) {
current_item = weight_item;
max_indices_data[max_indices_stride * bag + dim] = word_idx;
}
}
}
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
}
// Assumes all input tensors except for `weight` are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const Tensor& per_sample_weights,
bool include_last_offset,
bool requires_grad) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
auto weight_arg = TensorArg(weight, "weight", 1);
checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
int64_t offset_0 = offsets.data_ptr<int64_t>()[0];
int64_t offset_n = offsets.data_ptr<int64_t>()[offsets.size(0)-1];
TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
"in the mini-batch has to start from position 0. "
"However, got ", offsets[0]);
TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
"be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
offset_n);
if (per_sample_weights.defined()) {
TORCH_CHECK(mode == MODE_SUM,
"embedding_bag: per_sample_weights only supported with mode='sum'");
auto per_input_weights_arg = TensorArg(
per_sample_weights,"per_sample_weights", 1);
checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
TORCH_CHECK(per_sample_weights.dim() == 1);
TORCH_CHECK(per_sample_weights.numel() == indices.numel());
}
auto bag_size = make_bag_size(offsets, indices, mode, requires_grad);
if (include_last_offset) {
TORCH_CHECK(
offsets.size(0) >= 1,
"include_last_offset: number of offset should be at least 1");
}
auto output = at::empty(
{include_last_offset ? offsets.size(0) - 1 : offsets.size(0),
weight.size(1)},
weight.options());
// To save compute, if we are going to go down the fast path case for the 'sum'
// mode, we skip calculating offset2bag, since it is not going to be used.
auto fast_path_sum = [&weight, &per_sample_weights, &output]() {
if (per_sample_weights.defined()) {
return isFastPathIndexSelectScale(weight, per_sample_weights, output);
} else {
return isFastPathIndexSelect(weight, output);
}
};
// Use an empty 0-element tensor as a sentinel that we have skipped the
// creation of offset2bag because autograd chokes when trying to use an
// undefined tensor as an input to a backward op.
Tensor offset2bag = at::empty({0}, offsets.options());
if (mode == MODE_MEAN || mode == MODE_MAX || !fast_path_sum()) {
// If the last entries are empty, that the last offsets are irrelevant as they
// won't change anything in the assignment of ID -> bag, but index_add would
// throw out of bounds error. So to keep it simple we just add one more
// entry to the end then get rid of it after make_offset2bag.
offset2bag = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag);
offset2bag.resize_({indices.sizes()[0]});
// only initialize output in slow path
output.zero_();
}
if (mode == MODE_MEAN || mode == MODE_SUM) {
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
if (per_sample_weights.defined()) {
AT_ASSERT(mode == MODE_SUM);
index_select_scale_add<scalar_t>(
indices, offset2bag, per_sample_weights, weight, output, offsets, include_last_offset);
} else {
index_select_add<scalar_t>(indices, offset2bag, weight, output, offsets, include_last_offset);
}
});
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
return std::tuple<Tensor, Tensor, Tensor, Tensor>(ret, offset2bag, bag_size, bag_size);
} else { // MODE_MAX
at::optional<Tensor> maybe_per_sample_weights;
if (per_sample_weights.defined()) {
maybe_per_sample_weights = per_sample_weights;
}
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.scalar_type(), "embedding_bag_cpu_max", [&]() {
return embedding_bag_cpu_max<scalar_t>(
weight, indices, offset2bag, output, bag_size, offsets);
}
);
}
}
// embedding_bag wrapper to enforce contiguity in tensors other than `weight`.
// This is created to save extra `.contiguous()` call in backward.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
embedding_bag(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
const int64_t mode, bool sparse,
const Tensor &per_sample_weights,
bool include_last_offset) {
if (!weight.requires_grad()) {
return at::_embedding_bag_forward_only(weight, indices.contiguous(), offsets.contiguous(),
scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset);
}
return at::_embedding_bag(weight, indices.contiguous(), offsets.contiguous(),
scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset);
};
// Assumes all input tensors except for `weight` are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cpu(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
const int64_t mode, bool sparse,
const Tensor &per_sample_weights, bool include_last_offset) {
std::ignore = scale_grad_by_freq;
std::ignore = sparse;
return _embedding_bag_cpu_impl(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
/*requires_grad=*/false);
}
// Assumes all input tensors except for `weight` are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const bool scale_grad_by_freq,
const int64_t mode, bool sparse,
const Tensor &per_sample_weights, bool include_last_offset) {
std::ignore = scale_grad_by_freq;
std::ignore = sparse;
return _embedding_bag_cpu_impl(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
/*requires_grad=*/true);
}
// Assumes all input tensors are contiguous.
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
const Tensor &offsets,
const Tensor &offset2bag,
const Tensor &bag_size_,
const Tensor &max_indices_,
int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
bool sparse,
const Tensor& per_sample_weights) {
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkContiguous("embedding_bag", indices_arg);
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
checkContiguous("embedding_bag", offsets_arg);
Tensor offset2bag_;
if (indices.numel() != 0 && offset2bag.numel() == 0) {
offset2bag_ = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkContiguous("embedding_bag", offset2bag_arg);
offset2bag_ = offset2bag;
}
if (sparse) {
return at::_embedding_bag_sparse_backward(
grad, indices, offsets, offset2bag_, bag_size_, num_weights,
scale_grad_by_freq, mode, per_sample_weights);
} else {
return at::_embedding_bag_dense_backward(
grad, indices, offsets, offset2bag_, bag_size_, max_indices_, num_weights,
scale_grad_by_freq, mode, per_sample_weights);
}
}
static Tensor _embedding_bag_dense_backward_cpu_max(
const Tensor& grad,
const Tensor& bag_size,
const Tensor& max_indices,
int64_t num_weights) {
AT_ASSERT(max_indices.defined());
auto index_grad_weight =
at::zeros({num_weights, grad.size(1)}, grad.options());
auto nonempty_max_indices = max_indices.index_select(0, bag_size.nonzero().view(-1));
auto nonempty_grad = grad.index_select(0, bag_size.nonzero().view(-1));
for (int64_t dim = 0; dim < grad.size(1); dim++) {
index_grad_weight.select(1, dim).index_add_(
0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim));
}
return index_grad_weight;
}
static std::vector<int64_t> compute_counts(
int64_t num_weights,
int64_t* indices_data,
int64_t indices_length) {
std::vector<int64_t> counts(num_weights, 0);
for (int i = 0; i < indices_length; i++) {
counts[indices_data[i]]++;
}
return counts;
}
// counts_uniq stores the index of the NEXT unique element
// of the (sorted) indices vector.
//
// For example:
// indices: [0, 0, 0, 1, 3, 3, 4]
// counts: [3, 1, 0, 2, 1, 0]
// counts_uniq: [3, 4, 6, 7]
//
// The unique indices can be found at index 0, 3, 4, 6.
static std::vector<int64_t> compute_counts_uniq(
int64_t num_weights,
int64_t* indices_data,
int64_t indices_length,
const std::vector<int64_t>& counts) {
std::vector<int64_t> counts_uniq;
counts_uniq.reserve(num_weights);
int64_t o = 0;
for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) {
counts_uniq.push_back(counts[indices_data[i]]);
if (o > 0) {
counts_uniq[o] += counts_uniq[o - 1];
}
o++;
}
return counts_uniq;
}
template <typename scalar_t>
void _embedding_bag_dense_backward_cpu_sum_mean(
const Tensor& grad,
const Tensor& indices_,
const Tensor& offsets_,
const Tensor& offset2bag__,
int64_t num_weights,
bool scale_grad_by_freq,
int64_t mode,
const Tensor& per_sample_weights_,
Tensor& index_grad_weight) {
Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__);
auto ind_sort_ = indices_.sort();
auto indices = std::get<0>(ind_sort_);
auto ind_sort = std::get<1>(ind_sort_);
auto offset2bag = offset2bag_.index_select(0, ind_sort);
optional<Tensor> per_sample_weights;
scalar_t* per_sample_weights_data;
optional<int64_t> per_sample_weights_stride;
if (per_sample_weights_.defined()) {
per_sample_weights = per_sample_weights_.index_select(0, ind_sort);
per_sample_weights_data = per_sample_weights->data_ptr<scalar_t>();
per_sample_weights_stride = per_sample_weights->stride(0);
}
auto* indices_data = indices.data_ptr<int64_t>();
auto* offsets_data = offsets_.data_ptr<int64_t>();
auto* offset2bag_data = offset2bag.data_ptr<int64_t>();
int64_t numel = indices.numel();
auto counts = compute_counts(num_weights, indices_data, numel);
auto next_unique_index_idx =
compute_counts_uniq(num_weights, indices_data, numel, counts);
auto loop = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
int64_t start = i == 0 ? 0 : next_unique_index_idx[i - 1];
int64_t index = indices_data[start];
for (int64_t j = start; j < next_unique_index_idx[i]; j++) {
int64_t source = offset2bag_data[j];
double scale = 1.0;
if (per_sample_weights) {
AT_ASSERT(mode == MODE_SUM);
scale = per_sample_weights_data[*per_sample_weights_stride * j];
}
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == 1) { // MODE_MEAN
if (offsets_.size(0) == 1) {
auto bag_size = indices.size(0);
scale /= bag_size;
} else {
if (source == offsets_.size(0) - 1) {
scale /= indices.size(0) - offsets_data[offsets_.size(0) - 1];
} else {
scale /= offsets_data[source + 1] - offsets_data[source];
}
}
}
int64_t ddim = grad.size(1);
auto igwd = index_grad_weight.data_ptr<scalar_t>();
auto gd = grad.data_ptr<scalar_t>();
THBlas_axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
}
};
if (numel > 1000) {
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
} else {
loop(0, (int64_t)next_unique_index_idx.size());
}
}
Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_,
const Tensor &offsets_,
const Tensor &offset2bag__,
const Tensor &bag_size_,
const Tensor& max_indices_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode,
const Tensor& per_sample_weights_) {
// indices_, offsets_ and offset2bag__ are assumed having correct dtypes and
// contiguous here due to the checks in _embedding_bag_backward above.
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
// for more details.
auto grad = grad_.contiguous();
auto grad_arg = TensorArg(grad, "grad_", 1);
checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble});
if (mode == MODE_MAX) {
return _embedding_bag_dense_backward_cpu_max(
grad_, bag_size_, max_indices_, num_weights);
}
AT_ASSERT(mode == MODE_MEAN || mode == MODE_SUM);
auto index_grad_weight =
at::zeros({num_weights, grad.size(1)}, grad.options());
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "embedding_bag_backward", [&] {
_embedding_bag_dense_backward_cpu_sum_mean<scalar_t>(
grad, indices_, offsets_, offset2bag__, num_weights,
scale_grad_by_freq, mode, per_sample_weights_, index_grad_weight);
});
return index_grad_weight;
}
template<typename scalar_t>
Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices,
const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode) {
TORCH_CHECK(
mode == MODE_SUM,
"embedding_bag_backward: per_sample_weights only supported for mode='sum'");
AT_ASSERT(grad.dim() == 2);
auto embedding_features = grad.size(1);
AT_ASSERT(indices.dim() == 1);
auto num_samples = indices.size(0);
AT_ASSERT(weight.dim() == 2);
AT_ASSERT(weight.size(1) == embedding_features);
auto output = at::zeros({num_samples}, grad.options());
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkContiguous("embedding_bag", indices_arg);
Tensor offset2bag_;
if (indices.numel() != 0 && offset2bag.numel() == 0) {
offset2bag_ = at::zeros(
{indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, indices, offset2bag_);
offset2bag_.resize_({indices.sizes()[0]});
} else {
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
checkScalarType("embedding_bag", offset2bag_arg, kLong);
checkContiguous("embedding_bag", offset2bag_arg);
offset2bag_ = offset2bag;
}
auto* grad_data = grad.data_ptr<scalar_t>();
auto grad_stride0 = grad.stride(0);
auto grad_stride1 = grad.stride(1);
auto* weight_data = weight.data_ptr<scalar_t>();
auto weight_stride0 = weight.stride(0);
auto weight_stride1 = weight.stride(1);
auto* indices_data = indices.data_ptr<int64_t>();
// The following are contiguous
auto* output_data = output.data_ptr<scalar_t>();
auto* offset2bag_data = offset2bag_.data_ptr<int64_t>();
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) {
for (int64_t sample_idx = begin; sample_idx < end; sample_idx++) {
auto bag_idx = offset2bag_data[sample_idx];
auto embedding_idx = indices_data[sample_idx];
output_data[sample_idx] = dot_impl<scalar_t>(
embedding_features,
grad_data + grad_stride0 * bag_idx, grad_stride1,
weight_data + weight_stride0 * embedding_idx, weight_stride1);
}
});
return output;
}
Tensor _embedding_bag_per_sample_weights_backward_cpu(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices,
const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode) {
return AT_DISPATCH_FLOATING_TYPES(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu", [&]() {
return _embedding_bag_per_sample_weights_backward_cpu_template<scalar_t>(
grad, weight, indices, offsets, offset2bag, mode);
}
);
}
Tensor _embedding_bag_sparse_backward(
const Tensor &grad_, const Tensor &indices, const Tensor &offsets,
const Tensor &offset2bag, const Tensor &bag_size_, int64_t num_weights,
bool scale_grad_by_freq, int64_t mode, const Tensor& per_sample_weights) {
// indices, offsets and offset2bag are assumed having correct dtypes and
// contiguous here due to the checks in _embedding_bag_backward above.
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
// for more details.
Tensor grad = grad_;
Tensor index_grad = grad_.index_select(0, offset2bag);
index_grad = apply_bag_size_backward(offsets, indices, mode, index_grad,
offset2bag, bag_size_);
if (per_sample_weights.defined()) {
AT_ASSERT(mode == MODE_SUM);
index_grad.mul_(per_sample_weights.unsqueeze(1));
}
return native::embedding_backward(index_grad, indices, num_weights, -1,
scale_grad_by_freq, true);
}
}
} // namespace at::native