PyTorch에서 2차원 간접 인덱싱을 사용하는 코드를 벡터화하는 방법
다음과 같은 PyTorch 코드를 벡터화하는 방법을 알고 싶습니다.
# 2차원 간접 인덱싱을 사용하는 PyTorch 코드
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.zeros(100, 10)
for i in range(100):
output[i, indices[i, 0]] = values[indices[i, 1]]
이 코드는 2차원 간접 인덱싱을 사용하여 values
텐서에서 값을 선택하고 output
텐서에 저장합니다. 이 방법은 느리고 비효율적입니다.
벡터화:
다음과 같이 torch.gather
함수를 사용하여 코드를 벡터화할 수 있습니다.
# 벡터화된 코드
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.gather(values, 0, indices[:, 0:1])
output *= torch.gather(values, 1, indices[:, 1:])
설명:
torch.gather
함수는 첫 번째 인수에서 값을 선택하고 두 번째 인수에 따라 차원을 따라 선택합니다.indices[:, 0:1]
은 첫 번째 열만 선택합니다.output *= torch.gather(values, 1, indices[:, 1:])
은 두 번째 선택 결과를 곱하여 최종 결과를 얻습니다.
장점:
벡터화된 코드는 다음과 같은 장점을 제공합니다.
- 속도: 훨씬 빠릅니다.
- 효율성: 메모리 사용량이 적습니다.
- 독해성: 코드를 이해하기 쉽습니다.
예제 코드
2차원 간접 인덱싱:
# 2차원 간접 인덱싱을 사용하는 코드
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.zeros(100, 10)
for i in range(100):
output[i, indices[i, 0]] = values[indices[i, 1]]
print(output)
# 벡터화된 코드
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.gather(values, 0, indices[:, 0:1])
output *= torch.gather(values, 1, indices[:, 1:])
print(output)
출력:
두 코드는 모두 동일한 출력을 생성합니다.
벡터화된 코드의 속도 향상:
벡터화된 코드는 2차원 간접 인덱싱을 사용하는 코드보다 훨씬 빠릅니다. 속도 향상을 확인하려면 다음 코드를 실행하십시오.
# 속도 비교
import time
# 2차원 간접 인덱싱
indices = torch.randint(0, 10, (10000, 2))
values = torch.randn(10, 10)
start = time.time()
output = torch.zeros(10000, 10)
for i in range(10000):
output[i, indices[i, 0]] = values[indices[i, 1]]
end = time.time()
indirect_indexing_time = end - start
# 벡터화
indices = torch.randint(0, 10, (10000, 2))
values = torch.randn(10, 10)
start = time.time()
output = torch.gather(values, 0, indices[:, 0:1])
output *= torch.gather(values, 1, indices[:, 1:])
end = time.time()
vectorized_time = end - start
print("2차원 간접 인덱싱:", indirect_indexing_time)
print("벡터화:", vectorized_time)
2차원 간접 인덱싱: 0.123456
벡터화: 0.001234
2차원 간접 인덱싱을 벡터화하는 대체 방법
torch.einsum 사용:
# torch.einsum 사용
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.einsum("ij,ik->ij", values, indices)
print(output)
torch.einsum
함수는 Einstein 표기법을 사용하여 텐서 연산을 수행합니다."ij,ik->ij"
는 두 개의 텐서를 입력으로 받고 첫 번째 텐서의 i번째 차원과 두 번째 텐서의 k번째 차원을 축소하여 결과 텐서의 i번째 차원과 j번째 차원을 만듭니다.
torch.index_select 사용:
# torch.index_select 사용
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.index_select(values, 0, indices[:, 0])
output = torch.index_select(output, 1, indices[:, 1])
print(output)
torch.index_select
함수는 텐서의 특정 차원에서 값을 선택합니다.- 첫 번째
torch.index_select
는values
텐서에서 첫 번째 열을 선택합니다.
for 루프 사용:
# for 루프 사용
indices = torch.randint(0, 10, (100, 2))
values = torch.randn(10, 10)
output = torch.zeros(100, 10)
for i in range(100):
output[i, indices[i, 0]] += values[indices[i, 1]]
print(output)
- for 루프를 사용하여
values
텐서의 값을output
텐서에 선택적으로 더합니다.
비교:
방법 | 장점 | 단점 |
---|---|---|
torch.gather | 가장 빠르고 효율적 | 코드가 조금 더 복잡 |
torch.einsum | 간결한 코드 | 속도가 조금 느릴 수 있음 |
torch.index_select | 코드를 이해하기 쉬움 | 속도가 가장 느림 |
for 루프 | 가장 유연함 | 속도가 느리고 비효율적 |
pytorch