67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
from __future__ import print_function
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
def plot_au(img, aus, title=None):
|
|
'''
|
|
Plot action units
|
|
:param img: HxWx3
|
|
:param aus: N
|
|
:return:
|
|
'''
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1, 1, 1)
|
|
ax.axis('off')
|
|
fig.subplots_adjust(0, 0, 0.8, 1) # get rid of margins
|
|
|
|
# display img
|
|
ax.imshow(img)
|
|
|
|
if len(aus) == 11:
|
|
au_ids = ['1','2','4','5','6','9','12','17','20','25','26']
|
|
x = 0.1
|
|
y = 0.39
|
|
i = 0
|
|
for au, id in zip(aus, au_ids):
|
|
if id == '9':
|
|
x = 0.5
|
|
y -= .15
|
|
i = 0
|
|
elif id == '12':
|
|
x = 0.1
|
|
y -= .15
|
|
i = 0
|
|
|
|
ax.text(x + i * 0.2, y, id, horizontalalignment='center', verticalalignment='center',
|
|
transform=ax.transAxes, color='r', fontsize=20)
|
|
ax.text((x-0.001)+i*0.2, y-0.07, au, horizontalalignment='center', verticalalignment='center',
|
|
transform=ax.transAxes, color='b', fontsize=20)
|
|
i+=1
|
|
|
|
else:
|
|
au_ids = ['1', '2', '4', '5', '6', '7', '9', '10', '12', '14', '15', '17', '20', '23', '25', '26', '45']
|
|
x = 0.1
|
|
y = 0.39
|
|
i = 0
|
|
for au, id in zip(aus, au_ids):
|
|
if id == '9' or id == '20':
|
|
x = 0.1
|
|
y -= .15
|
|
i = 0
|
|
|
|
ax.text(x + i * 0.2, y, id, horizontalalignment='center', verticalalignment='center',
|
|
transform=ax.transAxes, color='r', fontsize=20)
|
|
ax.text((x-0.001)+i*0.2, y-0.07, au, horizontalalignment='center', verticalalignment='center',
|
|
transform=ax.transAxes, color='b', fontsize=20)
|
|
i+=1
|
|
|
|
if title is not None:
|
|
ax.text(0.5, 0.95, title, horizontalalignment='center', verticalalignment='center',
|
|
transform=ax.transAxes, color='r', fontsize=20)
|
|
|
|
fig.canvas.draw()
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
plt.close(fig)
|
|
|
|
return data |