本案例的目的在于介绍如何使用深度生成模型实现图片风格的迁移。该案例使用深度学习第三方库Pytorch,实现一个Multi Style Generative Network(MSGNet)模型,该模型由descriptive network,tranformation network和loss network组成。然后,案例将展示如何自定义MSGNet的组件,并将各个不同的组件组合为一个完整的MSGNet模型。为了简化模型训练的复杂性,案例使用预训练的模型拟合内容图片和风格图片,并生成给定风格的新图片。

png

说明

深度学习的案例,包括数据和代码,大多涉及到多重文件夹的创建和操作。目前,数据嗨客平台的案例模块暂不支持多重文件夹的上传,案例代码不能实际在线运行。相关功能,数据嗨客研发团队正在开发中,敬请期待。代码和数据文件可以从网盘地址下载:链接: https://pan.baidu.com/s/1nvT567R 密码: vnvn

在本案例中,我们将展示如何使用Pytorch完成图片风格的迁移(Style Transfer)。首先我们导入需要使用的第三方库。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torchvision import transforms

from torch.optim import Adam
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.serialization import load_lua


import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from IPython.display import Image
from IPython.display import display

%matplotlib inline

接着,我们需要搭建一个神将网络模型来去完成风格迁移的任务,案例将构造一个MSGNet(Multi Style Generative Network)神经网络模型实现风格的转移。网络模型的示意图如下:

png

MSGnet与GAN的关系,正如作者在论文中所述

Relation to Generative Networks and Adversarial Training. Generative Adversarial Network (GAN) , which jointly trains an adversarial generator and discrimina- tor simultaneously, has catalyzed a surge of interest in the study of image generation . Recent work on image-to-image GAN adopts a conditional GAN to provide a general solution for some image-to-image genera- tion problems. For those problems, it was previously hard to define a loss function. However, the style transfer problem cannot be tackled using the conditional GAN framework, due to missing ground-truth image pairs. Instead, we adopt a discriminator / loss network that minimizes the perceptual difference of synthesized im- ages with content and style targets and provides the super- vision of the generative network learning. The initial idea of employing Gram Matrix to trigger the styles synthesis is inspired by a recent work that suggests using an encoder instead of random vector in GAN framework.

接下来,我们分别构造模型的组成部分。

构造类gramMatrix,实现返回输入参数的Gram矩阵的功能。

In [139]:
class GramMatrix(nn.Module):
    def forward(self, y):
        (b, ch, h, w) = y.size()
        features = y.view(b, ch, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (ch * h * w)
        return gram

自定义卷积层。

In [140]:
class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = int(np.floor(kernel_size / 2))
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

构造上采样的卷积层,作用是上采样输入,并进行卷积操作。与ConvTranspose2d相比,效果更好。

In [141]:
class UpsampleConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample_layer = torch.nn.Upsample(scale_factor=upsample)
        self.reflection_padding = int(np.floor(kernel_size / 2))
        if self.reflection_padding != 0:
            self.reflection_pad = nn.ReflectionPad2d(self.reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        if self.upsample:
            x = self.upsample_layer(x)
        if self.reflection_padding != 0:
            x = self.reflection_pad(x)
        out = self.conv2d(x)
        return out

构造预激活残差模块,用来识别深度残差网络中的映射关系。

In [142]:
class Bottleneck(nn.Module):
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=nn.BatchNorm2d):
        super(Bottleneck, self).__init__()
        self.expansion = 4
        self.downsample = downsample
        if self.downsample is not None:
            self.residual_layer = nn.Conv2d(inplanes, planes * self.expansion,
                                                        kernel_size=1, stride=stride)
        conv_block = []
        conv_block += [norm_layer(inplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]
        conv_block += [norm_layer(planes),
                                    nn.ReLU(inplace=True),
                                    ConvLayer(planes, planes, kernel_size=3, stride=stride)]
        conv_block += [norm_layer(planes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)]
        self.conv_block = nn.Sequential(*conv_block)
        
    def forward(self, x):
        if self.downsample is not None:
            residual = self.residual_layer(x)
        else:
            residual = x
        return residual + self.conv_block(x)

构造上采样残差模块,使得模型变得更深,收敛更快。

In [143]:
class UpBottleneck(nn.Module):
 
    def __init__(self, inplanes, planes, stride=2, norm_layer=nn.BatchNorm2d):
        super(UpBottleneck, self).__init__()
        self.expansion = 4
        self.residual_layer = UpsampleConvLayer(inplanes, planes * self.expansion,
                                                      kernel_size=1, stride=1, upsample=stride)
        conv_block = []
        conv_block += [norm_layer(inplanes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)]
        conv_block += [norm_layer(planes),
                                    nn.ReLU(inplace=True),
                                    UpsampleConvLayer(planes, planes, kernel_size=3, stride=1, upsample=stride)]
        conv_block += [norm_layer(planes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return  self.residual_layer(x) + self.conv_block(x)

构造Inspiration层,捕捉给定风格图片中的统计特征。

In [144]:
class Inspiration(nn.Module):
   
    def __init__(self, C, B=1):
        super(Inspiration, self).__init__()
        # B is equal to 1 or input mini_batch
        self.weight = nn.Parameter(torch.Tensor(1,C,C), requires_grad=True)
        # non-parameter buffer
        self.G = Variable(torch.Tensor(B,C,C), requires_grad=True)
        self.C = C
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.uniform_(0.0, 0.02)

    def setTarget(self, target):
        self.G = target

    def forward(self, X):
        # input X is a 3D feature map
        self.P = torch.bmm(self.weight.expand_as(self.G),self.G)
        return torch.bmm(self.P.transpose(1,2).expand(X.size(0), self.C, self.C), X.view(X.size(0),X.size(1),-1)).view_as(X)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'N x ' + str(self.C) + ')'

在构造完模型所需要的组件之后,我们继续搭建MSGnet模型。MSGnet选择一个16层的预训练VGG网络,作为descriptive network。对于transformative network,具体的构造如图所示:

png

In [145]:
'''
construct nerual network to tansfer image of different style
'''

class Net(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.InstanceNorm2d, n_blocks=6, gpu_ids=[]):
        super(Net, self).__init__()
        self.gpu_ids = gpu_ids
        self.gram = GramMatrix()

        block = Bottleneck
        upblock = UpBottleneck
        expansion = 4

        model1 = []
        model1 += [ConvLayer(input_nc, 64, kernel_size=7, stride=1),
                            norm_layer(64),
                            nn.ReLU(inplace=True),
                            block(64, 32, 2, 1, norm_layer),
                            block(32*expansion, ngf, 2, 1, norm_layer)]
        self.model1 = nn.Sequential(*model1)

        model = []
        self.ins = Inspiration(ngf*expansion)
        model += [self.model1]
        model += [self.ins]    

        for i in range(n_blocks):
            model += [block(ngf*expansion, ngf, 1, None, norm_layer)]
        
        model += [upblock(ngf*expansion, 32, 2, norm_layer),
                            upblock(32*expansion, 16, 2, norm_layer),
                            norm_layer(16*expansion),
                            nn.ReLU(inplace=True),
                            ConvLayer(16*expansion, output_nc, kernel_size=7, stride=1)]

        self.model = nn.Sequential(*model)

    def setTarget(self, Xs):
        F = self.model1(Xs)
        G = self.gram(F)
        self.ins.setTarget(G)

    def forward(self, input):
        return self.model(input)

搭建完网络的结构之后,我们可以开始训练模型。为了节约时间,示范风格迁移的效果,本案例选择使用预训练的模型style21.model,来学习content_image的内容,和style_img的风格并生成新的图片test.jpg

我们依然需要定义三个预处理函数,用来读写和处理磁盘,内存中的图片。

In [176]:
## 处理图片
def preprocess_batch(batch):
    batch = batch.transpose(0, 1)
    (r, g, b) = torch.chunk(batch, 3)
    batch = torch.cat((b, g, r))
    batch = batch.transpose(0, 1)
    return batch

## 读入图片
def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
    from PIL import Image
    img = Image.open(filename).convert('RGB')
    if size is not None:
        if keep_asp:
            size2 = int(size * 1.0 / img.size[0] * img.size[1])
            img = img.resize((size, size2), Image.ANTIALIAS)
        else:
            img = img.resize((size, size), Image.ANTIALIAS)

    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    img = np.array(img).transpose(2, 0, 1)
    img = torch.from_numpy(img).float()
    return img

## 把图片写入磁盘
def tensor_save_rgbimage(tensor, filename, cuda=False):
    if cuda:
        img = tensor.clone().cpu().clamp(0, 255).numpy()
    else:
        img = tensor.clone().clamp(0, 255).numpy()
    img = img.transpose(1, 2, 0).astype('uint8')
    img = Image.fromarray(img)
    img.save(filename)
    

    ## 展示图片
def plot_fig(source, size):
    
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    if isinstance(source, list):
        
        len_list = len(source)
        col = 2
        row = len_list/2
        
        plt.figure(figsize = size)

        for index in range(len_list):
    
            plt.subplot(str(row) + str(col) + str(index + 1))
            plt.axis('off')
            plt.imshow(mpimg.imread(source[index]))
    
    else:
        plt.figure(figsize = size)
        plt.axis('off')
        plt.imshow(mpimg.imread(source))
    
    plt.show()  

然后,我们准备需要输入的参数。

In [147]:
## 是否是使用GPU
cuda = 0

## 图片路径
img_dir = 'images/content/'
style_dir = 'images/21styles/'
model_dir = 'models/'

content_img = img_dir + 'venice-boat.jpg'
style_img = style_dir + 'starry_night.jpg'
use_model = model_dir + '21styles.model'

dst_img = 'test.jpg'

查看内容图片和风格图片

In [164]:
## 可视化内容图片
plot_fig(content_img, (10, 20))
In [165]:
## 可视化风格图片
plot_fig(style_img, (10, 20))