꺼내먹는지식 준
Compose __call__(): Albumentations Transforms 매개변수가 여러개일 때 본문
맨날 transform(img)
image tranform 하는 것만 보다가 갑작스럽게 image, bboxes, labels가 모두 transform의 매개변수로 들어가는 걸 보고 당황했다.
주변인의 도움을 받아 해결
참고 코드는 다음과 같다.
transforms = get_train_transform(resize_H, resize_W)
def get_train_transform(h, w):
return A.Compose([
A.Resize(height = h, width = w),
A.Flip(p=0.5),
ToTensorV2(p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
sample = {
'image': image,
'bboxes': boxes,
'labels': labels
}
sample = transforms(**sample)
일단 torchvision transforms 에서는 image만을 매개변수로 받고, 더 추가하려면 custom compose 가 필요하다.
그러나 Albumentations는 compose 구현이 다르게 되어있다.
소스 코드를 참고하면 다음과 같다.
Compose __call__()
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
if args:
raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
if self.is_check_args:
self._check_args(**data)
_check_args 를 따라가면,
def _check_args(self, **kwargs) -> None:
checked_single = ["image", "mask"]
checked_multi = ["masks"]
check_bbox_param = ["bboxes"]
# ["bboxes", "keypoints"] could be almost any type, no need to check them
for data_name, data in kwargs.items():
internal_data_name = self.additional_targets.get(data_name, data_name)
if internal_data_name in checked_single:
if not isinstance(data, np.ndarray):
raise TypeError("{} must be numpy array type".format(data_name))
if internal_data_name in checked_multi:
if data:
if not isinstance(data[0], np.ndarray):
raise TypeError("{} must be list of numpy arrays".format(data_name))
if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None:
raise ValueError("bbox_params must be specified for bbox transformations")
additional_targets: typing.Optional[typing.Dict[str, str]] = None,
타입힌트를 모르시는 분들을 위해
https://www.daleseo.com/python-typing/
[파이썬] typing 모듈로 타입 표시하기
Engineering Blog by Dale Seo
www.daleseo.com
additional_targets 는 dict type
#additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'}
additional_targets.get
dictionary get:
dictionary에서 'a'를 key 로 하는 value 를 찾아 c에 저장, 없으면 두번째 파라미터인 Nothing 저장
즉 여기서
additional_targets.get(data_name, data_name)는
먼저 Compose 를 생성할 때, 기본 키워드인 images, bboxes를 다른 키워드로 변경하고 dictionary 에 key value로 저장한 후,
Compose __call__ 시 들어오는 인자가 default 인자와 다를 때, get을 통해 불러와서 다른 작업 시 keyword 문제가 생기지 않도록 처리하는 기능이다.
굳이 이렇게 사용할 일이 있을까..? 싶지만 있나보다.
해당 키워드들에 따라 어떤 기능이 수행되는지는 다음에 이어서..
※kwargs.items() 란?
def greet_me(**kwargs):
print('kwargs.items() : ',kwargs.items())
#for key, value in kwargs.items():
# print( f'{ key } (key) = { value } (value)')
kwargs = {
'name' : '조재성',
'학교' : 'dsu',
'study': 'pan'
}
greet_me(**kwargs)
# kwargs.items() : dict_items([('name', '조재성'), ('학교', 'dsu'), ('study', 'pan')])
'AI > PyTorch' 카테고리의 다른 글
torch Softmax 차원에 따른 출력 값 차이 (0) | 2022.03.24 |
---|---|
Pytorch 차원 (dim, axis) (0) | 2022.03.24 |
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 문제 해결 (0) | 2022.03.08 |
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same 문제 해결 (0) | 2022.03.08 |
Pytorchlightning TensorBoard 사용법 쉬운 정리 (0) | 2022.03.05 |