꺼내먹는지식 준

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 문제 해결 본문

AI/PyTorch

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 문제 해결

알 수 없는 사용자 2022. 3. 8. 22:40
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

얼핏 보면, input type 은 cuda 에 올렸는데 weight type이라고 하니까 감이 잘 안온다. 

model을 올리지 않았다고 하는건가 싶다. 

나도 동일한 에러가 났었는데 다는 model type 이라고 분명하게 명시 되어있었다. 이때는 모델을 cuda 에 올리니까 바로 해결이 되었다. 

 

그럼 이 에러는 어떤 상황에서 발생하였는가? 

 

신규범_T3116

고수 캠퍼의 말을 빌려보자.

일반 list말고 ModuleList 써보세요 일반 List는 attribute로 외부에 안 나타나기 때문에 cuda로 보낼때 타겟으로 못읽는듯 합니다. ModuleList에 넣으면 잘 돌아가요. 

 

즉 module을 관리하기 위하여 list 에 모듈을 다음과 같이 추가한 예시를 보자.

list = [nn.Conv2d(in, out, kernel=1, stride=1), nn.BatchNorm2d(otu), nn.ReLU] 

이 경우 해당 모듈들이 cuda 에 올려지지 않고 모델 구조를 출력해보면 다음과 같이 나타난다. 

print(model)

이때 해결 방법은 두가지이다. 

1) 애초에 List가 아닌 ModuleList에 module을 추가하거나, 

 

2) List에 추가한 모듈을 nn.Sequential( *List ) 다음과 같이 unpacking하여 보관하는 것이다. 

 

-해결-

 

Comments