인공지능/Deep Learning

[Deep Learning] 이미지 분류 딥러닝 Transfer Learning, MobileNetV2 활용

건휘맨 2024. 4. 19. 17:48

Transfer Learning 란 한 작업에서 학습한 지식을 다른 관련 작업으로 전송하는 것
= 내 데이터로 다시 재 학습(보강해서)해서 사용

 

MobileNetV2 는 경량화된 딥러닝 아키텍처로, 모바일 및 임베디드 기기에서도 효율적으로 사용될 수 있으며

정확도 또한 많이 떨어지지 않게하여 속도와 정확도 사이의 트레이드 오프 문제를 어느정도 해결한 네트워크이다.

 

 

트랜스퍼 러닝은 학습이 잘 된 모델을 가져와서 나의 문제에 맞게 활용하는 것이므로
학습이 잘 된 모델의 헤드 모델을 뺀 베이스 모델만 가져온다.

 

 

import tensorflow as tf

# 만들려는 모델의 인풋 이미지는 (128, 128, 3)으로 한다
>>> IMG_SHAPE = (128, 128, 3)

>>> base_model =  tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                                    include_top=False,      # 헤드 모델
                                                    weights='imagenet')

 

잘 만들어진 베이스 모델 부분은 이미 1400만장으로 학습이 잘 되어있어서
이미지의 특징을 뽑아내는 역할을 하는 부분이므로
우리 데이터로는 학습되지 않도록 한다.

>>> base_model.trainable = False
# Trainable params: 0 이 된다
# .summary() 로 Trainable params 확인

 

헤드 모델을 만든다.

from keras.layers import Flatten, Dense

>>> head_model = base_model.output

# Flatten()
>>> head_model = Flatten()(head_model)

# 히든 레이어
>>> head_model = Dense(128, 'relu')(head_model)

# 아웃풋 레이어
>>> head_model = Dense(1, 'sigmoid')(head_model)

# 모두 헤드 모델 뒤에 붙힌다는 뜻

 

위의 두 모델을 하나의 모델로 합친다.

from keras.models import Model

>>> model = Model(inputs= base_model.input, outputs= head_model)

 

컴파일

from keras.optimizers import RMSprop

>>> model.compile(RMSprop(learning_rate=0.0001),loss='binary_crossentropy', metrics=['accuracy'])

>>> train_datagen = ImageDataGenerator(rescale=1/255,
                                        zoom_range=0.2,
                                        width_shift_range=0.2,
                                        height_shift_range=0.2)

>>> train_generator = train_datagen.flow_from_directory(train_dir,
                                                        target_size=(128,128),
                                                        class_mode='binary')

>>> val_datagen = ImageDataGenerator(rescale=1/255)
>>> val_generator = val_datagen.flow_from_directory(val_dir,
                                                    target_size=(128,128),
                                                    class_mode='binary')

 

학습

>>> epoch_history = model.fit(train_generator,
                              epochs=5,
                              validation_data= val_generator)