코딩걸음마

[딥러닝] Pytorch 행렬/텐서의 곱셈(Matrix/tensor Multiplication) 본문

딥러닝_Pytorch

[딥러닝] Pytorch 행렬/텐서의 곱셈(Matrix/tensor Multiplication)

코딩걸음마 2022. 6. 24. 22:15
728x90

1. 행렬의 곱셈

행렬의 곱셈은 내적(inner product, Dot product)이라고한다. 딥러닝 과정에서 중요한 연산 중 하나이다.

행렬의 곱셈의 조건이 있다. 다음과 같은 두개의 행렬이 있다고 가정해보자

x = torch.size([ a , b ])

y  = torch.size([ c , d ]) 라고 할 때

b = c 조건이 성립되어야하며, 내적 결과는  torch.size([ a , d ])가 나온다.

import torch
x = torch.FloatTensor([[1, 2],
                       [3, 4],
                       [5, 6]])
y = torch.FloatTensor([[1, 2],
                       [1, 2]])

print(x.size(), y.size())
torch.Size([3, 2]) torch.Size([2, 2])

두 행렬의 크기를 파악한 결과, b와 c가 같으므로 내적연산이 가능하며, (3, 2)의 행렬이 결과로 나올 것이다.

행렬의 내적연산은 다음과 같다.

 torch.matmul(x, y)

z = torch.matmul(x, y)
z
tensor([[ 3.,  6.],
        [ 7., 14.],
        [11., 22.]])

 

 

2. 텐서의 곱셈

물론 tensor간의 곱셈도 있다. 이 또한 딥러닝 과정에서 중요한 연산 중 하나이다.

텐서의 곱셈에서도 조건이 있다. 다음과 같은 두개의 tensor가 있다고 가정해보자

x = torch.size([ N , n, h ])

y  = torch.size([ N, h , m ]) 라고 할 때

N= N 조건과 h = h 조건이 동시에 성립되어야하며,  결과 tensor는  torch.size([ N , n, m ])가 나온다.

x = torch.FloatTensor([[[1, 2],
                        [3, 4],
                        [5, 6]],
                       [[7, 8],
                        [9, 10],
                        [11, 12]],
                       [[13, 14],
                        [15, 16],
                        [17, 18]]])
y = torch.FloatTensor([[[1, 2, 2],
                        [1, 2, 2]],
                       [[1, 3, 3],
                        [1, 3, 3]],
                       [[1, 4, 4],
                        [1, 4, 4]]])

print(x.size(), y.size())
torch.Size([3, 3, 2]) torch.Size([3, 2, 3])

두 행렬의 크기를 파악한 결과, N이 3으로 같고 h가 2로 같으므로 곱셈연산이 가능하며, (3, 3, 3)의 텐서이 결과로 나올 것이다.

텐서의 곱셈연산은 다음과 같다.

 torch.bmm(x, y)

z = torch.bmm(x, y)
print(z.size())
z
tensor([[[  3.,   6.,   6.],
         [  7.,  14.,  14.],
         [ 11.,  22.,  22.]],

        [[ 15.,  45.,  45.],
         [ 19.,  57.,  57.],
         [ 23.,  69.,  69.]],

        [[ 27., 108., 108.],
         [ 31., 124., 124.],
         [ 35., 140., 140.]]])
728x90
Comments