꺼내먹는지식 준

Torch Indexing 본문

AI/PyTorch

Torch Indexing

알 수 없는 사용자 2022. 1. 25. 15:01

torch.index_select(input, dim, index, *, out=None) → Tensor

 

Torch Indexing 

1) numpy indexing 기법 

2) torch.index_select 

A = torch.Tensor([[1, 2],
                  [3, 4]])
                  
A[:,0] #1,3 
A[:][0] # 1,2 == A[0]
torch.index_select(A, 1, indices) #[[1.], [3.]]
output.squeeze(1) # [1., 3.]

torch.index_select 엄청 번거롭다. 

 

3) torch_gather 

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

output = torch.gather(A, 1, torch.tensor([[0],[1]]))
#A의 1 차원으로 indexing. 1차원 첫번째 index 중 [0] 번째 데이터, 두번째 index 중 [1] 

print(output.shape)
#[[1.], [4.]]
#size([2 ,1])
output = output.squeeze(1)
print(output)
#[1., 4.]

torch_gather로 원하는 곳 indexing이 가능하다. 

 

3 D gather 

import torch

A = torch.Tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])

print(A.shape)
#torch.Size([2, 2, 2])

output = torch.gather(A, 1, torch.tensor([[[0,1]], [[0,1]]]))
print(output)
#tensor([[[1., 4.]],

        [[5., 8.]]])

print(output.shape)
torch.Size([2, 1, 2])

output = output.squeeze(1)
print(output)
#tensor([[1., 4.],
        [5., 8.]])

굉장히 까다롭다. [  [[0,1]] , [[0,1]]   ]

 

4) Torch.tensor.expand

Tensor.expand(*sizes) → Tensor
x = torch.tensor([[1], [2], [3]])
x.size()
#torch.Size([3, 1])
x.expand(3, 4)
#tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])
x.expand(-1, 4)   # -1 means not changing the size of that dimension
#tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])

크기를 확장시켜준다. 

 

 

'AI > PyTorch' 카테고리의 다른 글

Torch Math Operations  (0) 2022.01.25
Torch Tensors  (0) 2022.01.25
Pytorch Dataset  (0) 2022.01.25
Pytorch Backpropagation(AutoGrad, Optimizer)  (0) 2022.01.25
Pytorch 프로젝트 생성, 배포, 유지보수  (0) 2022.01.24
Comments