zl程序教程

您现在的位置是:首页 >  云平台

当前栏目

Pytorch入门与实践——AI插画师:生成对抗网络数据集制作

网络AI数据PyTorch入门 实践 生成 制作
2023-09-14 09:05:44 时间

目录

摘要

1、用爬虫爬取二次元妹子的图片

2、获取图片中的头像


摘要

最近想搞一搞GAN,但是发现《Pytorch入门与实践——AI插画师:生成对抗网络》,但是发现数据集的链接失效了,所以自己制作一份。

代码来自https://www.zhihu.com/people/he-zhi-yuan-16,我做了一些修改。

1、用爬虫爬取二次元妹子的图片

数据从https://konachan.net/网站中下载的,是一个非常著名的动漫网站(不过我不知道)代码如下:

import requests
from bs4 import BeautifulSoup
import os
import traceback

def download(url, filename):
    if os.path.exists(filename):
        print('file exists!')
        return
    try:
        r = requests.get(url, stream=True, timeout=60)
        r.raise_for_status()
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)
                    f.flush()
        return filename
    except KeyboardInterrupt:
        if os.path.exists(filename):
            os.remove(filename)
        raise KeyboardInterrupt
    except Exception:
        traceback.print_exc()
        if os.path.exists(filename):
            os.remove(filename)


if os.path.exists('imgs') is False:
    os.makedirs('imgs')

start =1
end = 8000
for i in range(start, end + 1):
    url = 'https://konachan.net/post?page=%d&tags=' % i
    html = requests.get(url).text
    soup = BeautifulSoup(html, 'html.parser')
    for img in soup.find_all('img', class_="preview"):
        target_url =img['src']
        filename = os.path.join('imgs', target_url.split('/')[-1])
        download(target_url, filename)
    print('%d / %d' % (i, end))

运行代码后就能在imgs文件夹看到二次元妹子的照片,各种各样的,目不暇接、眼花缭乱。。。。。

 

2、获取图片中的头像

截取头像和原文一样,直接使用github上一个基于opencv的工具,地址:https://github.com/nagadomi/lbpcascade_animeface,将lbpcascade_animeface.xml(准确率挺高的,不过有点猥琐,大家试一下就知道了。。。。。。)文件,放到根目录下。

然后运行下面的代码:

import cv2
import sys
import os.path
from glob import glob

def detect(filename, cascade_file="lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)
    image = cv2.imread(filename)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(gray,
                                     # detector options
                                     scaleFactor=1.1,
                                     minNeighbors=5,
                                     minSize=(48, 48))
    for i, (x, y, w, h) in enumerate(faces):
        face = image[y: y + h, x:x + w, :]
        face = cv2.resize(face, (96, 96))
        save_filename = '%s-%d.jpg' % (os.path.basename(filename).split('.')[0], i)
        cv2.imwrite("faces/" + save_filename, face)


if __name__ == '__main__':
    if os.path.exists('faces') is False:
        os.makedirs('faces')
    file_list = glob('imgs/*.jpg')
    for filename in file_list:
        detect(filename)

 

随便放几张截取后的头像:

 

连接是我制作的数据:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/15939189

运行上面的代码就可以截取二次元妹子的头像了,到这里数据集制作完成了,然后我们一起GAN。如果你觉得有帮助请收藏、点赞,也可以打赏,多少随意。