From 1314f823bdfa15fb4216f4212a9856e0c05ebbb3 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 5 Jul 2023 18:18:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=98=A0=E5=B0=84=E6=96=87=E6=A1=A3=20No.120?= =?UTF-8?q?=20(#5966)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add api_difference doc * Fix --- .../api_difference/cuda/torch.cuda.device.md | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md new file mode 100644 index 00000000000..1ba1e7da761 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md @@ -0,0 +1,40 @@ +## [参数不一致]torch.cuda.device + +### [torch.cuda.device](https://pytorch.org/docs/1.13/generated/torch.cuda.device.html#torch.cuda.device) + +```python +torch.cuda.device(device) +``` + +### [paddle.CUDAPlace](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/CUDAPlace_cn.html) + +```python +paddle.CUDAPlace(id) +``` + +其中 Pytorch 与 Paddle 的参数支持类型不一致,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------- | ------------ | -------------------------------------------------------------------------------- | +| device | id | GPU 的设备 ID, Pytorch 支持 torch.device 和 int,Paddle 支持 int,需要进行转写。 | + +### 转写示例 + +#### device: 获取 device 参数,对其取 device.index 值 + +```python +# Pytorch 写法 +torch.cuda.device(torch.device('cuda')) + +# Paddle 写法 +paddle.CUDAPlace(0) + +# 增加 index +# Pytorch 写法 +torch.cuda.device(torch.device('cuda', index=index)) + +# Paddle 写法 +paddle.CUDAPlace(index) +```