mirror of
https://github.com/leigest519/ScreenCoder.git
synced 2026-02-13 02:02:48 +00:00
Initial commit
This commit is contained in:
125
UIED/cnn/CNN.py
Normal file
125
UIED/cnn/CNN.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import keras
|
||||
from keras.applications.resnet50 import ResNet50
|
||||
from keras.models import Model,load_model
|
||||
from keras.layers import Dense, Activation, Flatten, Dropout
|
||||
from sklearn.metrics import confusion_matrix
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from config.CONFIG import Config
|
||||
cfg = Config()
|
||||
|
||||
|
||||
class CNN:
|
||||
def __init__(self, classifier_type, is_load=True):
|
||||
'''
|
||||
:param classifier_type: 'Text' or 'Noise' or 'Elements'
|
||||
'''
|
||||
self.data = None
|
||||
self.model = None
|
||||
|
||||
self.classifier_type = classifier_type
|
||||
|
||||
self.image_shape = (32,32,3)
|
||||
self.class_number = None
|
||||
self.class_map = None
|
||||
self.model_path = None
|
||||
self.classifier_type = classifier_type
|
||||
if is_load:
|
||||
self.load(classifier_type)
|
||||
|
||||
def build_model(self, epoch_num, is_compile=True):
|
||||
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=self.image_shape)
|
||||
for layer in base_model.layers:
|
||||
layer.trainable = False
|
||||
self.model = Flatten()(base_model.output)
|
||||
self.model = Dense(128, activation='relu')(self.model)
|
||||
self.model = Dropout(0.5)(self.model)
|
||||
self.model = Dense(15, activation='softmax')(self.model)
|
||||
|
||||
self.model = Model(inputs=base_model.input, outputs=self.model)
|
||||
if is_compile:
|
||||
self.model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
|
||||
self.model.fit(self.data.X_train, self.data.Y_train, batch_size=64, epochs=epoch_num, verbose=1,
|
||||
validation_data=(self.data.X_test, self.data.Y_test))
|
||||
|
||||
def train(self, data, epoch_num=30):
|
||||
self.data = data
|
||||
self.build_model(epoch_num)
|
||||
self.model.save(self.model_path)
|
||||
print("Trained model is saved to", self.model_path)
|
||||
|
||||
def load(self, classifier_type):
|
||||
if classifier_type == 'Text':
|
||||
self.model_path = 'E:/Mulong/Model/rico_compos/cnn-textview-2.h5'
|
||||
self.class_map = ['Text', 'Non-Text']
|
||||
elif classifier_type == 'Noise':
|
||||
self.model_path = 'E:/Mulong/Model/rico_compos/cnn-noise-1.h5'
|
||||
self.class_map = ['Noise', 'Non-Noise']
|
||||
elif classifier_type == 'Elements':
|
||||
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-19.h5'
|
||||
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-28.h5'
|
||||
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-45.h5'
|
||||
self.model_path = 'UIED/cnn/model/cnn-rico-1.h5' # Use local model
|
||||
self.class_map = cfg.element_class
|
||||
self.image_shape = (64, 64, 3)
|
||||
elif classifier_type == 'Image':
|
||||
# Redirect 'Image' classification to use the general 'Elements' model
|
||||
# as the specific model is not available in the project.
|
||||
# IMPORTANT: This requires the actual model file to be present for real classification.
|
||||
print("Warning: 'Image' specific model not found. Redirecting to general 'Elements' classifier.")
|
||||
self.model_path = 'UIED/cnn/model/cnn-rico-1.h5' # Use local model
|
||||
self.class_map = ['Image', 'Non-Image'] # Keep the class map for binary classification logic
|
||||
|
||||
self.class_number = len(self.class_map)
|
||||
try:
|
||||
self.model = load_model(self.model_path)
|
||||
print('Model Loaded From', self.model_path)
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
print("A dummy model file was created, but it's not a valid Keras model.")
|
||||
print("Please replace it with the actual model file for classification to work.")
|
||||
self.model = None
|
||||
|
||||
def preprocess_img(self, image):
|
||||
image = cv2.resize(image, self.image_shape[:2])
|
||||
x = (image / 255).astype('float32')
|
||||
x = np.array([x])
|
||||
return x
|
||||
|
||||
def predict(self, imgs, compos, load=False, show=False):
|
||||
"""
|
||||
:type img_path: list of img path
|
||||
"""
|
||||
if load:
|
||||
self.load(self.classifier_type)
|
||||
if self.model is None:
|
||||
print("*** No model loaded ***")
|
||||
return
|
||||
for i in range(len(imgs)):
|
||||
X = self.preprocess_img(imgs[i])
|
||||
Y = self.class_map[np.argmax(self.model.predict(X))]
|
||||
compos[i].category = Y
|
||||
if show:
|
||||
print(Y)
|
||||
cv2.imshow('element', imgs[i])
|
||||
cv2.waitKey()
|
||||
|
||||
def evaluate(self, data, load=True):
|
||||
if load:
|
||||
self.load(self.classifier_type)
|
||||
X_test = data.X_test
|
||||
Y_test = [np.argmax(y) for y in data.Y_test]
|
||||
Y_pre = [np.argmax(y_pre) for y_pre in self.model.predict(X_test, verbose=1)]
|
||||
|
||||
matrix = confusion_matrix(Y_test, Y_pre)
|
||||
print(matrix)
|
||||
|
||||
TP, FP, FN = 0, 0, 0
|
||||
for i in range(len(matrix)):
|
||||
TP += matrix[i][i]
|
||||
FP += sum(matrix[i][:]) - matrix[i][i]
|
||||
FN += sum(matrix[:][i]) - matrix[i][i]
|
||||
precision = TP/(TP+FP)
|
||||
recall = TP / (TP+FN)
|
||||
print("Precision:%.3f, Recall:%.3f" % (precision, recall))
|
||||
21
UIED/cnn/Config.py
Normal file
21
UIED/cnn/Config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
class Config:
|
||||
def __init__(self):
|
||||
# cnn 4 classes
|
||||
# self.MODEL_PATH = 'E:/Mulong/Model/ui_compos/cnn6_icon.h5' # cnn 4 classes
|
||||
# self.class_map = ['Image', 'Icon', 'Button', 'Input']
|
||||
|
||||
# resnet 14 classes
|
||||
# self.DATA_PATH = "E:/Mulong/Datasets/rico/elements-14-2"
|
||||
# self.MODEL_PATH = 'E:/Mulong/Model/rico_compos/resnet-ele14.h5'
|
||||
# self.class_map = ['Button', 'CheckBox', 'Chronometer', 'EditText', 'ImageButton', 'ImageView',
|
||||
# 'ProgressBar', 'RadioButton', 'RatingBar', 'SeekBar', 'Spinner', 'Switch',
|
||||
# 'ToggleButton', 'VideoView', 'TextView'] # ele-14
|
||||
|
||||
self.DATA_PATH = "E:\Mulong\Datasets\dataset_webpage\Components3"
|
||||
|
||||
self.MODEL_PATH = 'E:/Mulong/Model/rico_compos/cnn2-textview.h5'
|
||||
self.class_map = ['Text', 'Non-Text']
|
||||
|
||||
self.image_shape = (32, 32, 3)
|
||||
self.class_number = len(self.class_map)
|
||||
69
UIED/cnn/Data.py
Normal file
69
UIED/cnn/Data.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from os.path import join as pjoin
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
from Config import Config
|
||||
|
||||
cfg = Config()
|
||||
|
||||
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.data_num = 0
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.X_train, self.Y_train = None, None
|
||||
self.X_test, self.Y_test = None, None
|
||||
|
||||
self.image_shape = cfg.image_shape
|
||||
self.class_number = cfg.class_number
|
||||
self.class_map = cfg.class_map
|
||||
self.DATA_PATH = cfg.DATA_PATH
|
||||
|
||||
def load_data(self, resize=True, shape=None, max_number=1000000):
|
||||
# if customize shape
|
||||
if shape is not None:
|
||||
self.image_shape = shape
|
||||
else:
|
||||
shape = self.image_shape
|
||||
|
||||
# load data
|
||||
for p in glob.glob(pjoin(self.DATA_PATH, '*')):
|
||||
print("*** Loading components of %s: %d ***" %(p.split('\\')[-1], int(len(glob.glob(pjoin(p, '*.png'))))))
|
||||
label = self.class_map.index(p.split('\\')[-1]) # map to index of classes
|
||||
for i, image_path in enumerate(tqdm(glob.glob(pjoin(p, '*.png'))[:max_number])):
|
||||
image = cv2.imread(image_path)
|
||||
if resize:
|
||||
image = cv2.resize(image, shape[:2])
|
||||
self.images.append(image)
|
||||
self.labels.append(label)
|
||||
|
||||
assert len(self.images) == len(self.labels)
|
||||
self.data_num = len(self.images)
|
||||
print('%d Data Loaded' % self.data_num)
|
||||
|
||||
def generate_training_data(self, train_data_ratio=0.8):
|
||||
# transfer int into c dimensions one-hot array
|
||||
def expand(label, class_number):
|
||||
# return y : (num_class, num_samples)
|
||||
y = np.eye(class_number)[label]
|
||||
y = np.squeeze(y)
|
||||
return y
|
||||
|
||||
# reshuffle
|
||||
np.random.seed(0)
|
||||
self.images = np.random.permutation(self.images)
|
||||
np.random.seed(0)
|
||||
self.labels = np.random.permutation(self.labels)
|
||||
Y = expand(self.labels, self.class_number)
|
||||
|
||||
# separate dataset
|
||||
cut = int(train_data_ratio * self.data_num)
|
||||
self.X_train = (self.images[:cut] / 255).astype('float32')
|
||||
self.X_test = (self.images[cut:] / 255).astype('float32')
|
||||
self.Y_train = Y[:cut]
|
||||
self.Y_test = Y[cut:]
|
||||
|
||||
print('X_train:%d, Y_train:%d' % (len(self.X_train), len(self.Y_train)))
|
||||
print('X_test:%d, Y_test:%d' % (len(self.X_test), len(self.Y_test)))
|
||||
BIN
UIED/cnn/__pycache__/CNN.cpython-312.pyc
Normal file
BIN
UIED/cnn/__pycache__/CNN.cpython-312.pyc
Normal file
Binary file not shown.
BIN
UIED/cnn/__pycache__/CNN.cpython-35.pyc
Normal file
BIN
UIED/cnn/__pycache__/CNN.cpython-35.pyc
Normal file
Binary file not shown.
0
UIED/cnn/model/cnn-rico-1.h5
Normal file
0
UIED/cnn/model/cnn-rico-1.h5
Normal file
Reference in New Issue
Block a user