Simple Self-Attention from Scratch

This vignette describes how to implement the attention mechanism - which forms the basis of transformers - in the R language.

We begin by generating encoder representations of four different words.

# encoder representations of four different words
word_1 = matrix(c(1,0,0), nrow=1)
word_2 = matrix(c(0,1,0), nrow=1)
word_3 = matrix(c(1,1,0), nrow=1)
word_4 = matrix(c(0,0,1), nrow=1)

Next, we stack the word embeddings into a single array (in this case a matrix) which we call words.

# stacking the word embeddings into a single array
words = rbind(word_1,
              word_2,
              word_3,
              word_4)

Let’s see what this looks like.

print(words)
#>      [,1] [,2] [,3]
#> [1,]    1    0    0
#> [2,]    0    1    0
#> [3,]    1    1    0
#> [4,]    0    0    1

Next, we generate random integers on the domain [0,3].

# initializing the weight matrices (with random values)
set.seed(0)
W_Q = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_K = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)
W_V = matrix(floor(runif(9, min=0, max=3)),nrow=3,ncol=3)

Next, we generate the Queries (Q), Keys (K), and Values (V). The %*% operator performs the matrix multiplication. You can view the R help page using help('%*%') (or the online An Introduction to R).

# generating the queries, keys and values
Q = words %*% W_Q
K = words %*% W_K
V = words %*% W_V

Following this, we score the Queries (Q) against the Key (K) vectors (which are transposed for the multiplation using t(), see help('t') for more info).

# scoring the query vectors against all key vectors
scores = Q %*% t(K)
print(scores)
#>      [,1] [,2] [,3] [,4]
#> [1,]    6    4   10    5
#> [2,]    4    6   10    6
#> [3,]   10   10   20   11
#> [4,]    3    1    4    2

We now generate the weights matrix.

weights = attention::ComputeWeights(scores)

Let’s have a look at the weights matrix.

print(weights)
#>            [,1]       [,2]     [,3]       [,4]
#> [1,] -0.2986355 -2.6877197 4.479533 -1.4931776
#> [2,] -3.1208558 -0.6241712 4.369198 -0.6241712
#> [3,] -1.7790165 -1.7790165 4.690134 -1.1321014
#> [4,]  1.2167336 -3.6502008 3.650201 -1.2167336

Finally, we compute the attention as a weighted sum of the value vectors (which are combined in the matrix V).

# computing the attention by a weighted sum of the value vectors
attention = weights %*% V

Now we can view the results using:

print(attention)
#>          [,1]     [,2]       [,3]
#> [1,] 7.167252 6.868617 -1.4931776
#> [2,] 4.993369 1.872514 -0.6241712
#> [3,] 6.469151 4.690134 -1.1321014
#> [4,] 7.300402 8.517135 -1.2167336