close

How to get the device type of a pytorch module conveniently?

Hello Guys, How are you all? Hope You all Are Fine. Today We Are Going To learn about How to get the device type of a pytorch module conveniently in Python. So Here I am Explain to you all the possible Methods here.

Without wasting your time, Let’s start This Article.

Table of Contents

How to get the device type of a pytorch module conveniently?

  1. How to get the device type of a pytorch module conveniently?

    The recommended workflow (as described on PyTorch blog) is to create the device object separately and use that everywhere. Copy-pasting the example from the blog here:

  2. get the device type of a pytorch module conveniently

    The recommended workflow (as described on PyTorch blog) is to create the device object separately and use that everywhere. Copy-pasting the example from the blog here:

Method 1

This question has been asked many times (1, 2). Quoting the reply from a PyTorch developer:

That’s not possible. Modules can hold parameters of different types on different devices, and so it’s not always possible to unambiguously determine the device.

The recommended workflow (as described on PyTorch blog) is to create the device object separately and use that everywhere. Copy-pasting the example from the blog here:

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

Do note that there is nothing stopping you from adding a .device property to the models.

As mentioned by Kani (in the comments), if the all the parameters in the model are on the same device, one could use next(model.parameters()).device.

Method 2

My solution, works in 99% of cases.

class Net(nn.Module):
  def __init__()
    super().__init__()
    self.dummy_param = nn.Parameter(torch.empty(0))

  def forward(x):
    device = self.dummy_param.device
    ... etc

Thereafter, the dummy_param will always have the same device as the module Net, so you can get it anytime you want. eg:

net = Net()
net.dummy_param.device

'cpu'

net = net.to('cuda')
net.dummy_param.device

'cuda:0'

Summery

It’s all About this issue. Hope all Methods helped you a lot. Comment below Your thoughts and your queries. Also, Comment below which Method worked for you? Thank You.

Also, Read