mirror of
https://github.com/leigest519/ScreenCoder.git
synced 2026-02-13 10:12:46 +00:00
125 lines
5.0 KiB
Python
125 lines
5.0 KiB
Python
import keras
|
|
from keras.applications.resnet50 import ResNet50
|
|
from keras.models import Model,load_model
|
|
from keras.layers import Dense, Activation, Flatten, Dropout
|
|
from sklearn.metrics import confusion_matrix
|
|
import numpy as np
|
|
import cv2
|
|
|
|
from config.CONFIG import Config
|
|
cfg = Config()
|
|
|
|
|
|
class CNN:
|
|
def __init__(self, classifier_type, is_load=True):
|
|
'''
|
|
:param classifier_type: 'Text' or 'Noise' or 'Elements'
|
|
'''
|
|
self.data = None
|
|
self.model = None
|
|
|
|
self.classifier_type = classifier_type
|
|
|
|
self.image_shape = (32,32,3)
|
|
self.class_number = None
|
|
self.class_map = None
|
|
self.model_path = None
|
|
self.classifier_type = classifier_type
|
|
if is_load:
|
|
self.load(classifier_type)
|
|
|
|
def build_model(self, epoch_num, is_compile=True):
|
|
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=self.image_shape)
|
|
for layer in base_model.layers:
|
|
layer.trainable = False
|
|
self.model = Flatten()(base_model.output)
|
|
self.model = Dense(128, activation='relu')(self.model)
|
|
self.model = Dropout(0.5)(self.model)
|
|
self.model = Dense(15, activation='softmax')(self.model)
|
|
|
|
self.model = Model(inputs=base_model.input, outputs=self.model)
|
|
if is_compile:
|
|
self.model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
|
|
self.model.fit(self.data.X_train, self.data.Y_train, batch_size=64, epochs=epoch_num, verbose=1,
|
|
validation_data=(self.data.X_test, self.data.Y_test))
|
|
|
|
def train(self, data, epoch_num=30):
|
|
self.data = data
|
|
self.build_model(epoch_num)
|
|
self.model.save(self.model_path)
|
|
print("Trained model is saved to", self.model_path)
|
|
|
|
def load(self, classifier_type):
|
|
if classifier_type == 'Text':
|
|
self.model_path = 'E:/Mulong/Model/rico_compos/cnn-textview-2.h5'
|
|
self.class_map = ['Text', 'Non-Text']
|
|
elif classifier_type == 'Noise':
|
|
self.model_path = 'E:/Mulong/Model/rico_compos/cnn-noise-1.h5'
|
|
self.class_map = ['Noise', 'Non-Noise']
|
|
elif classifier_type == 'Elements':
|
|
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-19.h5'
|
|
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-28.h5'
|
|
# self.model_path = 'E:/Mulong/Model/rico_compos/resnet-ele14-45.h5'
|
|
self.model_path = 'UIED/cnn/model/cnn-rico-1.h5' # Use local model
|
|
self.class_map = cfg.element_class
|
|
self.image_shape = (64, 64, 3)
|
|
elif classifier_type == 'Image':
|
|
# Redirect 'Image' classification to use the general 'Elements' model
|
|
# as the specific model is not available in the project.
|
|
# IMPORTANT: This requires the actual model file to be present for real classification.
|
|
print("Warning: 'Image' specific model not found. Redirecting to general 'Elements' classifier.")
|
|
self.model_path = 'UIED/cnn/model/cnn-rico-1.h5' # Use local model
|
|
self.class_map = ['Image', 'Non-Image'] # Keep the class map for binary classification logic
|
|
|
|
self.class_number = len(self.class_map)
|
|
try:
|
|
self.model = load_model(self.model_path)
|
|
print('Model Loaded From', self.model_path)
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
print("A dummy model file was created, but it's not a valid Keras model.")
|
|
print("Please replace it with the actual model file for classification to work.")
|
|
self.model = None
|
|
|
|
def preprocess_img(self, image):
|
|
image = cv2.resize(image, self.image_shape[:2])
|
|
x = (image / 255).astype('float32')
|
|
x = np.array([x])
|
|
return x
|
|
|
|
def predict(self, imgs, compos, load=False, show=False):
|
|
"""
|
|
:type img_path: list of img path
|
|
"""
|
|
if load:
|
|
self.load(self.classifier_type)
|
|
if self.model is None:
|
|
print("*** No model loaded ***")
|
|
return
|
|
for i in range(len(imgs)):
|
|
X = self.preprocess_img(imgs[i])
|
|
Y = self.class_map[np.argmax(self.model.predict(X))]
|
|
compos[i].category = Y
|
|
if show:
|
|
print(Y)
|
|
cv2.imshow('element', imgs[i])
|
|
cv2.waitKey()
|
|
|
|
def evaluate(self, data, load=True):
|
|
if load:
|
|
self.load(self.classifier_type)
|
|
X_test = data.X_test
|
|
Y_test = [np.argmax(y) for y in data.Y_test]
|
|
Y_pre = [np.argmax(y_pre) for y_pre in self.model.predict(X_test, verbose=1)]
|
|
|
|
matrix = confusion_matrix(Y_test, Y_pre)
|
|
print(matrix)
|
|
|
|
TP, FP, FN = 0, 0, 0
|
|
for i in range(len(matrix)):
|
|
TP += matrix[i][i]
|
|
FP += sum(matrix[i][:]) - matrix[i][i]
|
|
FN += sum(matrix[:][i]) - matrix[i][i]
|
|
precision = TP/(TP+FP)
|
|
recall = TP / (TP+FN)
|
|
print("Precision:%.3f, Recall:%.3f" % (precision, recall)) |