import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from skimage.util.shape import view_as_blocks

import torch
import torch.nn as nn

# Demonstrace první části ViT

########################################################################
# 1. Načtení jednoho obrázku (H, W, C) a rozdělení na bloky (patches)

img = io.imread("dog.jpg")      # očekává se cca 224x224x3
H, W, C = img.shape
P = 16                              # velikost patche

# view_as_blocks: (H, W, C) -> (H//P, W//P, P, P, C)
blocks_with_extra_dim = view_as_blocks(img, block_shape=(P, P, C))
# Odstranění přebytečné dimenze (osy 2)
blocks = blocks_with_extra_dim.squeeze(axis=2)
print(f"Shape po opravě (squeeze): {blocks.shape}") # Očekáváme (14, 14, 16, 16, 3)


Hn, Wn, Ph, Pw, C = blocks.shape
all_blocks_for_viz = blocks.reshape(-1, Ph, Pw, C) # (196, 16, 16, 3)

def show_all_patches(patches, patches_per_row):
    """
    Zobrazí všechny patche v mřížce, která odpovídá jejich pozici v obrázku.
    """
    n_patches = len(patches)
    # Počet řádků a sloupců mřížky pro vizualizaci
    cols = patches_per_row
    rows = int(np.ceil(n_patches / cols))

    # Zvětšíme obrázek, aby se vešlo 14x14 malých obrázků
    plt.figure(figsize=(10, 10))

    for i in range(n_patches):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(patches[i])
        plt.axis('off')

    plt.suptitle("Všech 196 patchů z obrázku", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()
    
show_all_patches(all_blocks_for_viz, patches_per_row=14)  

########################################################################
# 2. Další krok je vzít každý ten patch o rozměru (16, 16, 3) a "narovnat" ho do jednoho dlouhého vektoru - mechanické zploštění (flattening).

# Počet patchů
num_patches = Hn * Wn  # 196

# Dimenze jednoho zploštěného patch vektoru
patch_dim = Ph * Pw * C # 16 * 16 * 3 = 768

# Tvar se změní z (14, 14, 16, 16, 3) na (196, 768)
patch_vectors = blocks.reshape(num_patches, patch_dim)

print(f"Tvar výsledného pole patch vektorů: {patch_vectors.shape}")
print(f"Dimenze jednoho vektoru: {patch_vectors.shape[1]}")

patch_vectors_normalized = patch_vectors.astype(np.float32) / 255.0

########################################################################
# 3. Lineární projekce (nn.Linear)
#    Dalším krokem bude tyto vektory o délce 768 poslat do lineární vrstvy (projekce), 
#    aby se z nich staly finální tokeny (embeddingy) o dimenzi D, což je také obvykle 768.
#    Vstup: Surový vektor o délce 768 z předchozí fáze.
#    Výstup: "Embedding" vektor, typicky také o délce 768 (v případě ViT-Base). Tento nový vektor už je "chytřejší" reprezentací patche.
#    Výstup z nn.Linear by mohl být i např. 192 (embedding dimenze pro ViT-Tiny)

#   Poznámka: pro rozdělení obrazu do bloků (patches) a vytvoření "Embedding" vektoru je možné také použít nn.Conv2d vrstvu

# Pro ViT potřebujeme batch dimenzi, takže přidáme `unsqueeze(0)`
patch_vectors_tensor = torch.from_numpy(patch_vectors_normalized).unsqueeze(0)
print(f"Tvar tensoru před lineární vrstvou: {patch_vectors_tensor.shape}") # (1, 196, 768)

# Vytvoříme lineární vrstvu
# nn.Linear(vstupní_dimenze, výstupní_dimenze) 768x768
embedding_dim = 768
input_dim = patch_vectors_tensor.shape[-1]
linear_projection = nn.Linear(input_dim, embedding_dim)

# Aplikujeme lineární projekci na naše patch vektory
# Tím získáme finální patch embeddings (tokeny)
patch_embeddings = linear_projection(patch_vectors_tensor)

print(f"Tvar patch embeddings po lineární vrstvě: {patch_embeddings.shape}") # (1, 196, 768)


########################################################################
# 4. K tomu, abychom z našich patch_embeddings vytvořili finální vstupní sekvenci z0 nám chybí CLS (Classification) Token a Position (Poziční) Embedding
# CLS (Classification) Token: Speciální, naučitelný vektor, který se přidá na začátek sekvence. Jeho úkolem je v průběhu průchodu sítí "nasbírat" globální informaci o celém obrázku. Pro finální klasifikaci se použije pouze tento jeden token.
# Position (Poziční) Embedding: Sada naučitelných vektorů, které se přičtou ke každému tokenu v sekvenci (včetně CLS tokenu). Protože mechanismus self-attention sám o sobě nerozlišuje pořadí tokenů, tyto embeddingy dodají modelu klíčovou informaci o tom, kde se který patch v původním obrázku nacházel

# Počet patchů a dimenze embeddingu
num_patches = patch_embeddings.shape[1]  # 196
embedding_dim = patch_embeddings.shape[2] # 768
batch_size = patch_embeddings.shape[0]   # 1

# CLS Token - je to naučitelný parametr. Vytvoříme ho s tvarem (1, 1, embedding_dim).
cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
print(f"Tvar CLS tokenu: {cls_token.shape}")

# Přidání CLS tokenu na začátek sekvence
cls_token_expanded = cls_token.expand(batch_size, -1, -1)
tokens_with_cls = torch.cat((cls_token_expanded, patch_embeddings), dim=1)
print(f"Tvar sekvence po přidání CLS tokenu: {tokens_with_cls.shape}") # (1, 197, 768)

# Vytvoření Pozičních Embeddingů
# Potřebujeme jeden embedding pro každý token v sekvenci (CLS + patche).
# Celkem tedy 1 + 196 = 197 tokenů.
# Opět je to naučitelný parametr.

positional_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
print(f"Tvar pozičních embeddingů: {positional_embeddings.shape}") # (1, 197, 768)

# 4. Přičtení pozičních embeddingů k tokenům
# Tím získáme finální z_0
z0 = tokens_with_cls + positional_embeddings
print(f"Finální tvar z0 (vstup do Transformeru): {z0.shape}") # (1, 197, 768)
print(f"Nyní je tensor z0 připravený ke vstupu do první vrstvy Transformer Encoderu")
