.box_article .article_cont p code
StyleCLIP 구현하기!
Reference:
오늘 스터디 목표: 아래 포스팅도 함께 참고하여, Pytorch로 구현된 딥러닝 모델 구조를 이해해보고, github에 있는 딥러닝 모델 오픈소스를 활용하기 위한 간단한 테크닉들을 익혀보자!
텍스트 명령어를 이용해 이미지를 원하는 방식으로 수정할 수 있는 모델이다.
고해상도 이미지를 생성하기 위한 효율적인 아키텍처
: 주어진 이미지 x로부터 그와 매칭되는 적절한 latent vector인 w를 찾는다.
latent vector w는 GAN 네트워크에 forwarding 됨으로써 특정한 이미지를 만들어낼 수 있다. 이렇게 만들어진 이미지와 원본 이미지 x 사이에서 similarity를 계산한다. 이렇게 계산된 similarity loss를 이용해서 backporpagation을 진행하고 gradient를 구해서 latent vector를 업데이트하는 방식을 반복한다.
1) Latent vector dataset generation
latent vector를 랜덤하게 샘플링해서 아주 많은 양의 latent vectors와 image pair를 먼저 준비한다. 즉, 어떤 latent vector를 넣었을 때 어떤 이미지가 나올 것이다라는 것을 각각 가상으로 묶어서 데이터 셋을 만든 뒤에 이러한 데이터 셋을 이용해서 인코더 네트워크를 학습한다.
2) Training an encoder network
학습을 하는 데 매우 오래 걸리지만, 학습이 된 이후에는 특정한 이미지가 들어왔을 때 그 이미지에 대한 적절한 latent vector를 찾을 수 있다.
Interpolation은 a라는 포인트에서 b라는 포인트까지 서서히 이동을 시켜보는 것이다.
1) Learns a boundary of an attribute (such a gender, age)
GAN의 latent space에서 특정한 attribute에 대한 decision boundary를 학습한다. 위 예시에서는 남자와 여자를 분류하는 decision boundary를 학습하고 있다.
decision boundary를 학습한 이후에는 두 개의 attribute를 가로지르는 하나의 direction vector를 찾는다. 이 direction vector를 따라 주어진 latent vector를 업데이트하여 정해진 attribute를 바꿀 수 있다.
2) Update a latent vector across the boundary
위에서 남성과 여성을 구분하는 classification boundary를 학습했기 때문에 이 boundary를 따라 이동을 하면 특정 이미지를 여자에서 남자로 바꿀 수 있다.
image encoder와 text encdoer를 jointly train시킨 네트워크
이미지와 텍스트를 같은 space에 인코딩 시킬 수 있다.
CLIP은 매우 큰 크기의 데이터 셋을 이용해서 두 개의 인코더를 학습했고, 특정한 이미지와 텍스트 사이에서 similarity를 구하기 위해 효과적으로 사용할 수 있다. 즉, 어떤 이미지와 어떤 텍스트가 얼마나 닮아있고, 얼마나 유사한 시멘틱한 의미를 가지고 있는지 판단할 수 있다.
어떤 텍스트를 넣었을 때, 그 텍스트에 맞게 이미지를 바꿀 수 있다.
예를 들어, text를 "without makeup"로 줬을 때, 화장한 이미지가 화장기가 없는 이미지와 유사해지는 방향으로 embedding vector가 업데이트 되도록 latent vector가 optimize된다.
Latent optimization: a simple approach for leveraging CLIP to guide image manipulation.
latent vector를 업데이트해서 특정한 인풋 이미지와 유사한 형태를 갖도록 만들면서 그와 동시에 우리가 넣은 텍스트 프롬프트와 유사한 시멘틱 정보를 갖는 이미지를 얻을 수 있는 방향으로 latent vector를 업데이트한다.
Latent mapper is trained to manipulate the desired attributes of the image as indicated by the text prompt t, while preserving the other visual attributes of the input image.
이와 같은 인코더 네트워크를 학습하는 데는 열 몇 시간 정도 소요될 수 있다. 그러한 한 번 학습이 되고 나면 특정한 이미지에 대해 결과를 한 번의 forward로 얻을 수 있다.
Global Direction은 말 그대로 글로벌하게 사용할 수 있는 방향이기 때문에 어떠한 입력이 들어와도 사용할 수 있다.
먼저 Curly Hair, Pale, Hi-top Fade, Make up과 같은 다양한 특징에 대한 Global direction을 하나씩 찾은 다음, 특정한 이미지 s가 들어왔을 때, 단순히 s를 그 방향대로 이동한 시키면 원하는 attribute만 바뀐다.
이러한 방향들을 미리 많이 찾아 놓기만 하면 나중에 어떤 이미지가 들어와도 곧바로 그 방향대로 이동시키면 이미지를 변화시킬 수 있다.
* Colab 환경에서 [런타임 유형]을 GPU로 바꾸고 실행해주세요!
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!git clone https://github.com/ndb796/StyleCLIP-Tutorial
%cd StyleCLIP-Tutorial
위 코드를 실행하면 아래 Github 저장소를
내 컴퓨터(혹은 Colab 환경)에 복제하여 가져올 수 있다.
!wget https://postechackr-my.sharepoint.com/:u:/g/personal/dongbinna_postech_ac_kr/EVv6yusEt1tFhrL3TCu0Ta4BlpzW3eBMTS0yTPKodNHsNA?download=1 -O stylegan2-ffhq-config-f.pt
import torch
from stylegan2.model import Generator
g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load('stylegan2-ffhq-config-f.pt')["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.cuda()
모델 가중치를 불러오기 위해서는, 먼저 모델의 인스턴스(instance)를 생성한 다음에 load_state_dict() 메서드를 사용해서 매개변수들을 불러온다.
학습할 때만 사용하는 개념인 Dropout이나 Batchnorm 등을 비활성화시킨다. 즉, evaluation 과정에서 사용하지 않아야 하는 layer들을 알아서 off 시켜주는 함수이다.
import clip
class CLIPLoss(torch.nn.Module):
def __init__(self):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
self.upsample = torch.nn.Upsample(scale_factor=7)
self.avg_pool = torch.nn.AvgPool2d(kernel_size=32)
def forward(self, image, text):
image = self.avg_pool(self.upsample(image))
similarity = 1 - self.model(image, text)[0] / 100
return similarity
CLIP 네트워크에 포함되어 있는 Text Encoder와 Image Encoder, 그리고 StyleGAN을 이용해서 이미지 manipulation을 진행한다. latent vector w를 업데이트 할 때, latent vector w로 만들어진 이미지 임베딩이 입력한 text prompt와의 similarity가 높아지는 방향으로 업데이트 한다.
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
mean_latent = g_ema.mean_latent(4096)
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
img_orig, latent_code_init = g_ema([latent_code_init_not_trunc], return_latents=True,
truncation=0.7, truncation_latent=mean_latent)
# Visualize a random latent vector.
image = ToPILImage()(make_grid(img_orig.detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
h, w = image.size
image.resize((h // 2, w // 2))
Output:
from argparse import Namespace
args = Namespace()
args.description = 'A really sad face'
args.lr_rampup = 0.05
args.lr = 0.1
args.step = 150
args.l2_lambda = 0.005 # The weight for similarity to the original image.
args.save_intermediate_image_every = 1
args.results_dir = 'results'
import os
import math
import torchvision
from torch import optim
# The learning rate adjustment function.
def get_lr(t, initial_lr, rampdown=0.50, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
The learning rate adjustment function
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
os.makedirs(args.results_dir, exist_ok=True)
# Initialize the latent vector to be updated.
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss()
optimizer = optim.Adam([latent], lr=args.lr)
Initialize the latent vector to be updated
for i in range(args.step):
# Adjust the learning rate.
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
# Generate an image using the latent vector.
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False)
# Calculate the loss value.
c_loss = clip_loss(img_gen, text_inputs)
l2_loss = ((latent_code_init - latent) ** 2).sum()
loss = c_loss + args.l2_lambda * l2_loss
# Get gradient and update the latent vector.
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log the current state.
print(f"lr: {lr}, loss: {loss.item():.4f}")
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
with torch.no_grad():
img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False)
torchvision.utils.save_image(img_gen, f"results/{str(i).zfill(5)}.png", normalize=True, range=(-1, 1))
with torch.no_grad():
img_orig, _ = g_ema([latent_code_init], input_is_latent=True, randomize_noise=False)
# Display the initial image and result image.
final_result = torch.cat([img_orig, img_gen])
torchvision.utils.save_image(final_result.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), normalize=True, scale_each=True, range=(-1, 1))
result_image = ToPILImage()(make_grid(final_result.detach().cpu(), normalize=True, scale_each=True, range=(-1, 1), padding=0))
h, w = result_image.size
result_image.resize((h // 2, w // 2))
!ffmpeg -r 15 -i results/%05d.png -c:v libx264 -vf fps=25 -pix_fmt yuv420p out.mp4
from google.colab import files
files.download('out.mp4')
[코드 분석 스터디] Top 100 Korean Dramas (0) | 2021.11.23 |
---|---|
[코드 분석 스터디] Spooky Author Prediction_NLP tutorial (0) | 2021.11.14 |
[코드 분석 스터디] Natural Language Processing : Bag of Words for IMDB movie review (0) | 2021.11.13 |
[코드 분석 스터디] Segmentation : Sementic Segmentation - CARLA Image Road segmentation (1) | 2021.11.04 |
[코드 분석 스터디] Time Series Regression - Predict Future Sales 커널 필사 (2) | 2021.10.01 |
댓글 영역