forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BinaryOps.cpp
902 lines (758 loc) · 35.4 KB
/
BinaryOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
#include <ATen/native/BinaryOps.h>
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <torch/library.h>
namespace at {
namespace native {
DEFINE_DISPATCH(add_stub);
DEFINE_DISPATCH(add_clamp_stub);
DEFINE_DISPATCH(sub_stub);
DEFINE_DISPATCH(mul_stub);
DEFINE_DISPATCH(div_stub);
DEFINE_DISPATCH(remainder_stub);
DEFINE_DISPATCH(atan2_stub);
DEFINE_DISPATCH(bitwise_and_stub);
DEFINE_DISPATCH(bitwise_or_stub);
DEFINE_DISPATCH(bitwise_xor_stub);
DEFINE_DISPATCH(lshift_stub);
DEFINE_DISPATCH(rshift_stub);
DEFINE_DISPATCH(logical_and_stub);
DEFINE_DISPATCH(logical_or_stub);
DEFINE_DISPATCH(logical_xor_stub);
DEFINE_DISPATCH(lt_stub);
DEFINE_DISPATCH(le_stub);
DEFINE_DISPATCH(gt_stub);
DEFINE_DISPATCH(ge_stub);
DEFINE_DISPATCH(eq_stub);
DEFINE_DISPATCH(ne_stub);
DEFINE_DISPATCH(sigmoid_backward_stub);
DEFINE_DISPATCH(logit_backward_stub);
DEFINE_DISPATCH(tanh_backward_stub);
DEFINE_DISPATCH(max_elementwise_stub);
DEFINE_DISPATCH(min_elementwise_stub);
DEFINE_DISPATCH(fmod_stub);
DEFINE_DISPATCH(fmod_scalar_stub);
DEFINE_DISPATCH(logaddexp_stub);
DEFINE_DISPATCH(logaddexp2_stub);
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
alpha_check(iter.dtype(), alpha);
add_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
return result;
}
Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter.dtype(), alpha);
add_stub(iter.device_type(), iter, alpha);
return iter.output();
}
Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::add_out(self, self, other, alpha);
}
Tensor& add_relu_impl(
Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
Scalar min_val;
Scalar max_val;
if (self.dtype() == at::kInt) {
min_val = 0;
max_val = std::numeric_limits<int32_t>::max();
} else if (self.dtype() == at::kLong) {
min_val = 0;
max_val = std::numeric_limits<int64_t>::max();
} else if (self.dtype() == at::kShort) {
min_val = 0;
max_val = std::numeric_limits<int16_t>::max();
} else if (self.dtype() == at::kChar) {
min_val = 0;
max_val = std::numeric_limits<int8_t>::max();
} else if (self.dtype() == at::kFloat) {
min_val = 0.0;
max_val = std::numeric_limits<float>::max();
} else if (self.dtype() == at::kDouble) {
min_val = 0.0;
max_val = std::numeric_limits<double>::max();
} else {
TORCH_INTERNAL_ASSERT(
"Unsupported datatype for add_relu:", self.dtype().name());
}
result = iter.output();
add_clamp_stub(iter.device_type(), iter, alpha, min_val, max_val);
return result;
}
Tensor& add_relu_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(result, self, other, alpha);
}
Tensor add_relu(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
return add_relu_impl(result, self, other, alpha);
}
Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) {
return add_relu_impl(self, self, other, alpha);
}
Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) {
TORCH_CHECK(false,
"Integer division of tensors using div or / is no longer supported, ",
"and in a future release div will perform true division as in Python 3. ",
"Use true_divide or floor_divide (// in Python) instead.");
}
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
div_stub(iter.device_type(), iter);
return result;
}
Tensor div(const Tensor& self, const Tensor& other) {
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
&& isIntegralType(other.scalar_type(), /*includeBool=*/ true)) {
TORCH_CHECK(false,
"Integer division of tensors using div or / is no longer supported, ",
"and in a future release div will perform true division as in Python 3. ",
"Use true_divide or floor_divide (// in Python) instead.");
}
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& div_(Tensor& self, const Tensor& other) {
return native::div_out(self, self, other);
}
Tensor& remainder_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
remainder_stub(iter.device_type(), iter);
return result;
}
Tensor remainder(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
remainder_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& remainder_(Tensor& self, const Tensor& other) {
return native::remainder_out(self, self, other);
}
Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
// If both inputs have integral (or bool) types, creates
// temporary float copies as new inputs.
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
&& isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) {
const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype());
auto iter = TensorIterator::binary_op(result,
self.to(scalar_type),
divisor.to(scalar_type),
/*check_mem_overlap=*/ true);
div_stub(iter.device_type(), iter);
return result;
}
auto iter = TensorIterator::binary_op(result, self, divisor, /*check_mem_overlap=*/ true);
div_stub(iter.device_type(), iter);
return result;
}
Tensor true_divide(const Tensor& self, const Tensor& divisor) {
// If both inputs have integral (or bool) types, creates
// temporary float copies as new inputs and sets the result's type to
// the default scalar type
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
&& isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) {
const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype());
Tensor result = at::empty({0}, self.options().dtype(scalar_type));
auto iter = TensorIterator::binary_op(result,
self.to(scalar_type),
divisor.to(scalar_type));
div_stub(iter.device_type(), iter);
return result;
}
// If at least one input is non-integral (or bool) participates in
// type promotion like other binary ufuncs
Tensor result;
auto iter = TensorIterator::binary_op(result, self, divisor);
div_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& true_divide_(Tensor& self, const Tensor& divisor) {
return native::true_divide_out(self, self, divisor);
}
Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
div_stub(iter.device_type(), iter);
if (result.is_floating_point()) {
result.trunc_();
}
return result;
}
Tensor floor_divide(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter.device_type(), iter);
auto out = iter.output();
if (out.is_floating_point()) {
out.trunc_();
}
return out;
}
Tensor& floor_divide_(Tensor& self, const Tensor& other) {
return native::floor_divide_out(self, self, other);
}
Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
mul_stub(iter.device_type(), iter);
return result;
}
Tensor mul(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& mul_(Tensor& self, const Tensor& other) {
return native::mul_out(self, self, other);
}
Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
sub_check(self, other);
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
alpha_check(iter.dtype(), alpha);
sub_stub(iter.device_type(), iter, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype());
return result;
}
Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) {
sub_check(self, other);
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter.dtype(), alpha);
sub_stub(iter.device_type(), iter, alpha);
return iter.output();
}
Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub_out(self, self, other, alpha);
}
Tensor& sigmoid_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) {
auto iter = TensorIterator::binary_op(result, grad_output, output);
sigmoid_backward_stub(iter.device_type(), iter);
return result;
}
Tensor sigmoid_backward(const Tensor& grad_output, const Tensor& output) {
Tensor result;
auto iter = TensorIterator::binary_op(result, grad_output, output);
sigmoid_backward_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& logit_backward_out(
Tensor& result,
const Tensor& grad_output,
const Tensor& input,
c10::optional<double> eps) {
auto iter = TensorIterator::binary_op(result, grad_output, input);
logit_backward_stub(
iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0));
return result;
}
Tensor logit_backward(
const Tensor& grad_output,
const Tensor& input,
c10::optional<double> eps) {
Tensor result;
auto iter = TensorIterator::binary_op(result, grad_output, input);
logit_backward_stub(
iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0));
return iter.output();
}
Tensor& tanh_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) {
auto iter = TensorIterator::binary_op(result, grad_output, output);
tanh_backward_stub(iter.device_type(), iter);
return result;
}
Tensor tanh_backward(const Tensor& grad_output, const Tensor& output) {
Tensor result;
auto iter = TensorIterator::binary_op(result, grad_output, output);
tanh_backward_stub(iter.device_type(), iter);
return iter.output();
}
Tensor rsub(const Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub(other, self, alpha);
}
Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
atan2_stub(iter.device_type(), iter);
return result;
}
Tensor atan2(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
return native::atan2_out(result, self, other);
}
Tensor& atan2_(Tensor& self, const Tensor& other) {
return native::atan2_out(self, self, other);
}
// These are still needed because we don't have C++ conversions from number
// types (int, float, etc.) to Tensor (only to Scalar). They're not exposed
// to Python.
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
static void check_convert(Scalar scalar, ScalarType scalarType) {
// Validate that is possible to convert scalar to tensor dtype without overflow
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
scalar.to<scalar_t>();
});
}
static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) {
check_convert(scalar, tensor.scalar_type());
return wrapped_scalar_tensor(scalar);
}
Tensor add(const Tensor& self, Scalar other, Scalar alpha) {
return native::add(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& add_(Tensor& self, Scalar other, Scalar alpha) {
return native::add_(self, wrapped_scalar_tensor(other), alpha);
}
// WARNING: There doesn't appear to be any testing for this function
// with sparse self input.
Tensor div(const Tensor& self, Scalar other) {
return self.div(wrapped_scalar_tensor(other)); // redispatch!
}
// WARNING: This function, with a sparse self, is currently only
// exercised by DistributedDataParallelTest.test_sparse_gradients
// (you need to exercise it from C++, because this overload is never
// used for Python)
Tensor& div_(Tensor& self, Scalar other) {
return self.div_(wrapped_scalar_tensor(other)); // redispatch!
}
Tensor remainder(const Tensor& self, Scalar other) {
Tensor other_tensor = wrapped_scalar_tensor(other);
// FIXME: 'other' is converted to match the dtype of 'self' to retain
// BC with TH, but in the future, we should use normal type promotion,
// like in numpy
return native::remainder(self, other_tensor.toType(self.scalar_type()));
}
Tensor& remainder_(Tensor& self, Scalar other) {
Tensor other_tensor = wrapped_scalar_tensor(other);
// FIXME: 'other' is converted to match the dtype of 'self' to retain
// BC with TH, but in the future, we should use normal type promotion,
// like in numpy
return native::remainder_(self, other_tensor.toType(self.scalar_type()));
}
Tensor& remainder_out(Tensor& result, const Tensor& self, Scalar other) {
Tensor other_tensor = wrapped_scalar_tensor(other);
// FIXME: 'other' is converted to match the dtype of 'self' to retain
// BC with TH, but in the future, we should use normal type promotion,
// like in numpy
return native::remainder_out(result, self, other_tensor.toType(self.scalar_type()));
}
Tensor mul(const Tensor& self, Scalar other) {
return native::mul(self, wrapped_scalar_tensor(other));
}
Tensor& mul_(Tensor& self, Scalar other) {
return native::mul_(self, wrapped_scalar_tensor(other));
}
Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
return native::sub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
return native::sub_(self, wrapped_scalar_tensor(other), alpha);
}
Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
bitwise_and_stub(iter.device_type(), iter);
return result;
}
Tensor bitwise_and(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_and_out(result, self, other);
return result;
}
Tensor& bitwise_and_(Tensor& self, const Tensor& other) {
return at::bitwise_and_out(self, self, other);
}
Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) {
return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
}
Tensor bitwise_and(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_and_out(result, self, other);
}
Tensor& bitwise_and_(Tensor& self, Scalar other) {
return at::bitwise_and_out(self, self, other);
}
// Legacy and interfaces. They are aliased to bitwise_and* functions
Tensor __and__(const Tensor& self, const Tensor& other) {
return at::bitwise_and(self, other);
}
Tensor __and__(const Tensor& self, Scalar other) {
return at::bitwise_and(self, other);
}
Tensor& __iand__(Tensor& self, const Tensor& other) {
return self.bitwise_and_(other);
}
Tensor& __iand__(Tensor& self, Scalar other) {
return self.bitwise_and_(other);
}
Tensor& bitwise_or_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
bitwise_or_stub(iter.device_type(), iter);
return result;
}
Tensor bitwise_or(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_or_out(result, self, other);
return result;
}
Tensor& bitwise_or_(Tensor& self, const Tensor& other) {
return at::bitwise_or_out(self, self, other);
}
Tensor& bitwise_or_out(Tensor& result, const Tensor& self, Scalar other) {
return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other));
}
Tensor bitwise_or(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_or_out(result, self, other);
}
Tensor& bitwise_or_(Tensor& self, Scalar other) {
return at::bitwise_or_out(self, self, other);
}
// Legacy or interfaces. They are aliased to bitwise_or* functions
Tensor __or__(const Tensor& self, const Tensor& other) {
return at::bitwise_or(self, other);
}
Tensor __or__(const Tensor& self, Scalar other) {
return at::bitwise_or(self, other);
}
Tensor& __ior__(Tensor& self, const Tensor& other) {
return self.bitwise_or_(other);
}
Tensor& __ior__(Tensor& self, Scalar other) {
return self.bitwise_or_(other);
}
Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
bitwise_xor_stub(iter.device_type(), iter);
return result;
}
Tensor bitwise_xor(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_xor_out(result, self, other);
return result;
}
Tensor& bitwise_xor_(Tensor& self, const Tensor& other) {
return at::bitwise_xor_out(self, self, other);
}
Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, Scalar other) {
return at::bitwise_xor_out(result, self, wrapped_scalar_tensor(other));
}
Tensor bitwise_xor(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_xor_out(result, self, other);
}
Tensor& bitwise_xor_(Tensor& self, Scalar other) {
return at::bitwise_xor_out(self, self, other);
}
// Legacy xor interfaces. They are aliased to bitwise_xor* functions
Tensor __xor__(const Tensor& self, const Tensor& other) {
return at::bitwise_xor(self, other);
}
Tensor __xor__(const Tensor& self, Scalar other) {
return at::bitwise_xor(self, other);
}
Tensor& __ixor__(Tensor& self, const Tensor& other) {
return self.bitwise_xor_(other);
}
Tensor& __ixor__(Tensor& self, Scalar other) {
return self.bitwise_xor_(other);
}
Tensor __lshift__(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
lshift_stub(iter.device_type(), iter);
return iter.output();
}
Tensor __lshift__(const Tensor& self, Scalar other) {
Tensor result;
auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
auto iter = TensorIterator::binary_op(result, self, wrapper);
lshift_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& __ilshift__(Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(self, self, other);
lshift_stub(iter.device_type(), iter);
return self;
}
Tensor& __ilshift__(Tensor& self, Scalar other) {
auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
auto iter = TensorIterator::binary_op(self, self, wrapper);
lshift_stub(iter.device_type(), iter);
return self;
}
Tensor __rshift__(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
rshift_stub(iter.device_type(), iter);
return iter.output();
}
Tensor __rshift__(const Tensor& self, Scalar other) {
Tensor result;
auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
auto iter = TensorIterator::binary_op(result, self, wrapper);
rshift_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& __irshift__(Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(self, self, other);
rshift_stub(iter.device_type(), iter);
return self;
}
Tensor& __irshift__(Tensor& self, Scalar other) {
auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type());
auto iter = TensorIterator::binary_op(self, self, wrapper);
rshift_stub(iter.device_type(), iter);
return self;
}
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
// Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow
if (self.scalar_type() != other.scalar_type()) {
if (self.dim() != 0 && other.dim() == 0) {
check_convert(other.item(), self.scalar_type());
} else if (self.dim() == 0 && other.dim() != 0) {
check_convert(self.item(), other.scalar_type());
}
}
auto iter = TensorIterator::comparison_op(result, self, other, /*check_mem_overlap=*/true);
stub(iter.device_type(), iter);
return result;
}
template <typename OutImpl>
Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return out_impl(result, self, other);
}
// To avoid overflow during type promotion we will check that both dtypes of self and other are same
template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
return out_impl(self, self, other);
}
// validates that is possible to convert Scalar other to self's dtype without overflow.
// This behavior is unique to comparison ops; arithmetic operations don't do this.
// In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops.
template <typename OutImpl>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, OutImpl& out_impl) {
return out_impl(result, self, wrapped_scalar_tensor_and_check_convert(other, self));
}
template <typename OutImpl>
Tensor comparison_op(const Tensor& self, Scalar other, OutImpl& out_impl) {
return comparison_op(self, wrapped_scalar_tensor_and_check_convert(other, self), out_impl);
}
template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, Scalar other, OutImpl& out_impl) {
return out_impl(self, self, wrapped_scalar_tensor_and_check_convert(other, self));
}
// We need explicit cast to OutFunc because each *_out func is overloaded twice. Without An explicit cast, merely
// referring to *_out function is ambiguious.
using OutFunc = std::add_const<Tensor&(&)(Tensor&, const Tensor&, const Tensor&)>::type;
Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); }
Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::lt_out)); }
Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::lt_out)); }
Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::lt_out)); }
Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::lt_out)); }
Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::lt_out)); }
Tensor& le_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, le_stub); }
Tensor le(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::le_out)); }
Tensor& le_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::le_out)); }
Tensor& le_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::le_out)); }
Tensor le(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::le_out)); }
Tensor& le_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::le_out)); }
Tensor& gt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, gt_stub); }
Tensor gt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::gt_out)); }
Tensor& gt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::gt_out)); }
Tensor& gt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::gt_out)); }
Tensor gt(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::gt_out)); }
Tensor& gt_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::gt_out)); }
Tensor& ge_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ge_stub); }
Tensor ge(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::ge_out)); }
Tensor& ge_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ge_out)); }
Tensor& ge_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::ge_out)); }
Tensor ge(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::ge_out)); }
Tensor& ge_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ge_out)); }
Tensor& eq_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, eq_stub); }
Tensor eq(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::eq_out)); }
Tensor& eq_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::eq_out)); }
Tensor& eq_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::eq_out)); }
Tensor eq(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::eq_out)); }
Tensor& eq_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::eq_out)); }
Tensor& ne_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ne_stub); }
Tensor ne(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::ne_out)); }
Tensor& ne_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ne_out)); }
Tensor& ne_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::ne_out)); }
Tensor ne(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::ne_out)); }
Tensor& ne_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ne_out)); }
Tensor& logical_and_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_and_stub); }
Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor logical_and(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_and_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); }
Tensor& logical_or_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_or_stub); }
Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor logical_or(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_or_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); }
Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_xor_stub); }
Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); }
Tensor& max_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(!other.is_complex(), "max is not yet implemented for complex tensors.");
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
max_elementwise_stub(iter.device_type(), iter);
return result;
}
Tensor max(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex(), "max is not yet implemented for complex tensors.");
TORCH_CHECK(!other.is_complex(), "max is not yet implemented for complex tensors.");
Tensor result = at::empty(0, self.options());
return at::max_out(result, self, other);
}
Tensor& max_(Tensor& self, const Tensor& other) { return at::max_out(self, self, other); }
Tensor& min_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
TORCH_CHECK(!other.is_complex(), "min is not yet implemented for complex tensors.");
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
min_elementwise_stub(iter.device_type(), iter);
return result;
}
Tensor min(const Tensor& self, const Tensor& other) {
TORCH_CHECK(!self.is_complex(), "min is not yet implemented for complex tensors.");
TORCH_CHECK(!other.is_complex(), "min is not yet implemented for complex tensors.");
Tensor result = at::empty(0, self.options());
return at::min_out(result, self, other);
}
Tensor& min_(Tensor& self, const Tensor& other) { return at::min_out(self, self, other); }
Tensor floor_divide(const Tensor& self, Scalar other) {
return at::floor_divide(self, wrapped_scalar_tensor(other));
}
Tensor& floor_divide_(Tensor& self, Scalar other) {
return at::floor_divide_out(self, self, wrapped_scalar_tensor(other));
}
Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_stub(iter.device_type(), iter);
return result;
}
Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) {
auto iter = TensorIterator::unary_op(result, self,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_scalar_stub(iter.device_type(), iter, other);
return result;
}
Tensor fmod(const Tensor& self, const Tensor & other) {
Tensor result = at::empty({0}, self.options());
return at::fmod_out(result, self, other);
}
Tensor fmod(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::fmod_out(result, self, other);
}
Tensor& fmod_(Tensor& self, const Tensor& other) {
return at::fmod_out(self, self, other);
}
Tensor& fmod_(Tensor& self, Scalar other) {
return at::fmod_out(self, self, other);
}
Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true);
logaddexp_stub(iter.device_type(), iter);
return result;
}
Tensor logaddexp(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
return at::logaddexp_out(result, self, other);
}
Tensor& logaddexp2_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true);
logaddexp2_stub(iter.device_type(), iter);
return result;
}
Tensor logaddexp2(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
return at::logaddexp2_out(result, self, other);
}
Tensor true_divide(const Tensor& self, Scalar divisor) {
return self.true_divide(wrapped_scalar_tensor(divisor)); // redispatch!
}
Tensor& true_divide_(Tensor& self, Scalar divisor) {
return self.true_divide_(wrapped_scalar_tensor(divisor)); // redispatch!
}
// Note: this function is only for testing.
// It is undocumented and should not be used outside of tests.
Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scalar alpha) {
return self - (other * alpha);
}
// TODO: Deduplicate this with the TensorIterator logic. This would
// also fix the TODOs below.
Tensor binary_op_meta(const Tensor& self, const Tensor& other) {
// TODO: Doesn't do type promotion correctly
// TODO: Doesn't do strides correctly
int64_t dim = std::max(self.dim(), other.dim());
std::vector<int64_t> sizes(dim);
for (int64_t i = 0; i < dim; i++) {
int64_t j = -1 - i;
if (i >= self.dim() || self.size(j) == 1) {
sizes[dim + j] = other.size(j);
} else if (i >= other.dim() || self.size(i) == 1) {
sizes[dim + j] = self.size(j);
} else {
TORCH_CHECK(
self.size(j) == other.size(j),
"Expected self.size(", j, ") == other.size(", j, "), but got ", self.size(j), " != ", other.size(j)
);
sizes[dim + j] = self.size(j);
}
}
return at::empty_meta(sizes, self.options());
}
Tensor binary_op_with_scalar_meta(const Tensor& self, const Tensor& other, Scalar x) {
return binary_op_meta(self, other);
}
TORCH_LIBRARY_IMPL(aten, Meta, m) {
m.impl("add.Tensor", binary_op_with_scalar_meta);
}
} // namespace native
} // namespace at