๋ง์ง๋ง ์์ ์ผ์: README ์ํ: ๐๊ฒ์ํ ์์ฑ์ผ์: 2020๋ 11์ 2์ผ ์ค์ 11:53
์ด ํ๋ก์ ํธ๋ ๊ธฐ๋ณธ์ ์ผ๋ก 2๊ฐ์ python ํ์ผ (run_websocket_server(0.2.3).py, run_websocket_client(0.2.3).py**)**๋ก ๊ตฌ์ฑ๋์ด์์ต๋๋ค. ๋ ํ์ผ์ ๊ฐ๊ฐ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด์ ์ปดํจํฐ์ ๋ค์ด๋ก๋๋์ด์์ด์ผ ํ๋ฉฐ ์๋ฒ๋ฅผ ๊ตฌ์ฑ ๋ฐ ๊ตฌ๋ํ๋ ์ญํ ์ ํฉ๋๋ค. ์ญํ ๋ฐ ์ฌ์ฉ๋ฒ์ ๋ํ ์์ธํ ์ค๋ช ์ '์ด ๊ณณ'์ ์ฐธ์กฐํด ์ฃผ์ญ์์ค. ๋ณธ ๋ฌธ์์์๋ ํฅํ ์ ์ง๋ณด์๋ฅผ ์ํด ํ์ด์ฌ ํ์ผ์ ์ฝ๋๋ง์ ๊ฐ๋ตํ๊ฒ ์ค๋ช ํฉ๋๋ค.
โ๏ธ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด์์ ์คํํ๋ฉฐ ์ค์ ์๋ฒ๋ก๋ถํฐ ์ปค๋งจ๋๋ฅผ ์์ ํ๋ ์๋ฒ๋ฅผ ๊ตฌ๋ํฉ๋๋ค.
โ๏ธ ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์ ์ ํ ๋๋ก ์์ ํ ์ปค๋งจ๋์ ๋ฐ๋ผ ์ ํด์ง ๋์์ ์ํํฉ๋๋ค.
- def : start_proc
def start_proc(participant, kwargs): # pragma: no cover
""" helper function for spinning up a websocket participant """
def target():
server = participant(**kwargs)
server.start()
p = Process(target=target)
p.start()
return p
- ๋งค๊ฐ๋ณ์ participant ์๋ syft ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ WebsocketServerWorker ํจ์๊ฐ ์ฃผ๋ก ๋ค์ด๊ฐ๋๋ค.
- ๋งค๊ฐ๋ณ์ kwargs ์๋ ํ์ผ์ ์คํํ ๋ ๋ฃ์ ์ธ์๊ฐ์ด ๋ค์ด๊ฐ๋๋ค.
- ๋ ๊ฐ์ ๋งค๊ฐ๋ณ์๋ฅผ ํ ๋๋ก ์๋ฒ๋ฅผ ๊ตฌ๋ํฉ๋๋ค. multiprocess ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ Processs ํจ์๋ฅผ ์ด์ฉํ ์ฐ๋ ๋ฉ์ ์ด์ฉํฉ๋๋ค.
2. part : parser
parser = argparse.ArgumentParser(description="Run websocket server worker.")
parser.add_argument(
"--port", "-p", type=int, help="port number of the websocket server worker, e.g. --port 8777"
)
parser.add_argument("--host", type=str, default="localhost", help="host for the connection")
parser.add_argument(
"--id", type=str, help="name (id) of the websocket server worker, e.g. --id alice"
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="if set, websocket server worker will be started in verbose mode",
)
args = parser.parse_args()
- ํ์ผ์ ์คํํ ๋ ๊ฐ๋ฅํ ์ต์ (argument)์ ๋ํด ์ ์ํฉ๋๋ค.
- port, host(์์ดํผ ์ฃผ์), id, verbose 4๊ฐ์ ์ธ์๊ฐ ์กด์ฌํฉ๋๋ค.
3. part : main
kwargs = {
"id": args.id,
"host": args.host,
"port": args.port,
"hook": hook,
"verbose": args.verbose,
}
server = start_proc(WebsocketServerWorker, kwargs)
- ์คํ์ ๋ฐ์ ์ธ์๋ฅผ ํ ๋๋ก start_proc์ ์คํํฉ๋๋ค.
โ๏ธ ์ค์ ์ฅ์น(๋ฐ์คํฌํฑ, ๋ ธํธ๋ถ)์ ์กด์ฌํ๋ฉฐ ์คํ ์ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด๊ฐ ๊ตฌ๋์ค์ธ ์๋ฒ์ ์ฐ๊ฒฐํฉ๋๋ค.
โ๏ธ ๊ฐ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด์ FL์ ์ํ ์ปค๋งจ๋๋ฅผ ์ก์ ํฉ๋๋ค.
โ๏ธ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด๋ก๋ถํฐ ์์ ํ ๋ชจ๋ธ์ ํฉ์ฐ, ์ฒ๋ฆฌํ ํ ์ ๋ฐ์ดํธ๋ ๋ชจ๋ธ์ ํ์ ํฉ๋๋ค.
- class : ๋คํธ์ํฌ
import torch.nn as nn
import torch.nn.functional as f
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = f.relu(self.conv1(x))
x = f.max_pool2d(x, 2, 2)
x = f.relu(self.conv2(x))
x = f.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = f.relu(self.fc1(x))
x = self.fc2(x)
return f.log_softmax(x, dim=1)
- ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ ๋คํธ์ํฌ๋ฅผ ์ค์ ํฉ๋๋ค.
- nn ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ถ์ํ๋ Module ํด๋์ค๋ฅผ ์์ํฉ๋๋ค.
- nn ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ functional ์์ relu, pool๊ณผ ๊ฐ์ ๋ ์ด์ด ํ๋ฆฌ์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
2. def : train_on_batches
import torch.optim as optim
def train_on_batches(worker, batches, model_in, device, lr):
"""Train the model on the worker on the provided batches
Args:
worker(syft.workers.BaseWorker): worker on which the
training will be executed
batches: batches of data of this worker
model_in: machine learning model, training will be done on a copy
device (torch.device): where to run the training
lr: learning rate of the training steps
Returns:
model, loss: obtained model and loss after training
"""
model = model_in.copy()
optimizer = optim.SGD(model.parameters(), lr=lr) # TODO momentum is not supported at the moment
model.train()
model.send(worker)
loss_local = False
- optimizer์ ์ฌ์ฉํ๊ณ ์ ํ๋ ์ตํฐ๋ง์ด์ ๋ฅผ ์ ํํ๊ณ ๋ชจ๋ธ์ ์ฐ๊ฒฐ ํ ํ ํ์ต๋ฅ ์ ์ ํฉ๋๋ค.
- train() ํจ์๋ฅผ ์ด์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ต๋ชจ๋๋ก ์ ํํฉ๋๋ค.
for batch_idx, (data, target) in enumerate(batches):
loss_local = False
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = f.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % LOG_INTERVAL == 0:
loss = loss.get() # <-- NEW: get the loss back
loss_local = True
logger.debug(
"Train Worker {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
worker.id,
batch_idx,
len(batches),
100.0 * batch_idx / len(batches),
loss.item(),
)
)
if not loss_local:
loss = loss.get() # <-- NEW: get the loss back
model.get() # <-- NEW: get the model back
return model, loss
- ๋ฐ์์จ ๋ฐฐ์น ๋ฐ์ดํฐ์ ์ ๋๋ฐ์ด์ค๋ก ๋ณด๋ ๋๋ค. ๋ฐ์ดํฐ์ ์ ๋ฐ์ดํฐ์ ํ๊ฒ์ผ๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
3. def : get_next_batches
def get_next_batches(fdataloader: sy.FederatedDataLoader, nr_batches: int):
"""retrieve next nr_batches of the federated data loader and group
the batches by worker
Args:
fdataloader (sy.FederatedDataLoader): federated data loader
over which the function will iterate
nr_batches (int): number of batches (per worker) to retrieve
Returns:
Dict[syft.workers.BaseWorker, List[batches]]
"""
batches = {}
for worker_id in fdataloader.workers:
worker = fdataloader.federated_dataset.datasets[worker_id].location
batches[worker] = []
try:
for i in range(nr_batches):
next_batches = next(fdataloader)
for worker in next_batches:
batches[worker].append(next_batches[worker])
except StopIteration:
pass
return batches
4. def : train
def train(
model, device, federated_train_loader, lr, federate_after_n_batches, abort_after_one=False
):
model.train()
nr_batches = federate_after_n_batches
models = {}
loss_values = {}
iter(federated_train_loader) # initialize iterators
batches = get_next_batches(federated_train_loader, nr_batches)
counter = 0
while True:
logger.debug(
"Starting training round, batches [{}, {}]".format(counter, counter + nr_batches)
)
data_for_all_workers = True
for worker in batches:
curr_batches = batches[worker]
if curr_batches:
models[worker], loss_values[worker] = train_on_batches(
worker, curr_batches, model, device, lr
)
else:
data_for_all_workers = False
counter += nr_batches
if not data_for_all_workers:
logger.debug("At least one worker ran out of data, stopping.")
break
model = utils.federated_avg(models)
batches = get_next_batches(federated_train_loader, nr_batches)
if abort_after_one:
break
return model
- ๋ชจ๋ธ์ ํธ๋ ์ด๋ํ๋ ๋ถ๋ถ์ ๋๋ค. model.train() ์ ์ฌ์ฉ๋๋ ํธ๋ ์ด๋ ๋ฉ์๋์๋ ๋ณ๊ฐ์ ๋๋ค.
- batches ๋ณ์์๋ get_next_batches ํจ์๋ฅผ ์ด์ฉํ์ฌ ๋ฏธ๋ฆฌ ์ ํ ๋ฐฐ์น ์ ๋งํผ์ ๋ฐ์ดํฐ์ ์ ๋ฐ์์ต๋๋ค.
5. def : test
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += f.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
logger.debug("\n")
accuracy = 100.0 * correct / len(test_loader.dataset)
logger.info(
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), accuracy
)
)
6. def : define_and_get_arguments
def define_and_get_arguments(args=sys.argv[1:]):
parser = argparse.ArgumentParser(
description="Run federated learning using websocket client workers."
)
parser.add_argument("--batch_size", type=int, default=64, help="batch size of the training")
parser.add_argument(
"--test_batch_size", type=int, default=1000, help="batch size used for the test data"
)
parser.add_argument("--epochs", type=int, default=2, help="number of epochs to train")
parser.add_argument(
"--federate_after_n_batches",
type=int,
default=50, help="number of training steps performed on each remote worker " "before averaging",
)
parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
parser.add_argument("--cuda", action="store_true", help="use cuda")
parser.add_argument("--seed", type=int, default=1, help="seed used for randomization")
parser.add_argument("--save_model", action="store_true", help="if set, model will be saved")
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="if set, websocket client workers will " "be started in verbose mode",
)
parser.add_argument(
"--use_virtual", action="store_true", help="if set, virtual workers will be used"
)
args = parser.parse_args(args=args)
return args
7. def : main
๋งค๊ฐ๋ณ์ ์์
def main():
args = define_and_get_arguments()
hook = sy.TorchHook(torch)
# ๊ฐ์์์
์(์๋ฎฌ๋ ์ด์
) ์ฌ์ฉ์ ์ด๊ณณ์ผ๋ก ๋ถ๊ธฐ
if args.use_virtual:
alice = VirtualWorker(id="alice", hook=hook, verbose=args.verbose)
bob = VirtualWorker(id="bob", hook=hook, verbose=args.verbose)
charlie = VirtualWorker(id="charlie", hook=hook, verbose=args.verbose)
# ์น์์ผ์์
์ ์ฌ์ฉ์ ์ด๊ณณ์ผ๋ก ๋ถ๊ธฐ
else:
a_kwargs_websocket = {"host": "192.168.0.52", "hook": hook}
b_kwargs_websocket = {"host": "192.168.0.53", "hook": hook}
c_kwargs_websocket = {"host": "192.168.0.54", "hook": hook}
baseport = 10002
alice = WebsocketClientWorker(id="alice", port=baseport, **a_kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=baseport, **b_kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=baseport, **c_kwargs_websocket)
# ๊ฐ์ฒด๋ฅผ ๋ฆฌ์คํธ๋ก ๋ฌถ์
workers = [alice, bob, charlie]
# ์ฟ ๋ค ์ฌ์ฉ ์ฌ๋ถ
use_cuda = args.cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
# ๋๋ค ์๋ ์ค์
torch.manual_seed(args.seed)
- define_and_get_arguments() ๋ฅผ ์ด์ฉํ์ฌ ์คํ ์ต์ ์ ๋ฐ์์ต๋๋ค.
- use_virtual ์ต์ ์ ์คํํ์ ๊ฒฝ์ฐ ์น์์ผ์ ์ด์ฉํ์ง ์๊ณ ๊ฐ์ ์์ปค๋ก ์๋ฎฌ๋ ์ด์ ํฉ๋๋ค. ์ค์ ๋ก ๋ผ์ฆ๋ฒ ๋ฆฌ ํ์ด์ ์ฐ๊ฒฐํ์ฌ ์คํํ๊ธฐ ์ ๊ฐ์ ์์ปค ์๋ฎฌ๋ ์ด์ ์ ์ด์ฉํด ํ ์คํธ ์๊ฐ์ ๋จ์ถํ ์ ์์ต๋๋ค.
- use_virtual์ ๋ฐ๋ก ์ค์ ํ์ง ์์์ ๊ฒฝ์ฐ ์น์์ผ์ผ๋ก ๋์ํฉ๋๋ค. ์ด ๊ฒฝ์ฐ kwargs_websocket์๋ ๋ผ์ฆ๋ฒ ๋ฆฌํ์ด์ IP, hook์ด ์ฃผ์ด์ง๋๋ค. ๊ทธ ํ, WebsocketClientWorker๋ฅผ ์ด์ฉํ์ฌ ๊ฐ ID์ ์์ปค ๊ฐ์ฒด๋ฅผ ํ ๋นํฉ๋๋ค.
- ๋ฑํ ์์ ํ ์ผ์ด ์๋ ํํธ์ ๋๋ค.
federated_train_loader = sy.FederatedDataLoader(
datasets.MNIST(
"../data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
).federate(tuple(workers)),
batch_size=args.batch_size,
shuffle=True,
iter_per_worker=True,
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"../data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs,
)
model = Net().to(device)
for epoch in range(1, args.epochs + 1):
# output : 2020-11-05 15:07:04,953 INFO run_websocket_client(0.2.3).py(l:268) - Starting epoch 1/2
logger.info("Starting epoch %s/%s", epoch, args.epochs)
model = train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
test(model, device, test_loader)
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
- federated_train_loader๋ ๋ถ๋ฌ์จ datasets์ Federated Learning์ด ๊ฐ๋ฅํ ๊ฐ์ฒด๋ก ๋ง๋ญ๋๋ค.
- FederateDataloader๋ Federate Learning์ ์คํํ๊ธฐ ์ํ ๋ช ๋ น์ด๋ค์ด ๋ชจ์ฌ์๋ ๊ฐ์ฒด์ ๋๋ค. ๋ฐ๋ณต์๋ก ๋ณํํ์ฌ ์ฌ์ฉํฉ๋๋ค.
- ๊ทธ ์ ์ ๊ฐ ์์ปค์๊ฒ ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฐฐํด์ผ ํฉ๋๋ค. ์ด๋ datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ federated(tuple(workers))๋ฅผ ์ด์ฉํฉ๋๋ค.
- ์ด์ args.epochs์ ๋ช ์๋ ์๋งํผ ํ์ต์ ๋ฐ๋ณตํฉ๋๋ค. ์ด epoch๋ ์ค์ ์๋ฒ์์ ๋ชจ๋ธ์ ์ง๊ณํ๋ epoch์ ๋๋ค. ๊ธฐ๋ณธ๊ฐ์ 2์ ๋๋ค.
8. part : main
if __name__ == "__main__":
FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s"
LOG_LEVEL = logging.DEBUG
logging.basicConfig(format=FORMAT, level=LOG_LEVEL)
websockets_logger = logging.getLogger("websockets")
websockets_logger.setLevel(logging.DEBUG)
websockets_logger.addHandler(logging.StreamHandler())
main()
- ๋ณ๊ฑด ์๊ณ ๋ก๊น ๋ฉ์์ง ์ค์ ๊ณผ ๋ฉ์ธ ํจ์ ์ง์ ํ๋ ๋๊ฐ์ง ํํธ๋ก ๋๋ฉ๋๋ค.
- getLogger๋ฅผ ์ด์ฉํด websockets๋ผ๋ ๋ก๊ฑฐ๋ฅผ ์์ฑํฉ๋๋ค. setLevel์ ์ด์ฉํด DEBUG ๋ ๋ฒจ ์์ ๋ ๋ฒจ์ ๋ชจ๋ ํ๋ฆฐํธํฉ๋๋ค. (๋ก๊ฑฐ ๋ ๋ฒจ์ DEBUG, INFO, WARNING, ERROR, CRITICAL 5๊ฐ๊ฐ ์กด์ฌํฉ๋๋ค.)
- addHandler๋ฅผ ์ด์ฉํด ์ฝ์์ฐฝ์ ๋ก๊ทธ๊ฐ ์ถ๋ ฅ๋๊ฒ๋ ์ค์ ํฉ๋๋ค. ํ์ผ,DB,์์ผ,ํ ๋ฑ์ ํตํด ์ถ๋ ฅํ๋๋ก ์ค์ ํ ์๋ ์์ต๋๋ค.
- ๋ก๊น ์ ๋ํด ์ฐธ๊ณ ํ ๋งํ ๋ธ๋ก๊ทธ ๊ธ โฌ๏ธ
- ์ดํ ๋ฉ์ธํจ์์ ์ง์ ํฉ๋๋ค.
-
FederatedDataLoader
class FederatedDataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single or several iterators over the dataset. Arguments: federated_dataset (FederatedDataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: ``1``). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: ``False``). collate_fn (callable, optional): merges a list of samples to form a mini-batch. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) num_iterators (int): number of workers from which to retrieve data in parallel. num_iterators <= len(federated_dataset.workers) - 1 the effect is to retrieve num_iterators epochs of data but at each step data from num_iterators distinct workers is returned. iter_per_worker (bool): if set to true, __next__() will return a dictionary containing one batch per worker """ __initialized = False def __init__( self, federated_dataset, batch_size=8, shuffle=False, num_iterators=1, drop_last=False, collate_fn=default_collate, iter_per_worker=False, **kwargs, ): if len(kwargs) > 0: options = ", ".join([f"{k}: {v}" for k, v in kwargs.items()]) logging.warning(f"The following options are not supported: {options}") try: self.workers = federated_dataset.workers except AttributeError: raise Exception( "Your dataset is not a FederatedDataset, please use " "torch.utils.data.DataLoader instead." ) self.federated_dataset = federated_dataset self.batch_size = batch_size self.drop_last = drop_last self.collate_fn = collate_fn self.iter_class = _DataLoaderOneWorkerIter if iter_per_worker else _DataLoaderIter # Build a batch sampler per worker self.batch_samplers = {} for worker in self.workers: data_range = range(len(federated_dataset[worker])) if shuffle: sampler = RandomSampler(data_range) else: sampler = SequentialSampler(data_range) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_samplers[worker] = batch_sampler if iter_per_worker: self.num_iterators = len(self.workers) else: # You can't have more iterators than n - 1 workers, because you always # need a worker idle in the worker switch process made by iterators if len(self.workers) == 1: self.num_iterators = 1 else: self.num_iterators = min(num_iterators, len(self.workers) - 1) def __iter__(self): self.iterators = [] for idx in range(self.num_iterators): self.iterators.append(self.iter_class(self, worker_idx=idx)) return self def __next__(self): if self.num_iterators > 1: batches = {} for iterator in self.iterators: data, target = next(iterator) batches[data.location] = (data, target) return batches else: iterator = self.iterators[0] data, target = next(iterator) return data, target def __len__(self): length = len(self.federated_dataset) / self.batch_size if self.drop_last: return int(length) else: return math.ceil(length)
- .federate() ํจ์๋ฅผ ์ด์ฉํด federated๋ ๋ฐ์ดํฐ์ ์ ์ ๋ ฅ์ผ๋ก ๋ฐ์ต๋๋ค.
-
torch.device
CUDA Tensors :
.to
๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ Tensor๋ฅผ ์ด๋ ํ ์ฅ์น๋ก๋ ์ฎ๊ธธ ์ ์์ต๋๋ค.# ์ด ์ฝ๋๋ CUDA๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ ํ๊ฒฝ์์๋ง ์คํํฉ๋๋ค. # ``torch.device`` ๋ฅผ ์ฌ์ฉํ์ฌ tensor๋ฅผ GPU ์ํ์ผ๋ก ์ด๋ํด๋ณด๊ฒ ์ต๋๋ค. if torch.cuda.is_available(): device = torch.device("cuda") # CUDA ์ฅ์น ๊ฐ์ฒด(device object)๋ก y = torch.ones_like(x, device=device) # GPU ์์ ์ง์ ์ ์ผ๋ก tensor๋ฅผ ์์ฑํ๊ฑฐ๋ x = x.to(device) # ``.to("cuda")`` ๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค. z = x + y print(z) print(z.to("cpu", torch.double)) # ``.to`` ๋ dtype๋ ํจ๊ป ๋ณ๊ฒฝํฉ๋๋ค!
- ์ค์์ฅ์น๊ฐ ์๋ ์์ปค๊ฐ ์์ ํ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ ํ์ต
- [ ]