import torch
import matplotlib.pyplot as plt
import numpy as np
def generate_data(N):
X = torch.randint(0,9,size = (N,10))
num2s = torch.count_nonzero(X==2, dim=-1)
num4s = torch.count_nonzero(X==4, dim=-1)
labels = num4s > num2s
return X, labels.reshape(-1,1).float()
X, y = generate_data(123)
class AttentionModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.query = torch.nn.Parameter(torch.randn(1,32))
self.embed_func = torch.nn.Embedding(10, embedding_dim=16)
self.key_func = torch.nn.Linear(16, 32)
self.value_func = torch.nn.Sequential(
torch.nn.Linear(16,32),
torch.nn.ReLU(),
torch.nn.Linear(32,1)
)
self.head_mlp = torch.nn.Sequential(
torch.nn.Linear(1,32),
torch.nn.ReLU(),
torch.nn.Linear(32,1),
torch.nn.Sigmoid()
)
def forward(self,X):
embedX = self.embed_func(X) # [123, 10, 16]
keys = self.key_func(embedX) # [123, 10, 32]
qk = torch.einsum('ie, bje -> bij', self.query, keys) # [1,32] x [123, 10, 32] -> [123, 1, 10]
qk = qk / (32**0.5)
att = torch.nn.functional.softmax(qk, dim=-1) # [123, 1, 10]
values = self.value_func(embedX) # [123, 10, 1]
summary = torch.einsum('bij, bje -> bie', att, values)[:,0,:] # [123, 1]
pred = self.head_mlp(summary) # [123, 1]
return pred, att, values
def train():
model = AttentionModel()
opt = torch.optim.Adam(model.parameters(), lr = 3e-4)
losses = []
for idx in range(5000):
p, a, v = model(X)
loss = torch.nn.functional.binary_cross_entropy(p, y)
losses.append(float(loss))
if idx % 100 == 0:
print(float(loss))
plt.plot(losses)
plt.gcf().set_size_inches(2,2)
plt.show()
loss.backward()
opt.step()
opt.zero_grad()
return model
if __name__ == "__main__":
model = train()
with torch.no_grad():
X = torch.LongTensor([[1,7,2,0,2,1,3,4,8,6]])
p, a, v = model(X)
plt.imshow(a[0], vmin=0, vmax=1)
for x,y,d in zip(np.arange(10),np.zeros(10),X[0]):
plt.text(x,y,int(d), c = 'r' if d in [4,2] else 'w')
plt.gcf().set_size_inches(10,1)
plt.show()
# message = v[:,:,0]
message = np.where(a[0] > 0.1, v[:,:,0], np.nan*v[:,:,0])
plt.imshow(message)
for x,y,d in zip(np.arange(10),np.zeros(10),message[0]):
plt.text(x-0.5,y,f'{d:2f}', c = 'w')
plt.gcf().set_size_inches(10,1)
plt.show()