import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
device = "cuda" if torch.cuda.is_available() else "cpu"
# DDPM schedule
T = 500
betas = torch.linspace(1e-4, 0.02, T, device=device)
alphas = 1.0 - betas
# alpha_bar[0] = 1
alpha_bar = torch.cat([torch.ones(1, device=device),torch.cumprod(alphas, dim=0)])
# dataset
def get_batch(n=512):
x, _ = make_moons(n_samples=n, noise=0.05)
x = torch.tensor(x,dtype=torch.float32,device=device)
# normalize for more stable training
x = (x - x.mean(0)) / x.std(0)
return x
# sinusoidal timestep embedding
EMB_DIM = 64
half = EMB_DIM // 2
freqs = torch.exp(-torch.arange(half, device=device)* math.log(10000)/ (half - 1))
def timestep_embedding(t):
args = t[:, None].float() * freqs[None] # t: (batch,)
return torch.cat([torch.sin(args), torch.cos(args)],dim=-1)
# epsilon predictor
class EpsNet(nn.Module):
def __init__(self, data_dim=2, t_dim=EMB_DIM, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(data_dim + t_dim, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
nn.SiLU(),
nn.Linear(hidden, data_dim)
)
def forward(self, x, t):
t_emb = timestep_embedding(t)
return self.net(torch.cat([x, t_emb], dim=-1))
model = EpsNet().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
# forward process
def q_sample(x0, t, eps=None):
if eps is None:
eps = torch.randn_like(x0)
ab = alpha_bar[t][:, None]
xt = ab.sqrt() * x0 + (1 - ab).sqrt() * eps
return xt
# training
for step in range(5000):
x0 = get_batch()
# sample timestep uniformly
t = torch.randint(1, T + 1, (x0.shape[0],), device=device)
eps = torch.randn_like(x0)
xt = q_sample(x0, t, eps)
eps_pred = model(xt, t)
loss = F.mse_loss(eps_pred,eps)
opt.zero_grad()
loss.backward()
opt.step()
if step % 1000 == 0:
print(f"step {step:5d} loss {loss.item():.4f}")
# visualize forward diffusion
@torch.no_grad()
def plot_forward_process():
x0 = get_batch(2000)
timesteps = [0, int(T/20), int(T/10), int(T/5), int(T)]
fig, axes = plt.subplots(1,len(timesteps),figsize=(15, 3))
for ax, t in zip(axes, timesteps):
tt = torch.full((len(x0),), t, device=device, dtype=torch.long)
xt = q_sample(x0, tt)
xt = xt.cpu()
ax.scatter(xt[:, 0], xt[:, 1], s=2)
ax.set_title(f"t={t}")
ax.axis("equal")
ax.axis("off")
plt.tight_layout()
plt.show()
plot_forward_process()
# reverse process
@torch.no_grad()
def sample(n=2000):
model.eval()
x = torch.randn(n, 2, device=device)
for t in reversed(range(1, T + 1)):
tt = torch.full((n,), t, device=device, dtype=torch.long)
eps_pred = model(x,tt)
beta_t = betas[t - 1]
alpha_t = alphas[t - 1]
ab_t = alpha_bar[t]
ab_prev = alpha_bar[t - 1]
mean = (x- beta_t/ torch.sqrt(1 - ab_t)* eps_pred) / torch.sqrt(alpha_t)
if t > 1:
posterior_var = (1 - ab_prev)/ (1 - ab_t)* beta_t
noise = torch.randn_like(x)
x = mean+ torch.sqrt(posterior_var)* noise
else:
x = mean
model.train()
return x.cpu()
# generate samples
samples = sample()
plt.figure(figsize=(6, 6))
plt.scatter(samples[:, 0], samples[:, 1], s=4)
plt.axis("equal")
plt.title("DDPM samples")
plt.show()