Created
June 26, 2025 16:24
-
-
Save MrRjxrby/b4580c44c611d481809b6b429663ca76 to your computer and use it in GitHub Desktop.
GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers | |
# Генератор | |
def build_generator(): | |
model = tf.keras.Sequential() | |
model.add(layers.Dense(128, input_dim=100, activation='relu')) | |
model.add(layers.Dense(784, activation='sigmoid')) | |
return model | |
# Дискриминатор | |
def build_discriminator(): | |
model = tf.keras.Sequential() | |
model.add(layers.Dense(128, input_dim=784, activation='relu')) | |
model.add(layers.Dense(1, activation='sigmoid')) | |
return model | |
# Построение моделей | |
generator = build_generator() | |
discriminator = build_discriminator() | |
discriminator.compile(loss='binary_crossentropy', optimizer='adam') | |
# GAN (генератор + дискриминатор) | |
discriminator.trainable = False | |
gan_input = layers.Input(shape=(100,)) | |
gan_output = discriminator(generator(gan_input)) | |
gan = tf.keras.Model(gan_input, gan_output) | |
gan.compile(loss='binary_crossentropy', optimizer='adam') | |
# Тренировка GAN | |
import numpy as np | |
def train_gan(gan, generator, discriminator, epochs, batch_size=128): | |
for epoch in range(epochs): | |
# Генерация рандомного шума | |
noise = np.random.normal(0, 1, (batch_size, 100)) | |
generated_data = generator.predict(noise) | |
# Даем настоящие данные | |
real_data = np.random.rand(batch_size, 784) | |
# Тренировка дискриминатора | |
combined_data = np.concatenate([real_data, generated_data]) | |
labels = np.concatenate([np.ones(batch_size), np.zeros(batch_size)]) | |
discriminator.train_on_batch(combined_data, labels) | |
# Тренировка генератора | |
noise = np.random.normal(0, 1, (batch_size, 100)) | |
misleading_labels = np.ones(batch_size) | |
gan.train_on_batch(noise, misleading_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment