import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import timm  # <- 1. Import knihovny timm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Trénování bude probíhat na zařízení: {device}')

IMG_SIZE = 224  # ViT modely typicky očekávají 224x224

def get_class_name(class_id):
    label_names = [
        "Speed limit (20km/h)", "Speed limit (30km/h)", "Speed limit (50km/h)",
        "Speed limit (60km/h)", "Speed limit (70km/h)", "Speed limit (80km/h)",
        "End of speed limit (80km/h)", "Speed limit (100km/h)", "Speed limit (120km/h)",
        "No passing", "No passing veh over 3.5 tons", "Right-of-way at intersection",
        "Priority road", "Yield", "Stop", "No vehicles", "Veh > 3.5 tons prohibited",
        "No entry", "General caution", "Dangerous curve left", "Dangerous curve right",
        "Double curve", "Bumpy road", "Slippery road", "Road narrows on the right",
        "Road work", "Traffic signals", "Pedestrians", "Children crossing",
        "Bicycles crossing", "Beware of ice/snow", "Wild animals crossing",
        "End speed + passing limits", "Turn right ahead", "Turn left ahead",
        "Ahead only", "Go straight or right", "Go straight or left", "Keep right",
        "Keep left", "Roundabout mandatory", "End of no passing",
        "End no passing veh > 3.5 tons"
    ]
    return label_names[class_id]


# Transformace dat s novou velikostí obrázku
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
    # Aplikace stejné normalizace, jaká byla použita při trénování na datasetu ImageNet
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Načtení datasetu
trainset = torchvision.datasets.GTSRB(root='./data', split="train", download=True, transform=transform)
testset = torchvision.datasets.GTSRB(root='./data', split="test", download=True, transform=transform)

BATH_SIZE = 8 # Můžete zvětšit, pokud to vaše GPU paměť dovolí
testloader = torch.utils.data.DataLoader(testset, batch_size=BATH_SIZE, shuffle=False, num_workers=2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATH_SIZE, shuffle=True, num_workers=2)


# model z timm
model = timm.create_model(
    'vit_tiny_patch16_224',
    pretrained=True,  # Použijeme předtrénované váhy z ImageNet
    num_classes=43    # timm automaticky vymění poslední vrstvu za novou s 43 výstupy
)
model.to(device)  # Přesuneme model na GPU


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Zahajuji trénování...")
for epoch in range(6):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # Přesun dat na GPU v trénovací smyčce
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:  # Vypisujeme statistiky každých 100 mini-batchů
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Konec trénování')

# Evaluace modelu
correct = 0
total = 0
model.eval() # Přepnutí modelu do evaluačního módu
with torch.no_grad():
    for data in testloader:
        # Přesun dat na GPU i při evaluaci
        images, labels = data[0].to(device), data[1].to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct // total} %')


