AI image - style transfer

The neural style transfer is to compose a content image and a style image to a generated image which contains the content as well as the style. It let machine become a artist and there is no standard solution but humans aesthetic. In this introduction, you will learn the theory and using the pre-trained model to synthesis from any content images.

Theory

Use VGG19, one of a classical CNN model that contains 5 convolution-pooling pairs for feature extraction, to achieve the goal.

The main difficulty of style transfer is to balance the similarity of generated image G between the content image C and the style image S, so the loss function has the form

L = α Lcontent(C,G) + β Lstyle(S,G)

To calculate the loss, we need the content image output C3 at the third convolution-pooling layer of CNN and style image outputs S1,S2,S3,S4,S5 at all the convolution-pooling layer.

The losses are

Lcontent = 0.5 * Σi,j (C3-G)2

Lstyle = (4NlMl)-1 Σl=15 Σi,j (Graml*Sl-Graml*G)2

, where Nl denotes the channels of the layers, i.e. N l ∈ [64,128,256,512,512] and Ml denotes the shape of the layers, i.e. Ml ∈ [224×224,112×112,56×56,28×28,14×14]. The “Gram” is Gram matrix, which has the following form

Graml = MlMlT

Ml,i,j = ith channel, jth pixel in the plane in lth layer

The similarity between content image C3 and generated image G are calculated by the pixel difference because the content feature is more delicate and localized then the style feature. On the contrary, the similarity between style images Sl and the generated image G are calculated by the inner product transformation that extract the style feature such as the texture and lines.

Implementations

The tutorial reference Neural style transfer by tensorflow. The system requirements are

  • Python=3.6
  • Tensorflow=2.2.0
  • Tensorflow_hub=0.11.0

In the beginning, import the packages and define some functions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
import tensorflow as tf
print("tf.__version__=", tf.__version__)
import tensorflow_hub as hub
print("hub.__version__=", hub.__version__)
import IPython.display as display
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12,12)
mpl.rcParams['axes.grid'] = False
import numpy as np
import PIL.Image
import time
import functools

def tensor_to_image(tensor):
tensor = tensor*255
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor)>3:
assert tensor.shape[0] == 1
tensor = tensor[0]
return PIL.Image.fromarray(tensor)

def load_img(path_to_img):
max_dim = 512
img = tf.io.read_file(path_to_img)
img = tf.image.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
shape = tf.cast(tf.shape(img)[:-1], tf.float32)
long_dim = max(shape)
scale = max_dim / long_dim
new_shape = tf.cast(shape * scale, tf.int32)
img = tf.image.resize(img, new_shape)
img = img[tf.newaxis, :]
return img

def imshow(image, title=None):
if len(image.shape) > 3:
image = tf.squeeze(image, axis=0)
plt.imshow(image)
if title:
plt.title(title)

Next, prepare an image or download an example image as the content image as well as download the example style image.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# download example content image
# content_path = tf.keras.utils.get_file('YellowLabradorLooking_new.jpg', \
#'https://storage.googleapis.com/download.tensorflow.org \
#/example_images/YellowLabradorLooking_new.jpg')

# use prepared content image
content_path = r'C:\Users\James\.keras\datasets\b.jpg'

# download the example style image
style_path = tf.keras.utils.get_file('kandinsky5.jpg', \
'https://storage.googleapis.com/download.tensorflow.org/ \
example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg')

print(content_path, "\n", style_path)
# C:\Users\James\.keras\datasets\b.jpg
# C:\Users\James\.keras\datasets\kandinsky5.jpg

Then, plot the content image and the style image.

1
2
3
4
5
6
7
8
9
content_image = load_img(content_path)
style_image = load_img(style_path)

ax = plt.figure()
ax.set_facecolor('orange')
plt.subplot(1, 2, 1)
imshow(content_image, 'Content Image')
plt.subplot(1, 2, 2)
imshow(style_image, 'Style Image')

Finally, load the model and predict.

1
2
3
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
tensor_to_image(stylized_image)

Reference