pytorh .to(device) 和.cuda()
to CUDA device
2023-09-11 14:22:51 时间
一、.to(device) 可以指定CPU 或者GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU model.to(device) #如果是多GPU if torch.cuda.device_count() > 1: model = nn.DataParallel(model,device_ids=[0,1,2]) model.to(device)
mytensor = my_tensor.to(device)
这行代码的意思是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。
这句话需要写的次数等于需要保存GPU上的tensor变量的个数;一般情况下这些tensor变量都是最开始读数据时的tensor变量,后面衍生的变量自然也都在GPU上
二、.cuda() 只能指定GPU
#指定某个GPU os.environ['CUDA_VISIBLE_DEVICE']='1' model.cuda() #如果是多GPU os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' device_ids = [0,1,2,3] net = torch.nn.Dataparallel(net, device_ids =device_ids) net = torch.nn.Dataparallel(net) # 默认使用所有的device_ids net = net.cuda()
class DataParallel(Module): def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0]
相关文章
- Unable to locate package错误解决办法
- 我的2016_To Code or Not to Code: No Question
- `java.time.LocalDateTime` from String “2020-11-19“: Failed to deserialize java.time.LocalDateTime
- angular ng-bind-html异常Attempting to use an unsafe value in a safe context处理
- 【jmeter】+Not able to find Java executable or version. Please check your Java installation
- 解决:org.springframework.web.multipart.MultipartException: Failed to parse multipart servlet request;
- SAP 创建启用了ARM功能的采购订单,报错 -Shipping processing is not selected to supplier 100057 in purchase org. 0002-
- iOS 苹果开发证书失效的解决方案(Failed to locate or generate matching signing assets)
- WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable(spark加载hadoop本地库的时候出现不能加载的情况要怎么解决呢?)
- POJ 2891 Strange Way to Express Integers 中国剩余定理解法
- Unity 报错之 Unable to convert classes into dex format.
- How to kill a process on a port on linux 怎么杀死 关掉一个端口
- [LeetCode] 1312. Minimum Insertion Steps to Make a String Palindrome 让字符串成为回文串的最少插入次数
- SLF4J: Failed to load class “org.slf4j.impl.StaticLoggerBinder“.的解决方法
- 1033 To Fill or Not to Fill
- frp错误处理:login to server failed: authorization failed