diff --git a/README.md b/README.md index 0e231e5ac..31a7e1ecb 100644 --- a/README.md +++ b/README.md @@ -298,7 +298,7 @@ torch.permute "dims": "perm" } }, - "unsupport_args": {}, + "unsupport_args": [], "paddle_default_kwargs": {} } ``` diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 837d6c7fb..bdb1356b5 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -2747,6 +2747,20 @@ "input": "x" } }, + "torch.nn.functional.soft_margin_loss": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.functional.soft_margin_loss", + "args_list": [ + "input", + "target", + "size_average", + "reduce", + "reduction" + ], + "kwargs_change": { + "target": "label" + } + }, "torch.vander": {}, "torch.cross": { "Matcher": "GenericMatcher", @@ -6069,33 +6083,27 @@ "paddle_api": "paddle.device.get_device" }, "torch.Tensor.gt": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.greater_than", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.greater": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.greater_than", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.heaviside": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.heaviside", "args_list": [ - "values" + "data" ], "kwargs_change": { - "other": "y" + "data": "y" } }, "torch.Tensor.index_select": { @@ -6182,24 +6190,18 @@ } }, "torch.Tensor.le": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.less_equal", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.less_equal": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.less_equal", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.lerp": { "Matcher": "GenericMatcher", @@ -6254,38 +6256,29 @@ } }, "torch.Tensor.logical_and": { - "Matcher": "GenericMatcher", + "Matcher": "TensorLogicalMatcher", "paddle_api": "paddle.Tensor.logical_and", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.logical_not": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.logical_not" }, "torch.Tensor.logical_or": { - "Matcher": "GenericMatcher", + "Matcher": "TensorLogicalMatcher", "paddle_api": "paddle.Tensor.logical_or", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.logical_xor": { - "Matcher": "GenericMatcher", + "Matcher": "TensorLogicalMatcher", "paddle_api": "paddle.Tensor.logical_xor", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.logit": { "Matcher": "GenericMatcher", @@ -6295,24 +6288,18 @@ ] }, "torch.Tensor.lt": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.less_than", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.less": { - "Matcher": "GenericMatcher", + "Matcher": "Num2TensorBinaryMatcher", "paddle_api": "paddle.Tensor.less_than", "args_list": [ "other" - ], - "kwargs_change": { - "other": "y" - } + ] }, "torch.Tensor.lu": { "Matcher": "GenericMatcher", @@ -8397,7 +8384,7 @@ ] }, "torch.nn.LayerNorm": { - "Matcher": "LayerNormMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.LayerNorm", "args_list": [ "normalized_shape", @@ -8405,10 +8392,18 @@ "elementwise_affine", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "", + "elementwise_affine": [ + "weight_attr", + "bias_attr" + ] + } }, "torch.nn.GroupNorm": { - "Matcher": "GroupNormMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.GroupNorm", "args_list": [ "num_groups", @@ -8417,7 +8412,15 @@ "affine", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "", + "affine": [ + "weight_attr", + "bias_attr" + ] + } }, "torch.nn.BatchNorm1d": { "Matcher": "BatchNormMatcher", @@ -8430,7 +8433,11 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "" + } }, "torch.nn.BatchNorm2d": { "Matcher": "BatchNormMatcher", @@ -8443,7 +8450,11 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "" + } }, "torch.nn.BatchNorm3d": { "Matcher": "BatchNormMatcher", @@ -8456,7 +8467,11 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "" + } }, "torch.nn.SyncBatchNorm": { "Matcher": "BatchNormMatcher", @@ -8470,7 +8485,11 @@ "process_group", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "" + } }, "torch.nn.SyncBatchNorm.convert_sync_batchnorm": { "Matcher": "GenericMatcher", @@ -8487,7 +8506,7 @@ ] }, "torch.nn.RNNCell": { - "Matcher": "RNNCellMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.SimpleRNNCell", "args_list": [ "input_size", @@ -8496,10 +8515,19 @@ "nonlinearity", "device", "dtype" - ] + ], + "kwargs_change": { + "dtype": "", + "device": "", + "nonlinearity": "activation", + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + } }, "torch.nn.LSTMCell": { - "Matcher": "RNNCellMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.LSTMCell", "args_list": [ "input_size", @@ -8507,10 +8535,18 @@ "bias", "device", "dtype" - ] + ], + "kwargs_change": { + "dtype": "", + "device": "", + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + } }, "torch.nn.GRUCell": { - "Matcher": "RNNCellMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.GRUCell", "args_list": [ "input_size", @@ -8518,7 +8554,16 @@ "bias", "device", "dtype" - ] + ], + "kwargs_change": { + "dtype": "", + "device": "", + "nonlinearity": "activation", + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + } }, "torch.nn.RNN": { "Matcher": "RNNMatcher", @@ -8532,6 +8577,16 @@ "batch_first", "dropout", "bidirectional" + ], + "kwargs_change": { + "nonlinearity": "activation", + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + }, + "unsupport_args": [ + "proj_size" ] }, "torch.nn.LSTM": { @@ -8546,6 +8601,16 @@ "dropout", "bidirectional", "proj_size" + ], + "kwargs_change": { + "nonlinearity": "activation", + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + }, + "unsupport_args": [ + "proj_size" ] }, "torch.nn.GRU": { @@ -8559,6 +8624,15 @@ "batch_first", "dropout", "bidirectional" + ], + "kwargs_change": { + "bias": [ + "bias_ih_attr", + "bias_hh_attr" + ] + }, + "unsupport_args": [ + "proj_size" ] }, "torch.nn.MultiheadAttention": { @@ -8858,7 +8932,15 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "", + "affine": [ + "weight_attr", + "bias_attr" + ] + } }, "torch.nn.InstanceNorm2d": { "Matcher": "InstanceNormMatcher", @@ -8871,7 +8953,15 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "", + "affine": [ + "weight_attr", + "bias_attr" + ] + } }, "torch.nn.InstanceNorm3d": { "Matcher": "InstanceNormMatcher", @@ -8884,7 +8974,15 @@ "track_running_stats", "device", "dtype" - ] + ], + "kwargs_change": { + "eps": "epsilon", + "dtype": "", + "affine": [ + "weight_attr", + "bias_attr" + ] + } }, "torch.nn.BCEWithLogitsLoss": { "Matcher": "SizeAverageMatcher", @@ -8897,6 +8995,23 @@ "pos_weight" ] }, + "torch.nn.functional.multi_margin_loss": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.functional.multi_margin_loss", + "args_list": [ + "input", + "target", + "p", + "margin", + "weight", + "size_average", + "reduce", + "reduction" + ], + "kwargs_change": { + "target": "label" + } + }, "torch.utils.data.BatchSampler": { "Matcher": "TorchUtilDataBatchSampler", "args_list": [ @@ -8953,6 +9068,19 @@ "eps": "epsilon" } }, + "torch.lu": { + "Matcher": "LuMatcher", + "paddle_api": "paddle.linalg.lu", + "args_list": [ + "A", + "pivot", + "get_infos", + "out" + ], + "kwargs_change": { + "A": "x" + } + }, "torch.autograd.grad": { "Matcher": "GenericMatcher", "paddle_api": "paddle.grad", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index af596a8d0..7442382cd 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -33,12 +33,15 @@ def generate_code(self, kwargs): if "kwargs_change" in self.api_mapping: kwargs_change = self.api_mapping["kwargs_change"] new_kwargs = {} - dtype_v = None for k in list(kwargs.keys()): if k in kwargs_change: if kwargs_change[k]: # rename/copy in new_kwargs - new_kwargs[kwargs_change[k]] = kwargs.pop(k) + if isinstance(kwargs_change[k], list): + for v in kwargs_change[k]: + new_kwargs[v] = kwargs[k] + else: + new_kwargs[kwargs_change[k]] = kwargs[k] else: # remove in new_kwargs kwargs.pop(k) @@ -54,14 +57,17 @@ def generate_code(self, kwargs): "inplace", "generator", "non_blocking", + "async", ]: new_kwargs.pop(k) continue - if k == "dtype": - dtype_v = new_kwargs.pop("dtype") new_kwargs = self.set_paddle_default_kwargs(new_kwargs) + dtype_v = None + if "dtype" in new_kwargs: + dtype_v = new_kwargs.pop("dtype") + pin_memory_v = False if "pin_memory" in new_kwargs: pin_memory_v = eval(new_kwargs.pop("pin_memory")) @@ -70,7 +76,9 @@ def generate_code(self, kwargs): if "requires_grad" in new_kwargs: stop_gradient_v = "not " + new_kwargs.pop("requires_grad").strip("()") - out_v = new_kwargs.pop("out") if "out" in new_kwargs else None + out_v = None + if "out" in new_kwargs: + out_v = new_kwargs.pop("out") res = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs)) @@ -892,114 +900,33 @@ def generate_code(self, kwargs): return code.strip("\n") -class LayerNormMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "eps" not in kwargs: - epsilon = 1e-5 - else: - epsilon = kwargs["eps"] - - if "elementwise_affine" in kwargs and "False" in kwargs["elementwise_affine"]: - API_TEMPLATE = textwrap.dedent( - """ - paddle.nn.LayerNorm(normalized_shape={}, - epsilon={}, - weight_attr=False, - bias_attr=False) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - paddle.nn.LayerNorm(normalized_shape={}, - epsilon={}, - weight_attr=None, - bias_attr=None) - """ - ) - code = API_TEMPLATE.format(kwargs["normalized_shape"], epsilon) - return code - - -class GroupNormMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "eps" not in kwargs: - epsilon = 1e-5 - else: - epsilon = kwargs["eps"] - - if "affine" in kwargs and "False" in kwargs["affine"]: - API_TEMPLATE = textwrap.dedent( - """ - paddle.nn.GroupNorm(num_groups={}, - num_channels={}, - epsilon={}, - weight_attr=False, - bias_attr=False) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - paddle.nn.GroupNorm(num_groups={}, - num_channels={}, - epsilon={}, - weight_attr=None, - bias_attr=None) - """ - ) - code = API_TEMPLATE.format( - kwargs["num_groups"], kwargs["num_channels"], epsilon - ) - return code - - class BatchNormMatcher(BaseMatcher): def generate_code(self, kwargs): - if "eps" not in kwargs: - epsilon = 1e-5 - else: - epsilon = kwargs["eps"] - + if "dtype" in kwargs: + kwargs.pop("dtype") if "track_running_stats" in kwargs: track_running_stats = kwargs["track_running_stats"] + kwargs.pop("track_running_stats") else: track_running_stats = True - + kwargs["use_global_stats"] = track_running_stats if "momentum" in kwargs: - momentum = kwargs["momentum"] + momentum = f"1 - {kwargs['momentum']}" else: momentum = 0.1 + if "affine" not in kwargs: + kwargs["weight_attr"] = None + kwargs["bias_attr"] = None + else: + kwargs[ + "weight_attr" + ] = f"None if ({kwargs['affine']} is None or {kwargs['affine']}) else False" + kwargs[ + "bias_attr" + ] = f"None if ({kwargs['affine']} is None or {kwargs['affine']}) else False" + kwargs.pop("affine") - if "affine" in kwargs and "False" in kwargs["affine"]: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=False, - bias_attr=False, - use_global_stats={}) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=None, - bias_attr=None, - use_global_stats={}) - """ - ) - code = API_TEMPLATE.format( - self.get_paddle_api(), - kwargs["num_features"], - momentum, - epsilon, - track_running_stats, - ) + code = GenericMatcher.generate_code(self, kwargs) return code @@ -1734,41 +1661,12 @@ def generate_code(self, kwargs): class InstanceNormMatcher(BaseMatcher): def generate_code(self, kwargs): - if "eps" not in kwargs: - epsilon = 1e-5 - else: - epsilon = kwargs["eps"] - if "momentum" in kwargs: momentum = kwargs["momentum"] else: momentum = 0.1 - - if "affine" in kwargs and "True" in kwargs["affine"]: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=None, - bias_attr=None) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=False, - bias_attr=False) - """ - ) - - code = API_TEMPLATE.format( - self.get_paddle_api(), kwargs["num_features"], momentum, epsilon - ) - return code + kwargs["momentum"] = f"1-{momentum}" + return GenericMatcher.generate_code(self, kwargs) class GeneratorMatcher(BaseMatcher): @@ -3259,77 +3157,21 @@ def generate_code(self, kwargs): return code -class RNNCellMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "dtype" in kwargs: - return None - - if "nonlinearity" in kwargs: - kwargs["activation"] = kwargs.pop("nonlinearity") - - if "device" in kwargs: - kwargs.pop("device") - - if "bias" in kwargs and "False" in kwargs["bias"]: - API_TEMPLATE = textwrap.dedent( - """ - {}({}, - bias_ih_attr=False, - bias_hh_attr=False) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - {}({}) - """ - ) - if "bias" in kwargs: - kwargs.pop("bias") - code = API_TEMPLATE.format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) - return code - - class RNNMatcher(BaseMatcher): def generate_code(self, kwargs): - if "proj_size" in kwargs: - return None - if "batch_first" in kwargs: batch_first = kwargs.pop("batch_first") else: batch_first = False + kwargs["time_major"] = f"not {batch_first}" - if "nonlinearity" in kwargs: - kwargs["activation"] = kwargs.pop("nonlinearity") - - direction = "'forward'" + kwargs["direction"] = "'forward'" if "bidirectional" in kwargs: if "True" in kwargs["bidirectional"]: direction = "'bidirect'" kwargs.pop("bidirectional") - if "bias" in kwargs and "False" in kwargs["bias"]: - API_TEMPLATE = textwrap.dedent( - """ - {}({}, direction={}, time_major= not {}, - bias_ih_attr=False, - bias_hh_attr=False) - """ - ) - else: - API_TEMPLATE = textwrap.dedent( - """ - {}({}, direction={}, time_major= not {}) - """ - ) - - if "bias" in kwargs: - kwargs.pop("bias") - code = API_TEMPLATE.format( - self.get_paddle_api(), self.kwargs_to_str(kwargs), direction, batch_first - ) - return code + return GenericMatcher.generate_code(self, kwargs) class DiffMatcher(BaseMatcher): @@ -3369,52 +3211,30 @@ def get_paddle_nodes(self, args, kwargs): class Modules_BatchNormBaseMatcher(BaseMatcher): def generate_code(self, kwargs): - if "eps" not in kwargs: - epsilon = 1e-5 - else: - epsilon = kwargs["eps"] - if "track_running_stats" in kwargs: track_running_stats = kwargs["track_running_stats"] else: track_running_stats = True + kwargs["use_global_stats"] = track_running_stats if "momentum" in kwargs: momentum = kwargs["momentum"] else: momentum = 0.1 + kwargs["momentum"] = f"1-{momentum}" - if "affine" in kwargs and "False" in kwargs["affine"]: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=False, - bias_attr=False, - use_global_stats={}) - """ - ) + if "affine" not in kwargs: + kwargs["weight_attr"] = None + kwargs["bias_attr"] = None else: - API_TEMPLATE = textwrap.dedent( - """ - {}(num_features={}, - momentum=1-{}, - epsilon={}, - weight_attr=None, - bias_attr=None, - use_global_stats={}) - """ - ) - code = API_TEMPLATE.format( - self.get_paddle_api(), - kwargs["num_features"], - momentum, - epsilon, - track_running_stats, - ) + kwargs[ + "weight_attr" + ] = f"None if ({kwargs['affine']} is None or {kwargs['affine']}) else False" + kwargs[ + "bias_attr" + ] = f"None if ({kwargs['affine']} is None or {kwargs['affine']}) else False" - return code + return GenericMatcher.generate_code(self, kwargs) class TensorTakeMatcher(BaseMatcher): @@ -3707,6 +3527,47 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class LuMatcher(BaseMatcher): + def generate_code(self, kwargs): + out_v = kwargs.pop("out") if "out" in kwargs else None + + if out_v: + + out_3_var = "get_infos" in kwargs and kwargs["get_infos"] != "(False)" + new_kwargs = {} + new_kwargs["x"] = kwargs.pop("A") + new_kwargs.update(kwargs) + + if out_3_var: + API_TEMPLATE = textwrap.dedent( + """ + tmp_lu, tmp_p, tmp_info = {}({}) + paddle.assign(tmp_lu, {}[0]), paddle.assign(tmp_p, {}[1]), paddle.assign(tmp_info, {}[2]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), + self.kwargs_to_str(new_kwargs), + out_v, + out_v, + out_v, + ) + else: + API_TEMPLATE = textwrap.dedent( + """ + tmp_lu, tmp_p = {}({}) + paddle.assign(tmp_lu, {}[0]), paddle.assign(tmp_p, {}[1]) + """ + ) + code = API_TEMPLATE.format( + self.get_paddle_api(), self.kwargs_to_str(new_kwargs), out_v, out_v + ) + + return code + + return GenericMatcher.generate_code(self, kwargs) + + class RandomSplitMatcher(BaseMatcher): def generate_code(self, kwargs): API_TEMPLATE = textwrap.dedent( @@ -3761,6 +3622,16 @@ def generate_code(self, kwargs): return code +class TensorLogicalMatcher(BaseMatcher): + def generate_code(self, kwargs): + + code = "{}(y=({}).astype(({}).dtype))".format( + self.get_paddle_api(), kwargs["other"], self.paddleClass + ) + + return code + + class TensorDatasetMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): new_args = self.parse_args(args) diff --git a/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm1D.py b/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm1D.py index bc6f4e8bd..663cb7ab8 100644 --- a/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm1D.py +++ b/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm1D.py @@ -1,10 +1,11 @@ import paddle print('#########################case1#########################') -paddle.nn.BatchNorm1D(num_features=3, momentum=1 - 0.01, epsilon=0.001, - weight_attr=None, bias_attr=None, use_global_stats=True) +paddle.nn.BatchNorm1D(num_features=3, epsilon=0.001, momentum=0.01, + use_global_stats=True, weight_attr=None, bias_attr=None) print('#########################case2#########################') -bn = paddle.nn.BatchNorm1D(num_features=27, momentum=1 - 0.1, epsilon=1e-05, - weight_attr=None, bias_attr=None, use_global_stats=True) +bn = paddle.nn.BatchNorm1D(num_features=27, use_global_stats=True, + weight_attr=None, bias_attr=None) print('#########################case3#########################') -paddle.nn.BatchNorm1D(num_features=10, momentum=1 - 0.1, epsilon=1e-05, - weight_attr=False, bias_attr=False, use_global_stats=True) +paddle.nn.BatchNorm1D(num_features=10, epsilon=1e-05, use_global_stats=True, + weight_attr=None if False is None or False else False, bias_attr=None if + False is None or False else False) diff --git a/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm2D.py b/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm2D.py index acacec9f1..3fc24110e 100644 --- a/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm2D.py +++ b/tests/code_library/code_case/paddle_code/api_paddle_nn_BatchNorm2D.py @@ -1,10 +1,11 @@ import paddle print('#########################case1#########################') -bn = paddle.nn.BatchNorm2D(num_features=5, momentum=1 - 0.1, epsilon=1e-05, - weight_attr=None, bias_attr=None, use_global_stats=True) +bn = paddle.nn.BatchNorm2D(num_features=5, use_global_stats=True, + weight_attr=None, bias_attr=None) print('#########################case2#########################') -bn = paddle.nn.BatchNorm2D(num_features=27, momentum=1 - 0.1, epsilon=1e-05, - weight_attr=None, bias_attr=None, use_global_stats=True) +bn = paddle.nn.BatchNorm2D(num_features=27, use_global_stats=True, + weight_attr=None, bias_attr=None) print('#########################case3#########################') -paddle.nn.BatchNorm2D(num_features=10, momentum=1 - 0.1, epsilon=1e-05, - weight_attr=False, bias_attr=False, use_global_stats=True) +paddle.nn.BatchNorm2D(num_features=10, epsilon=1e-05, use_global_stats=True, + weight_attr=None if False is None or False else False, bias_attr=None if + False is None or False else False) diff --git a/tests/code_library/code_case/paddle_code/api_paddle_nn_InstanceNorm3D.py b/tests/code_library/code_case/paddle_code/api_paddle_nn_InstanceNorm3D.py index 9a8b75e02..1a09fcdc1 100644 --- a/tests/code_library/code_case/paddle_code/api_paddle_nn_InstanceNorm3D.py +++ b/tests/code_library/code_case/paddle_code/api_paddle_nn_InstanceNorm3D.py @@ -1,26 +1,25 @@ import paddle print('#########################case1#########################') -m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1, epsilon= - 1e-05, weight_attr=False, bias_attr=False) +m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1) input = paddle.randn(shape=[20, 100, 35, 45, 10]) output = m(input) print('#########################case2#########################') -m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1, epsilon= - 1e-05, weight_attr=None, bias_attr=None) +m = paddle.nn.InstanceNorm3D(num_features=100, weight_attr=True, bias_attr= + True, momentum=1 - 0.1) input = paddle.randn(shape=[20, 100, 35, 45, 10]) output = m(input) print('#########################case3#########################') -m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1, epsilon= - 1e-05, weight_attr=False, bias_attr=False) +m = paddle.nn.InstanceNorm3D(num_features=100, weight_attr=False, bias_attr + =False, momentum=1 - 0.1) input = paddle.randn(shape=[20, 100, 35, 45, 10]) output = m(input) print('#########################case4#########################') -m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1, epsilon= - 1e-05, weight_attr=None, bias_attr=None) +m = paddle.nn.InstanceNorm3D(num_features=100, weight_attr=True, bias_attr= + True, momentum=1 - 0.1) input = paddle.randn(shape=[20, 100, 35, 45, 10]) output = m(input) print('#########################case5#########################') -m = paddle.nn.InstanceNorm3D(num_features=100, momentum=1 - 0.1, epsilon= - 1e-05, weight_attr=False, bias_attr=False) +m = paddle.nn.InstanceNorm3D(num_features=100, weight_attr=False, bias_attr + =False, momentum=1 - 0.1) input = paddle.randn(shape=[20, 100, 35, 45, 10]) output = m(input) diff --git a/tests/test_Tensor_ger.py b/tests/test_Tensor_ger.py new file mode 100644 index 000000000..d6aeb1d72 --- /dev/null +++ b/tests/test_Tensor_ger.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.ger") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2, 3]) + y = torch.tensor([1., 2, 3, 4]) + result = x.ger(y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2, 3]) + y = torch.tensor([1., 2, 3, 4]) + result = x.ger(vec2=y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2, 3]).ger(torch.tensor([1., 2, 3, 4])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# The paddle input does not support integer type +def _test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3]) + y = torch.tensor([1, 2, 3, 4]) + result = x.ger(y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# The paddle other does not support integer type +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3.]) + y = torch.tensor([1, 2, 3, 4]) + result = x.ger(y) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_greater.py b/tests/test_Tensor_greater.py new file mode 100644 index 000000000..faf12f57e --- /dev/null +++ b/tests/test_Tensor_greater.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.greater") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).greater(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.greater(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.greater(other=other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.greater(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).greater(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_gt.py b/tests/test_Tensor_gt.py new file mode 100644 index 000000000..a5d7ff791 --- /dev/null +++ b/tests/test_Tensor_gt.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.gt") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).gt(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.gt(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.gt(other=other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.gt(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).gt(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_heaviside.py b/tests/test_Tensor_heaviside.py new file mode 100644 index 000000000..f98cb08ab --- /dev/null +++ b/tests/test_Tensor_heaviside.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.heaviside") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-1.5, 0, 2.0]) + values = torch.tensor([0.5]) + result = input.heaviside(values) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([-1.5, 0, 2.0]).heaviside(torch.tensor([0.5])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-1.5, 0, 2.0]) + result = input.heaviside(torch.tensor([0.5])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-1.5, 0, 2.0]) + result = input.heaviside(torch.tensor([0.5, 1.7, 0.8])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-1.5, 0, 2.0]) + result = input.heaviside(torch.tensor(data=[0.5, 1.7, 0.8])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_index_select.py b/tests/test_Tensor_index_select.py new file mode 100644 index 000000000..1495b3dae --- /dev/null +++ b/tests/test_Tensor_index_select.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.index_select") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.eye(2, 4) + indices = torch.tensor([0, 1]) + result = x.index_select(0, indices) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = torch.tensor([0, 1]) + result = torch.eye(3, 4).index_select(1, indices) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = torch.tensor([0, 1]) + dim = 0 + result = torch.eye(3, 4).index_select(dim, indices) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + indices = torch.tensor([0, 3]) + dim = 0 + result = torch.eye(5, 4).index_select(dim=dim, index=indices) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_inner.py b/tests/test_Tensor_inner.py new file mode 100644 index 000000000..4b8a951dc --- /dev/null +++ b/tests/test_Tensor_inner.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.inner") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2, 3]).inner(torch.tensor([0., 2, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.8173, 1.0874, 1.1784], [0.3279, 0.1234, 2.7894]]) + y = torch.tensor([[[-0.4682, -0.7159, 0.1506], + [ 0.4034, -0.3657, 1.0387], + [ 0.9892, -0.6684, 0.1774], + [ 0.9482, 1.3261, 0.3917]], + [[ 0.4537, 0.7493, 1.1724], + [ 0.2291, 0.5749, -0.2267], + [-0.7920, 0.3607, -0.3701], + [ 1.3666, -0.5850, -1.7242]]]) + result = x.inner(y) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.8173, 1.0874, 1.1784], [0.3279, 0.1234, 2.7894]]) + result = x.inner(torch.tensor(2.)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1., 2, 3]).inner(other=torch.tensor([0., 2, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# The paddle input does not support integer type +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, 2, 3]).inner(torch.tensor([0, 2, 1])) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_inverse.py b/tests/test_Tensor_inverse.py new file mode 100644 index 000000000..97c18b89f --- /dev/null +++ b/tests/test_Tensor_inverse.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.inverse") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516], + [-0.1383, 1.5706, 0.4724, 0.4141], + [ 0.1193, 0.2829, 0.9037, 0.3957], + [-0.8202, -0.6474, -0.1631, -0.6543]]) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[[[-0.1533, 2.3020, -0.1771, 0.5928], + [ 0.4338, -0.6537, 0.2296, 0.5946], + [-0.4932, 1.8386, -0.1039, 1.0440], + [ 0.1735, -0.8303, -0.3821, -0.4384]], + [[-0.1533, 2.3020, -0.1771, 0.5928], + [ 0.4338, -0.6537, 0.2296, 0.5946], + [-0.4932, 1.8386, -0.1039, 1.0440], + [ 0.1735, -0.8303, -0.3821, -0.4384]]]]) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"]) + + +# The paddle input does not support complex type +def _test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[-3.0832+3.0494j, -0.1751+0.1449j, 1.2197+1.8188j, 4.0353-0.7416j], + [-2.9842-2.8928j, 0.2123+0.6190j, -2.6104+0.7303j, 1.9740+3.3802j], + [ 0.4939-2.4271j, 0.5006-0.6895j, -1.3655-0.2352j, -1.6636+1.6514j], + [-4.1212+0.1513j, 0.7119-0.0603j, -1.7803+2.8278j, 3.4966+1.2988j]]) + result = x.inverse() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_is_complex.py b/tests/test_Tensor_is_complex.py new file mode 100644 index 000000000..06a1733b6 --- /dev/null +++ b/tests/test_Tensor_is_complex.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.is_complex") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[4, 9], [23, 2]]) + result = a.is_complex() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[4, 9], [23, 2]], dtype=torch.complex64).is_complex() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[4, 9], [23, 2]], dtype=torch.complex128) + result = a.is_complex() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_is_floating_point.py b/tests/test_Tensor_is_floating_point.py new file mode 100644 index 000000000..eb894164d --- /dev/null +++ b/tests/test_Tensor_is_floating_point.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.is_floating_point") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[4, 9], [23, 2]], dtype=torch.int64) + result = a.is_floating_point() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[4, 9], [23, 2]], dtype=torch.float64) + result = a.is_floating_point() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[4, 9], [23, 2]], dtype=torch.float32).is_floating_point() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[4, 9], [23, 2]], dtype=torch.float16).is_floating_point() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[4, 9], [23, 2]], dtype=torch.bfloat16).is_floating_point() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isclose.py b/tests/test_Tensor_isclose.py new file mode 100644 index 000000000..ca3939d50 --- /dev/null +++ b/tests/test_Tensor_isclose.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isclose") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([10000., 1e-07]).isclose(torch.tensor([10000.1, 1e-08])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([10000., 1e-08]).isclose(torch.tensor([10000.1, 1e-09])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1.0, float('nan')]).isclose(torch.tensor([1.0, float('nan')])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1.0, float('inf')]).isclose(torch.tensor([1.0, float('inf')]), equal_nan=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([10000., 1e-07]).isclose(torch.tensor([10000.1, 1e-08]), atol=2.) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isfinite.py b/tests/test_Tensor_isfinite.py new file mode 100644 index 000000000..7c4b24bf4 --- /dev/null +++ b/tests/test_Tensor_isfinite.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isfinite") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]) + result = input.isfinite() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 6.9, 2]) + result = input.isfinite() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isinf.py b/tests/test_Tensor_isinf.py new file mode 100644 index 000000000..9ab5fe9a8 --- /dev/null +++ b/tests/test_Tensor_isinf.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isinf") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]) + result = input.isinf() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 6.9, 2]) + result = input.isinf() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_isnan.py b/tests/test_Tensor_isnan.py new file mode 100644 index 000000000..b95f9a771 --- /dev/null +++ b/tests/test_Tensor_isnan.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.isnan") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]) + result = input.isnan() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 6.9, 2]) + result = input.isnan() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_kthvalue.py b/tests/test_Tensor_kthvalue.py new file mode 100644 index 000000000..338fde9a5 --- /dev/null +++ b/tests/test_Tensor_kthvalue.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.kthvalue") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1., 2., 3., 4., 5.]) + result = x.kthvalue(4) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) + result = x.kthvalue(2, 0, True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) + result = x.kthvalue(k=2, dim=0, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1., 2., 3.], [ 4., 5., 6.]]) + result = x.kthvalue(k=2, dim=0, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_lcm.py b/tests/test_Tensor_lcm.py new file mode 100644 index 000000000..b205198ae --- /dev/null +++ b/tests/test_Tensor_lcm.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.lcm") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([5, 10, 15]) + b = torch.tensor([3, 4, 5]) + result = a.lcm(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([5, 10, 15]) + b = torch.tensor([3, 4, 5]) + result = a.lcm(other=b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([5, 10, 15]).lcm(torch.tensor([3, 4, 5])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([5, 10, 15]).lcm(other=torch.tensor([3, 4, 5])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([5, 10, 15]) + b = torch.tensor([3]) + result = a.lcm(other=b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_le.py b/tests/test_Tensor_le.py new file mode 100644 index 000000000..d83e0d95a --- /dev/null +++ b/tests/test_Tensor_le.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.le") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).le(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.le(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.le(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.le(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).le(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_lerp.py b/tests/test_Tensor_lerp.py new file mode 100644 index 000000000..dc210ffc0 --- /dev/null +++ b/tests/test_Tensor_lerp.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.lerp") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = start.lerp(end, weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = torch.tensor([1., 2., 3., 4.]).lerp(torch.tensor([10., 10., 10., 10.]), weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = start.lerp(end=end, weight=weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + result = start.lerp(end=end, weight=0.5) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_lerp_.py b/tests/test_Tensor_lerp_.py new file mode 100644 index 000000000..aa3c2fa11 --- /dev/null +++ b/tests/test_Tensor_lerp_.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.lerp_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = start.lerp_(end, weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = torch.tensor([1., 2., 3., 4.]).lerp_(torch.tensor([10., 10., 10., 10.]), weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + weight = torch.tensor([0.5, 1, 0.3, 0.6]) + result = start.lerp_(end=end, weight=weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + start = torch.tensor([1., 2., 3., 4.]) + end = torch.tensor([10., 10., 10., 10.]) + result = start.lerp_(end=end, weight=0.5) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_less.py b/tests/test_Tensor_less.py new file mode 100644 index 000000000..ca869de7a --- /dev/null +++ b/tests/test_Tensor_less.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.less") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).less(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.less(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.less(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.less(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).less(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_less_equal.py b/tests/test_Tensor_less_equal.py new file mode 100644 index 000000000..473cc4aae --- /dev/null +++ b/tests/test_Tensor_less_equal.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.less_equal") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).less_equal(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.less_equal(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.less_equal(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.less_equal(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).less_equal(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_lgamma.py b/tests/test_Tensor_lgamma.py new file mode 100644 index 000000000..6a8206c93 --- /dev/null +++ b/tests/test_Tensor_lgamma.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.lgamma") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.34, 1.5, 0.73]) + result = input.lgamma() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([0.34, 1.5, 0.73]).lgamma() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.34, 1.5, 0.73]) + result = input.lgamma() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_log.py b/tests/test_Tensor_log.py new file mode 100644 index 000000000..73d4e9ce3 --- /dev/null +++ b/tests/test_Tensor_log.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.log") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + result = input.log() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]).log() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4, 10, 7, 9]).log() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 10, 7, 9]) + result = x.log() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_log10.py b/tests/test_Tensor_log10.py new file mode 100644 index 000000000..5bcf7e7bf --- /dev/null +++ b/tests/test_Tensor_log10.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.log10") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + result = input.log10() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]).log10() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4, 10, 7, 9]).log10() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_log2.py b/tests/test_Tensor_log2.py new file mode 100644 index 000000000..400707ed7 --- /dev/null +++ b/tests/test_Tensor_log2.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.log2") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]) + result = input.log2() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4.7767, 4.3234, 1.2156, 0.2411, 4.5739]).log2() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([4, 10, 7, 9]).log2() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([4, 10, 7, 9]) + result = x.log2() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logical_and.py b/tests/test_Tensor_logical_and.py new file mode 100644 index 000000000..d4d65018f --- /dev/null +++ b/tests/test_Tensor_logical_and.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logical_and") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([True, False, True]).logical_and(torch.tensor([True, False, False])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_and(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = a.logical_and(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = torch.tensor([0, 1, 10., 0.]).logical_and(other=torch.tensor([4, 0, 10., 0.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_and(b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logical_not.py b/tests/test_Tensor_logical_not.py new file mode 100644 index 000000000..ce4d11b05 --- /dev/null +++ b/tests/test_Tensor_logical_not.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logical_not") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([True, False, True]).logical_not() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + result = a.logical_not() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + result = a.logical_not() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + result = torch.tensor([0, 1, 10., 0.]).logical_not() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logical_or.py b/tests/test_Tensor_logical_or.py new file mode 100644 index 000000000..deeebc94f --- /dev/null +++ b/tests/test_Tensor_logical_or.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logical_or") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([True, False, True]).logical_or(torch.tensor([True, False, False])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_or(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = a.logical_or(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = torch.tensor([0, 1, 10., 0.]).logical_or(other=torch.tensor([4, 0, 10., 0.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_or(b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logical_xor.py b/tests/test_Tensor_logical_xor.py new file mode 100644 index 000000000..01dddbbec --- /dev/null +++ b/tests/test_Tensor_logical_xor.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logical_xor") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([True, False, True]).logical_xor(torch.tensor([True, False, False])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_xor(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = a.logical_xor(b) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.float32) + result = torch.tensor([0, 1, 10., 0.]).logical_xor(other=torch.tensor([4, 0, 10., 0.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([0, 1, 10, 0], dtype=torch.float32) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + result = a.logical_xor(b) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logit.py b/tests/test_Tensor_logit.py new file mode 100644 index 000000000..d717bc10e --- /dev/null +++ b/tests/test_Tensor_logit.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logit") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + result = input.logit(eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) + eps = 1e-6 + result = input.logit(eps) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]).logit(eps=1e-6) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_logsumexp.py b/tests/test_Tensor_logsumexp.py new file mode 100644 index 000000000..529dba85a --- /dev/null +++ b/tests/test_Tensor_logsumexp.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.logsumexp") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1.4907, 1.0593, 1.5696]) + result = input.logsumexp(0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]]) + result = input.logsumexp(1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]]) + result = input.logsumexp(1, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]]) + result = input.logsumexp(dim=1, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# paddle does not integer type +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([1, 4, 6]) + result = input.logsumexp(0) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_lt.py b/tests/test_Tensor_lt.py new file mode 100644 index 000000000..7454112e9 --- /dev/null +++ b/tests/test_Tensor_lt.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.lt") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).lt(torch.tensor([[1, 1], [4, 4]])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 1], [4, 4]]) + result = input.lt(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([[1, 2], [3, 4]]) + result = input.lt(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[1, 2], [3, 4]]) + other = torch.tensor([1, 2]) + result = input.lt(other) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([[1, 2], [3, 4]]).lt(2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_batchnorm1d.py b/tests/test_batchnorm1d.py new file mode 100644 index 000000000..f22e11834 --- /dev/null +++ b/tests/test_batchnorm1d.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.BatchNorm1d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm1d(5, affine=False) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm1d(5, affine=False, eps=1e-5) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm1d(5, 1e-5, 0.2, affine=False) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm1d(5, 1e-5, 0.2, affine=True) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm1d(5, 1e-5, 0.2, affine=True, track_running_stats=True) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = False + m = torch.nn.BatchNorm1d(5, affine=a) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm1d(5, 1e-5, 0.2, affine=a, track_running_stats=True) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm1d(5, 1e-5, 0.2, affine=a, track_running_stats=True, dtype=torch.float32) + input = torch.zeros(2, 5) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_batchnorm2d.py b/tests/test_batchnorm2d.py new file mode 100644 index 000000000..71d8fd48b --- /dev/null +++ b/tests/test_batchnorm2d.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.BatchNorm2d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5, affine=False) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5, affine=False, eps=1e-5) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5, 1e-5, 0.2, affine=False) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5, 1e-5, 0.2, affine=True) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5, 1e-5, 0.2, affine=True, track_running_stats=True) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm2d(5) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = False + m = torch.nn.BatchNorm2d(5, affine=a) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm2d(5, 1e-5, 0.2, affine=a, track_running_stats=True) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm2d(5, 1e-5, 0.2, affine=a, track_running_stats=True, dtype=torch.float32) + input = torch.zeros(2, 5, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_batchnorm3d.py b/tests/test_batchnorm3d.py new file mode 100644 index 000000000..b44e7f869 --- /dev/null +++ b/tests/test_batchnorm3d.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.BatchNorm3d") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5, affine=False) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5, affine=False, eps=1e-5) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5, 1e-5, 0.2, affine=False) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5, 1e-5, 0.2, affine=True) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5, 1e-5, 0.2, affine=True, track_running_stats=True) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = torch.nn.BatchNorm3d(5) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = False + m = torch.nn.BatchNorm3d(5, affine=a) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm3d(5, 1e-5, 0.2, affine=a, track_running_stats=True) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = True + m = torch.nn.BatchNorm3d(5, 1e-5, 0.2, affine=a, track_running_stats=True, dtype=torch.float32) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + a = False + m = torch.nn.BatchNorm3d(5, affine=a, dtype=torch.float32) + input = torch.zeros(2, 5, 6, 4, 4) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_is_complex.py b/tests/test_is_complex.py index 90db19379..b2e390495 100644 --- a/tests/test_is_complex.py +++ b/tests/test_is_complex.py @@ -16,7 +16,7 @@ from apibase import APIBase -obj = APIBase("torch.is_tensor") +obj = APIBase("torch.is_complex") def test_case_1(): diff --git a/tests/test_lu.py b/tests/test_lu.py new file mode 100644 index 000000000..f0ccbcb2d --- /dev/null +++ b/tests/test_lu.py @@ -0,0 +1,176 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.lu") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU, pivots = torch.lu(A) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU, pivots, info = torch.lu(A, get_infos=True) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots", "info"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU, pivots, info = torch.lu(A, pivot=True, get_infos=True) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots", "info"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU = torch.empty_like(A) + pivots = torch.empty((2, 3), dtype=torch.int32) + info = torch.empty((2, ), dtype=torch.int32) + torch.lu(A, pivot=True, get_infos=True, out=(A_LU, pivots, info)) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots", "info"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU = torch.empty_like(A) + pivots = torch.empty((2, 3), dtype=torch.int32) + torch.lu(A, pivot=True, get_infos=False, out=(A_LU, pivots)) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + + A = torch.tensor( + [ + [ + [0.3591, -0.0479, -0.2174], + [-0.6957, -1.4667, 1.4384], + [0.0735, 0.1147, 0.0513], + ], + [ + [-1.2565, -2.1263, 0.8075], + [-0.3665, -3.3540, -0.9417], + [-0.1299, -0.0689, -0.6207], + ], + ] + ) + A_LU = torch.empty_like(A) + pivots = torch.empty((2, 3), dtype=torch.int32) + LU = torch.lu(A, pivot=True, get_infos=False, out=(A_LU, pivots)) + """ + ) + obj.run(pytorch_code, ["A_LU", "pivots", "LU"]) diff --git a/tests/test_nn_GRUCell.py b/tests/test_nn_GRUCell.py index 0e077bf70..7833921fe 100644 --- a/tests/test_nn_GRUCell.py +++ b/tests/test_nn_GRUCell.py @@ -64,7 +64,7 @@ def test_case_3(): h0 = h0[1] """ ) - obj.run(pytorch_code, ["h0"], unsupport=True, reason="unsupported dtype parameter") + obj.run(pytorch_code, ["h0"]) def test_case_4(): @@ -80,7 +80,7 @@ def test_case_4(): h0 = h0[1] """ ) - obj.run(pytorch_code, ["h0"], unsupport=True, reason="unsupported dtype parameter") + obj.run(pytorch_code, ["h0"]) def test_case_5(): diff --git a/tests/test_nn_GroupNorm.py b/tests/test_nn_GroupNorm.py index 845f289ad..2df38b292 100644 --- a/tests/test_nn_GroupNorm.py +++ b/tests/test_nn_GroupNorm.py @@ -89,3 +89,15 @@ def test_alias_case_3(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_alias_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([[[[2.,3.], [3., 5.]], [[5.,3.], [9., 5.]]]]) + m = torch.nn.modules.GroupNorm(2, 2, eps=1e-05, affine=True, dtype=torch.float32) + result = m(a) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_InstanceNorm1d.py b/tests/test_nn_InstanceNorm1d.py index 4693a7dfa..f054e6053 100644 --- a/tests/test_nn_InstanceNorm1d.py +++ b/tests/test_nn_InstanceNorm1d.py @@ -33,6 +33,7 @@ def test_case_1(): [ 0.9385, 0.4565, 0.7702], [ 0.4135, -0.2587, 0.0482]]]) result = m(input) + result.requires_grad = False """ ) obj.run(pytorch_code, ["result"]) @@ -112,3 +113,22 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = nn.InstanceNorm1d(3, affine=False, momentum=0.1, dtype=torch.float32) + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857], + [-1.2533, -0.9829, -1.0981], + [ 0.1507, -1.1431, -2.0361]], + + [[ 0.1024, -0.4482, 0.4137], + [ 0.9385, 0.4565, 0.7702], + [ 0.4135, -0.2587, 0.0482]]]) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_InstanceNorm2d.py b/tests/test_nn_InstanceNorm2d.py index 063aee9e1..d20514052 100644 --- a/tests/test_nn_InstanceNorm2d.py +++ b/tests/test_nn_InstanceNorm2d.py @@ -50,6 +50,7 @@ def test_case_1(): [0.5224, 0.9840, 0.0497], [0.8938, 0.5135, 0.5939]]]]) result = m(input) + result.requires_grad = False """ ) obj.run(pytorch_code, ["result"]) @@ -197,3 +198,39 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = nn.InstanceNorm2d(3, affine=False, momentum=0.1, dtype=torch.float32) + input = torch.tensor([[[[0.9436, 0.7335, 0.9228], + [0.5443, 0.3380, 0.0676], + [0.2152, 0.2725, 0.2988]], + + [[0.3839, 0.7517, 0.8147], + [0.7681, 0.0924, 0.3781], + [0.6991, 0.2401, 0.4732]], + + [[0.3631, 0.5113, 0.4535], + [0.9779, 0.4084, 0.5979], + [0.6865, 0.5924, 0.9122]]], + + + [[[0.1519, 0.2828, 0.0797], + [0.5871, 0.1052, 0.2343], + [0.0323, 0.0754, 0.6707]], + + [[0.6969, 0.4170, 0.0762], + [0.2514, 0.5124, 0.3972], + [0.1007, 0.7754, 0.4779]], + + [[0.1753, 0.2245, 0.0369], + [0.5224, 0.9840, 0.0497], + [0.8938, 0.5135, 0.5939]]]]) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_InstanceNorm3d.py b/tests/test_nn_InstanceNorm3d.py index 92829fe7e..6197d2370 100644 --- a/tests/test_nn_InstanceNorm3d.py +++ b/tests/test_nn_InstanceNorm3d.py @@ -27,6 +27,7 @@ def test_case_1(): m = nn.InstanceNorm3d(100) input = torch.ones(20, 100, 35, 45, 10) result = m(input) + result.requires_grad = False """ ) obj.run(pytorch_code, ["result"]) @@ -82,3 +83,16 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch.nn as nn + import torch + m = nn.InstanceNorm3d(100, affine=False, momentum=0.1, dtype=torch.float32) + input = torch.ones(20, 100, 35, 45, 10) + result = m(input) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_LSTMCell.py b/tests/test_nn_LSTMCell.py index a837ee4f9..a37bd0001 100644 --- a/tests/test_nn_LSTMCell.py +++ b/tests/test_nn_LSTMCell.py @@ -67,9 +67,7 @@ def test_case_3(): result = result2[0] """ ) - obj.run( - pytorch_code, ["result"], unsupport=True, reason="unsupported dtype parameter" - ) + obj.run(pytorch_code, ["result"]) def test_case_4(): @@ -86,9 +84,7 @@ def test_case_4(): result = result2[0] """ ) - obj.run( - pytorch_code, ["result"], unsupport=True, reason="unsupported dtype parameter" - ) + obj.run(pytorch_code, ["result"]) def test_case_5(): @@ -105,6 +101,4 @@ def test_case_5(): result = result2[0] """ ) - obj.run( - pytorch_code, ["result"], unsupport=True, reason="unsupported dtype parameter" - ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_RNNCell.py b/tests/test_nn_RNNCell.py index 6d8016bbd..130a842ac 100644 --- a/tests/test_nn_RNNCell.py +++ b/tests/test_nn_RNNCell.py @@ -64,7 +64,7 @@ def test_case_3(): h0 = h0[1] """ ) - obj.run(pytorch_code, ["h0"], unsupport=True, reason="unsupported dtype parameter") + obj.run(pytorch_code, ["h0"]) def test_case_4(): @@ -80,7 +80,7 @@ def test_case_4(): h0 = h0[1] """ ) - obj.run(pytorch_code, ["h0"], unsupport=True, reason="unsupported dtype parameter") + obj.run(pytorch_code, ["h0"]) def test_case_5(): @@ -113,3 +113,19 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["h0"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch import nn + inp = torch.tensor([[0.,0.],[0.,0.]]) + rnn = torch.nn.RNNCell(2, 2) + h0 = torch.tensor([[0.,0.],[0.,0.]]) + h0 = rnn(inp, h0) + if isinstance(h0, tuple): + h0 = h0[1] + """ + ) + obj.run(pytorch_code, ["h0"], check_value=False) diff --git a/tests/test_nn_functional_multi_margin_loss.py b/tests/test_nn_functional_multi_margin_loss.py new file mode 100644 index 000000000..300fa134a --- /dev/null +++ b/tests/test_nn_functional_multi_margin_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.multi_margin_loss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, reduction='sum') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, reduction='none') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + weight = torch.tensor([0.2, 0.3, 0.5]) + result = torch.nn.functional.multi_margin_loss(input, target, weight=weight) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, size_average=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, reduce=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, margin=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([1, 0]) + result = torch.nn.functional.multi_margin_loss(input, target, p=2) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_soft_margin_loss.py b/tests/test_nn_functional_soft_margin_loss.py new file mode 100644 index 000000000..0c0b57c65 --- /dev/null +++ b/tests/test_nn_functional_soft_margin_loss.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.soft_margin_loss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 0., 1.],[0., 1., 1.]]) + result = torch.nn.functional.soft_margin_loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 0., 1.],[0., 1., 1.]]) + result = torch.nn.functional.soft_margin_loss(input, target, reduction='sum') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 0., 1.],[0., 1., 1.]]) + result = torch.nn.functional.soft_margin_loss(input, target, reduction='none') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 0., 1.],[0., 1., 1.]]) + result = torch.nn.functional.soft_margin_loss(input, target, size_average=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 0., 1.],[0., 1., 1.]]) + result = torch.nn.functional.soft_margin_loss(input, target, reduce=False) + """ + ) + obj.run(pytorch_code, ["result"])