import torch
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
import matplotlib
matplotlib.use('webagg')

# Vytvoření jednoduché konvoluční vrstvy
# 1 vstupní kanál, několik výstupních kanálů (filtrů), kernel 3x3
out_channels = 3
conv_layer = nn.Conv2d(in_channels=1, out_channels=out_channels, kernel_size=3, bias=False)

# Přístup k váhám vrstvy
print("Tvar vah:", conv_layer.weight.shape)  
# torch.Size([3, 1, 3, 3]) - [out_channels, in_channels, height, width]
print("\nVáhy (filtry):")
print(conv_layer.weight.data)

# Načtení obrazu a převod na grayscale
image = cv2.imread('09839.ppm')
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image_gray = cv2.resize(image_gray, (256, 256))

# Použití ToTensor() - automaticky převede na (1, H, W) a normalizuje na [0, 1]
transform = transforms.ToTensor()
image_tensor = transform(image_gray)  # shape: (1, 256, 256)

# Přidání batch dimenze
image_tensor = image_tensor.unsqueeze(0)  # shape: (1, 1, 256, 256)

# Aplikace konvoluce
output = conv_layer(image_tensor)

# Jednoduchá vizualizace: vstup + 3 výstupy
fig, axes = plt.subplots(1, out_channels+1, figsize=(16, 4))

# Vstupní obraz
axes[0].imshow(image_gray, cmap='gray')
axes[0].set_title('Vstupní obraz', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Výstupní feature mapy
for i in range(out_channels):
    axes[i+1].imshow(output[0, i, :, :].detach().numpy(), cmap='gray')
    axes[i+1].set_title(f"kernel-{i}", fontsize=14, fontweight='bold')
    axes[i+1].axis('off')


plt.show()
