PyTorch 소프트맥스: 사용할 차원
사용할 차원은 상황에 따라 달라집니다.
- 다중 클래스 분류: 각 샘플에 대한 클래스 예측 확률을 얻기 위해 마지막 차원(feature 차원)에 소프트맥스를 적용합니다.
- 시퀀스 모델링: 각 시퀀스 단계에 대한 다음 토큰 예측 확률을 얻기 위해 두 번째 차원(시퀀스 길이 차원)에 소프트맥스를 적용합니다.
다음은 각 상황에 대한 예시입니다.
다중 클래스 분류:
import torch
import torch.nn.functional as F
# 3개의 클래스를 가진 샘플 데이터
logits = torch.tensor([1, 2, 3])
# 마지막 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=0)
# 각 클래스 예측 확률 출력
print(probs)
시퀀스 모델링:
import torch
import torch.nn.functional as F
# 5개 단어 시퀀스 데이터
logits = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 두 번째 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=1)
# 각 시퀀스 단계에 대한 다음 토큰 예측 확률 출력
print(probs)
import torch
import torch.nn.functional as F
# 3개의 레이블을 가진 샘플 데이터
logits = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 마지막 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=1)
# 각 샘플에 대한 각 레이블 예측 확률 출력
print(probs)
참고:
dim
매개변수는 소프트맥스 함수를 적용할 차원을 지정합니다.dim
값은 0부터 시작하며, 음수 값은 입력 텐서의 마지막 차원부터 역순으로 차원을 지정합니다.
예제 코드
import torch
import torch.nn.functional as F
# 3개의 클래스를 가진 샘플 데이터
logits = torch.tensor([1, 2, 3])
# 마지막 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=0)
# 각 클래스 예측 확률 출력
print(probs)
출력:
tensor([0.09003057, 0.24472847, 0.66524096])
import torch
import torch.nn.functional as F
# 5개 단어 시퀀스 데이터
logits = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 두 번째 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=1)
# 각 시퀀스 단계에 대한 다음 토큰 예측 확률 출력
print(probs)
tensor([[0.09003057, 0.24472847, 0.66524096],
[0.09003057, 0.24472847, 0.66524096]])
import torch
import torch.nn.functional as F
# 3개의 레이블을 가진 샘플 데이터
logits = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 마지막 차원에 소프트맥스 적용
probs = F.softmax(logits, dim=1)
# 각 샘플에 대한 각 레이블 예측 확률 출력
print(probs)
tensor([[0.09003057, 0.24472847, 0.66524096],
[0.09003057, 0.24472847, 0.66524096]])
설명:
- 각 예제 코드는
torch.nn.functional.softmax
함수를 사용하여 소프트맥스 함수를 각 차원에 적용합니다. - 각 예제 코드는 각 차원에 대한 소프트맥스 함수 결과를 출력합니다.
- 이 예제 코드는 기본적인 예시이며, 실제 사용 환경에 맞게 수정해야 할 수도 있습니다.
- PyTorch 소프트맥스 문서를 참고하여 더 자세한 정보를 확인하세요.
PyTorch 소프트맥스 대체 방법
직접 구현:
소프트맥스 함수는 다음과 같은 수식으로 직접 구현할 수 있습니다.
def softmax(logits, dim):
exp_logits = torch.exp(logits - logits.max(dim=dim, keepdim=True)[0])
return exp_logits / exp_logits.sum(dim=dim, keepdim=True)
torch.nn.LogSoftmax 사용:
torch.nn.LogSoftmax
클래스는 소프트맥스 함수를 계산하고 로그 값을 출력합니다. 다음과 같이 사용할 수 있습니다.
import torch.nn as nn
# 3개의 클래스를 가진 샘플 데이터
logits = torch.tensor([1, 2, 3])
# 마지막 차원에 로그 소프트맥스 적용
log_probs = nn.LogSoftmax(dim=0)(logits)
# 각 클래스 로그 예측 확률 출력
print(log_probs)
torch.distributions.Categorical 사용:
torch.distributions.Categorical
클래스는 카테고리 분포를 나타냅니다. 다음과 같이 사용하여 소프트맥스 함수를 계산할 수 있습니다.
import torch.distributions as dist
# 3개의 클래스를 가진 샘플 데이터
logits = torch.tensor([1, 2, 3])
# 카테고리 분포 생성
probs = dist.Categorical(logits=logits)
# 각 클래스 예측 확률 출력
print(probs.probs)
jax.nn.softmax 사용:
Jax 라이브러리를 사용하면 jax.nn.softmax
함수를 사용하여 소프트맥스 함수를 계산할 수 있습니다.
import jax.nn as nn
# 3개의 클래스를 가진 샘플 데이터
logits = jnp.array([1, 2, 3])
# 마지막 차원에 소프트맥스 적용
probs = nn.softmax(logits, axis=0)
# 각 클래스 예측 확률 출력
print(probs)
대체 방법 선택:
- 직접 구현: 가장 유연하지만, 계산 속도가 느릴 수 있습니다.
torch.nn.LogSoftmax
: 로그 소프트맥스 값을 필요로 하는 경우 유용합니다.torch.distributions.Categorical
: 확률 분포 계산에 유용합니다.jax.nn.softmax
: Jax 라이브러리를 사용하는 경우 유용합니다.
- 각 방법의 장단점을 고려하여 상황에 맞는 대체 방법을 선택해야 합니다.
python pytorch