VAE

Emor zhong

Emor zhong

Unknown

1 0
  • 0 Collaborators

使用PyTorch构建VAE模型,并使用oneAPI中的oneDNN优化。 ...learn more

Project status: Published/In Market

oneAPI

Intel Technologies
oneAPI, DevCloud

Code Samples [1]

Overview / Usage

用VAE构建的自编码器,泛化能力比一般的AE更强。

Methodology / Approach

使用PyTorch构建VAE模型,并使用oneAPI中的oneDNN优化。

Technologies Used

  1. 构建VAE模式。

使用一层784*400的线性网络作为编码器的主干,用两个400*20的线性网络作为VAE的均值和方差生成部分。用relu激活函数连接这两个网络,作为Encoder。再用一个20*400的线性网络和一个400*784的线性网络组成Decoder,为了保证得到的输出范围在0到1,用softmax激活函数进行归一化。

  1. 构建ELBO损失函数

ELBO分为两个部分,BCE部分是重构损失,因为手写数字是灰度图像,因此可以用交叉熵构建。第二部分是KL散度,这是为了让生成的均值和方差接近0和1,尽可能接近标准正态分布。

  1. 构建训练函数

训练函数需要多次调用,每次调用都是一个完整的epoch。需要将数据移动到同一设备,然后调用优化器的梯度清空函数,在将数据喂给VAE模型并将梯度反向传播,最后用优化器优化。

  1. 构建测试函数

测试阶段与训练阶段类似,只需要将数据喂给模型即可。这样就可以得到模型的输出,将其转换为图像便可以将重构的图像保存下来。

  1. 采样生成

生成的主要依赖Decoder的功能,在生成标准正态分布的随机数之后将其喂给模型,即可生成手写数字图像。

Repository

https://github.com/EmorZz1G/my_vae

Comments (0)