Files
facefusion/facefusion/face_classifier.py
T
Henry Ruhs 8bf9170577 3.5.0 (#977)
* Mark as NEXT

* Reduce caching to avoid RAM explosion

* Reduce caching to avoid RAM explosion

* Update dependencies

* add face-detector-pad-factor

* update facefusion.ini

* fix test

* change pad to margin

* fix order

* add prepare margin

* use 50% max margin

* Minor fixes part2

* Minor fixes part3

* Minor fixes part4

* Minor fixes part1

* Downgrade onnxruntime as of BiRefNet broken on CPU

add test

update

update facefusion.ini

add birefnet

* rename models

add more models

* Fix versions

* Add .claude to gitignore

* add normalize color

add 4 channel

add colors

* worflows

* cleanup

* cleanup

* cleanup

* cleanup

* add more models (#961)

* Fix naming

* changes

* Fix style and mock Gradio

* Fix style and mock Gradio

* Fix style and mock Gradio

* apply clamp

* remove clamp

* Add normalizer test

* Introduce sanitizer for the rescue (#963)

* Introduce sanitizer for the rescue

* Introduce sanitizer for the rescue

* Introduce sanitizer for the rescue

* prepare ffmpeg for alpha support

* Some cleanup

* Some cleanup

* Fix CI

* List as TypeAlias is not allowed (#967)

* List as TypeAlias is not allowed

* List as TypeAlias is not allowed

* List as TypeAlias is not allowed

* List as TypeAlias is not allowed

* Add mpeg and mxf support (#968)

* Add mpeg support

* Add mxf support

* Adjust fix_xxx_encoder for the new formats

* Extend output pattern for batch-run (#969)

* Extend output pattern for batch-run

* Add {target_extension} to allowed mixed files

* Catch invalid output pattern keys

* alpha support

* cleanup

* cleanup

* add ProcessorOutputs type

* fix preview and streamer, support alpha for background_remover

* Refactor/open close processors (#972)

* Introduce open/close processors

* Add locales for translator

* Introduce __autoload__ for translator

* More cleanup

* Fix import issues

* Resolve the scope situation for locals

* Fix installer by not using translator

* Fixes after merge

* Fixes after merge

* Fix translator keys in ui

* Use LOCALS in installer

* Update and partial fix DirectML

* Use latest onnxruntime

* Fix performance

* Fix lint issues

* fix mask

* fix lint

* fix lint

* Remove default from translator.get()

* remove 'framerate='

* fix test

* Rename and reorder models

* Align naming

* add alpha preview

* fix frame-by-frame

* Add alpha effect via css

* preview support alpha channel

* fix preview modes

* Use official assets repositories

* Add support for u2net_cloth

* fix naming

* Add more models

* Add vendor, license and year direct to the models

* Add vendor, license and year direct to the models

* Update dependencies, Minor CSS adjustment

* Ready for 3.5.0

* Fix naming

* Update about messages

* Fix return

* Use groups to show/hide

* Update preview

* Conditional merge mask

* Conditional merge mask

* Fix import order

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
Co-authored-by: Harisreedhar <46858047+harisreedhar@users.noreply.github.com>
2025-11-03 14:05:15 +01:00

141 lines
3.9 KiB
Python

from functools import lru_cache
from typing import List, Tuple
import numpy
from facefusion import inference_manager
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.face_helper import warp_face_by_face_landmark_5
from facefusion.filesystem import resolve_relative_path
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.types import Age, DownloadScope, FaceLandmark5, Gender, InferencePool, ModelOptions, ModelSet, Race, VisionFrame
@lru_cache()
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
return\
{
'fairface':
{
'__metadata__':
{
'vendor': 'dchen236',
'license': 'Non-Commercial',
'year': 2021
},
'hashes':
{
'face_classifier':
{
'url': resolve_download_url('models-3.0.0', 'fairface.hash'),
'path': resolve_relative_path('../.assets/models/fairface.hash')
}
},
'sources':
{
'face_classifier':
{
'url': resolve_download_url('models-3.0.0', 'fairface.onnx'),
'path': resolve_relative_path('../.assets/models/fairface.onnx')
}
},
'template': 'arcface_112_v2',
'size': (224, 224),
'mean': [ 0.485, 0.456, 0.406 ],
'standard_deviation': [ 0.229, 0.224, 0.225 ]
}
}
def get_inference_pool() -> InferencePool:
model_names = [ 'fairface' ]
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
model_names = [ 'fairface' ]
inference_manager.clear_inference_pool(__name__, model_names)
def get_model_options() -> ModelOptions:
return create_static_model_set('full').get('fairface')
def pre_check() -> bool:
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]:
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
model_mean = get_model_options().get('mean')
model_standard_deviation = get_model_options().get('standard_deviation')
crop_vision_frame, _ = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_template, model_size)
crop_vision_frame = crop_vision_frame.astype(numpy.float32)[:, :, ::-1] / 255.0
crop_vision_frame -= model_mean
crop_vision_frame /= model_standard_deviation
crop_vision_frame = crop_vision_frame.transpose(2, 0, 1)
crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0)
gender_id, age_id, race_id = forward(crop_vision_frame)
gender = categorize_gender(gender_id[0])
age = categorize_age(age_id[0])
race = categorize_race(race_id[0])
return gender, age, race
def forward(crop_vision_frame : VisionFrame) -> Tuple[List[int], List[int], List[int]]:
face_classifier = get_inference_pool().get('face_classifier')
with conditional_thread_semaphore():
race_id, gender_id, age_id = face_classifier.run(None,
{
'input': crop_vision_frame
})
return gender_id, age_id, race_id
def categorize_gender(gender_id : int) -> Gender:
if gender_id == 1:
return 'female'
return 'male'
def categorize_age(age_id : int) -> Age:
if age_id == 0:
return range(0, 2)
if age_id == 1:
return range(3, 9)
if age_id == 2:
return range(10, 19)
if age_id == 3:
return range(20, 29)
if age_id == 4:
return range(30, 39)
if age_id == 5:
return range(40, 49)
if age_id == 6:
return range(50, 59)
if age_id == 7:
return range(60, 69)
return range(70, 100)
def categorize_race(race_id : int) -> Race:
if race_id == 1:
return 'black'
if race_id == 2:
return 'latino'
if race_id == 3 or race_id == 4:
return 'asian'
if race_id == 5:
return 'indian'
if race_id == 6:
return 'arabic'
return 'white'