在很久很久以前,看到过一个精心设计的学习图片颜色分布的模型。那时候觉得上色是一个很有趣很好玩的事情,再后来很多人尝试给漫画自动添加颜色,但是到目前为止,很少看到用GAN来做自动上色的。今天我们就来实现这么一个教程。话不多说,先看看上色的效果怎么样:

这手笔,牛逼点的UI设计师怎么着也得花个30-40分钟来把这些背景颜色填充,即便使用ps的魔法棒工具也得修一修边角,一时半会搞不定,而我们的Ai插画师,预测一张图片0.012s,也就是12ms就完成了一张图片的上色工作。用今天比较流行的话来说就是,拜托,你太弱了…. 人类还没有动笔,AI已经完成了一切。

仔细看看这个色彩配置,天空的蔚蓝和草地的翠绿形成了鲜明的对比。最起码这个颜色使用不是过于前卫,如果把天空绘制成了黄色那就略微有点尴尬了,效果看着还行。

其实在这之前我曾经发过一篇AI上色的文章,那时候很多初学者就问我怎么实现不同物体上色,那会还只是一个十分简单的网络,并且仅仅只用了一张图片训练,如果你想要一个全能的AI上色教程,同时你又像学点GAN,那么本篇教程你值得一看。

当你看完本教程之后,你可以做这些事情:

  • 自己爬一堆妹子的anime,动漫美少女的自动上色器就来了!
  • 自己爬一堆xxx的图片,自动上色器就来了!

请注意,本模型完全是非监督,你不需要任何label,因为label就是它本身。。所以绝对是值得大家学习和上手!和其他的一些实现相比,我们做了些许的改进:

  • G的输入其实是支持任意尺寸输入的,意味着你既可以训练1080×720的大图,也可以训练256×256的动漫头像图片;
  • G本身是一个encoder和decoder的结构,另外参考了一些Resnet的连接方式设计,关于encoder和decoder结构的强大威力,可以参考我之前实现的Deepfakes(我的版本比原始版本以及其他版本都支持更高的分辨率,64×64 vs 128×128),传送门在此

原理

好了,现在是枯燥的理论环节。为了使我们的教程显得不那么枯燥,我们先上更枯燥的代码吧~ 开个玩笑,我们先看一下这的网络设计。 要实现一个GAN来生成彩色图,那么G是什么?生成器应该输入的是黑白图片,输出是彩色图片,如果你这样去做,那么大概率你的网络会发散。甚至可能重构出来的图片看起来什么都不是。 最好的方式是:将图片的颜色空间按照YUV分离出来,我们的网络仅仅只预测UV分量, 然后通过网络拿到输出后在通过原图的Y分量和UV分量重构为彩色图片。 那么YUV是什么呢?

Y’代表明亮度(luma;brightness)而U与V存储色度(色讯;chrominance;color)部分;亮度(luminance)记作Y,而Y’的prime符号记作伽玛校正。

摘自百度百科,可以看到,实际上我们就是将颜色单独拿出来了,用GAN来预测这个颜色,然后再进行重构,最后达到合成彩色图片的效果。

核心代码

大概思路是有了,但相比读者和我一样有两个问题:

  • G的输入是黑白图片,输出是啥?
  • D怎么区分G的输出?

在不考虑如何搭建G和D之前,这两个问题是要知道的,其实也很简单,首先G的输入是黑白图片,也就是一个单通道的图片,输出是UV分量。 而D的任务呢,就是从原图分离出UV分量,与G的生成来做区分。 D和G的训练是一个博弈的过程,我们来脑补一下G和D是如何训练的:

起初D和G都很弱… D和G自从出生起,就注定具有不共戴天之仇…. 他们的目的,就是干掉对方的KPI,让他们出岔子,从而赢得自己在上司面前的信任.. 有一天,G的参谋长(参谋长名字叫 loss)说: “怎么办,这几天我们伪造的办公文件,全被D识破了!” G表示: “不慌,我们重新调整一下人员配置,把不干活的神经元砍掉,力求下次伪造的文件一定不能被他们识破..” 过了几天,D的参谋长(名字也叫loss,外号loss_D) 说: “这几天他们伪造的文件越来越像了!我们这边差点好几个让他们溜过去!” D表示: “不慌,调整人员配置,步骤要稳打稳扎,我希望我们的辨别技术这个季度要提升至少8个百分点!” 参谋长说: “老大英明,小的这就调整部分经费配置,谁干活多谁拿钱..” …. 就这样,两个部分变的越来越强,G伪造的文件越来越接近真实的文件,D的甄别能力也越来越强。。 此时,主导一切的幕后黑手真窃窃自喜:“这个G成熟了,该学会自动上色了”

OK,编不下去了。我们直接看代码把!首先是数据的预处理,其实这个才是重中之重,最深度学习体会最深的可能就是,数据预处理的方式如果不对,结果就可能天差地别。笔者最近做3D点云检测,训练了一个模型预测的时候总不对,最后发现竟然是点云的强度也要归一化。。。 上色的首要条件就是对图片进行预处理:

class PairImageDataset(data.Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.files = [os.path.join(path, x) for x in files]
​
    def __len__(self):
        return len(self.files)
​
    def __getitem__(self, index):
        img = Image.open(self.files[index])
        yuv = rgb2yuv(img)
        y = yuv[..., 0] - 0.5
        u_t = yuv[..., 1] / 0.43601035
        v_t = yuv[..., 2] / 0.61497538
        return torch.Tensor(np.expand_dims(y, axis=0)), torch.Tensor(
            np.stack([u_t, v_t], axis=0))
​

这段小巧的代码,就是我们定义的数据输入器,使用pytorch的dataset API编写。通过读取图片,RGB转到YUV,然后分离Y和UV通道,就可以构建我们的输入数据了!

在之前我写的30行代码自动上色的程序里面,我们用很少的代码实现了一个自动上色程序,这次使用GAN方法略显复杂,但实际代码并不多:

train_ds = PairImageDataset(args.training_dir)
    logging.info('loaded dataset from: {}, data length: {}'.format(args.training_dir, train_ds.__len__()))
    train_dataloader = data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
​
    i = 0
    adversarial_loss = torch.nn.BCELoss()
    optimizer_G = torch.optim.Adam(G.parameters(),
                                   lr=args.g_lr,
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(D.parameters(),
                                   lr=args.d_lr,
                                   betas=(0.5, 0.999))
    for epoch in range(start_epoch, args.epoch):
        for i, (y, uv) in enumerate(train_dataloader):
            try:
                # Adversarial ground truths
                valid = Variable(torch.Tensor(y.size(0), 1).fill_(1.0),
                                requires_grad=False).to(device)
                fake = Variable(torch.Tensor(y.size(0), 1).fill_(0.0),
                                requires_grad=False).to(device)
​
                yvar = Variable(y).to(device)
                uvvar = Variable(uv).to(device)
                real_imgs = torch.cat([yvar, uvvar], dim=1)
​
                optimizer_G.zero_grad()
                uvgen = G(yvar)
                # Generate a batch of images
                gen_imgs = torch.cat([yvar.detach(), uvgen], dim=1)
​
                # Loss measures generator's ability to fool the discriminator
                g_loss_gan = adversarial_loss(D(gen_imgs), valid)
                g_loss = g_loss_gan + args.pixel_loss_weights * torch.mean(
                    (uvvar - uvgen)**2)
                if i % args.g_every == 0:
                    g_loss.backward()
                    optimizer_G.step()
​
                optimizer_D.zero_grad()
                # Measure discriminator's ability to classify real from generated samples
                real_loss = adversarial_loss(D(real_imgs), valid)
                fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                optimizer_D.step()
                if i % 300 == 0:
                    logging.info("Epoch: %d, iter: %d, D loss: %f, G total loss: %f, GAN Loss: %f"
                        % (epoch, i, d_loss.item(), g_loss.item(), g_loss_gan.item()))
                    save_weights(
                        {'D': D.state_dict(), 'G': G.state_dict(), 'epoch': epoch},
                        epoch
                    )
                    
                    # snap some images from dir
                    test_imgs = glob.glob('images/*.jpeg')
                    for test_img in test_imgs:
                        snap_image_result_from_file(test_img, G)
​
            except KeyboardInterrupt:
                logging.info('interrupted. try saving model now..')
                save_weights(
                    {'D': D.state_dict(), 'G': G.state_dict(), 'epoch': epoch}, 0
                )
                logging.info('saved.')
                exit(0)

其中最核心是D和G的loss传递过程,首先我们定义了D的loss是BCEloss,也就是两个类别的交叉商,然后将返回的插值用来更新G,而D的loss呢则是生成的BCE和真实值的BCE二者的均值。

代码几乎没有什么比较难以理解的地方,唯一复杂的就是训练的步骤和方式。最后本教程的所有代码在下方可以看到。 我们总结一下完成这个任务的一些心得体会:

  • GAN其实可以很强了,我们没有去训练小图,但是肯定小图效果会很不错;
  • 上色有一些噪点,这可能是由于不够导致,也可能是我们的图片太杂,不够纯净。
  • 对于大面积的背景上色效果不错,对于比较细节的地方,上色能力不足。