# =========================
# GAN on MNIST — Code à trous
# =========================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import os

# --------- Hyperparams ----------
latent_dim = _________          # ex: 100
hidden_dim = 256
img_size = 28 * 28
batch_size = 128
num_epochs = 10
lr = 2e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- Data ----------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((_______,), (_______,))   # normaliser en [-1, 1]
])
train_set = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

# --------- Models ----------
class Generator(nn.Module):
    def __init__(self, z_dim=latent_dim, img_dim=img_size, h=hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(_________, h),     # TODO: entrée = ?
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(h, h * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(h * 2, img_dim),
            nn.Tanh()                    # sortie en [-1, 1]
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_dim=img_size, h=hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(_________, h),     # TODO: entrée = ?
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(h, h // 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(h // 2, 1)         # logits (pas de Sigmoid ici)
        )
    def forward(self, x):
        return self.net(x)

G = Generator().to(device)
D = Discriminator().to(device)

# --------- Loss & Optimizers ----------
criterion = nn.BCEWithLogitsLoss()
optim_G = optim.Adam(_________.parameters(), lr=lr, betas=(0.5, 0.999))  # TODO
optim_D = optim.Adam(_________.parameters(), lr=lr, betas=(0.5, 0.999))  # TODO

# --------- Utils ----------
os.makedirs("samples_gan", exist_ok=True)
fixed_z = torch.randn(64, _________, device=device)  # pour visualiser l'évolution du G

def save_samples(tensor, step):
    # tensor in [-1,1] -> [0,1]
    imgs = (tensor + 1) / 2.0
    grid = utils.make_grid(imgs, nrow=8)
    utils.save_image(grid, f"samples_gan/step_{step:06d}.png")

# --------- Training ----------
step = 0
for epoch in range(num_epochs):
    for real, _ in train_loader:
        real = real.view(real.size(0), -1).to(device)

        # ===== Train Discriminator =====
        z = torch.randn(real.size(0), _________, device=device)
        fake = _________(z).detach()                        # TODO: utiliser le générateur
        logits_real = _________(real)                       # TODO: sortie D pour vrais
        logits_fake = _________(fake)                       # TODO: sortie D pour faux

        # labels
        real_labels = torch.ones(real.size(0), 1, device=device)
        fake_labels = torch.zeros(real.size(0), 1, device=device)

        loss_D_real = criterion(logits_real, real_labels)
        loss_D_fake = criterion(logits_fake, fake_labels)
        loss_D = (loss_D_real + loss_D_fake) / 2

        optim_D.zero_grad()
        loss_D.backward()
        optim_D.step()

        # ===== Train Generator =====
        z = torch.randn(real.size(0), _________, device=device)
        fake = _________(z)                                 # TODO: utiliser le générateur
        logits_fake_for_G = _________(fake)                 # TODO: sortie D

        # Le générateur veut que D sorte "réel" sur les faux
        loss_G = criterion(logits_fake_for_G, _________)    # TODO: label = 1
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        # ---- Logging / save ----
        if step % 200 == 0:
            with torch.no_grad():
                samples = _________.decode if hasattr(G, "decode") else G
                imgs = samples(fixed_z).view(-1, 1, 28, 28).detach().cpu()
                save_samples(imgs, step)
            print(f"Epoch [{epoch+1}/{num_epochs}] Step {step} | D: {loss_D.item():.4f} | G: {loss_G.item():.4f}")
        step += 1

print("Training done. Check samples in folder: samples_gan")