Android ✖︎ 画像認識
🤖 TensorFlow・TensorFlow Lite
歴史
1.0:2017/2
Lite:2017/5
2.0:2019/9
1.0と2.0の中身は結構違います
違い:
TensorFlow 1.0と2.0:MLモデルの構築と訓練用
TensorFlow Lite:エンドデバイス用
TensorFlow Liteは複数のプラットフォームをサポートしている
Linux
📲 Androidでの機械学習
2017/5 Android 8(Oreo)からサポートされる
端末で.tfliteのMLモデルを直接利用できる
MLモデルついて
オフィシャル・サードパーティから提供された訓練済みMLモデルを使う
自分でTensorFlow LiteのMLモデルを訓練する
訓練したTensorFlow MLモデル → TensorFlow Liteへ変換する
🐋 詳細
TensorFlow Lite用のMLモデル:
Kaggleから訓練済みMLモデルをダウンロードする
アフリカ
アジア
ユーロッパ
北米
南アメリカ
オセアニア
モデルをPJに追加する
ライブラリ(CameraX・TensorFlow)を build.gradle.kts に追加する
// build.gradle.kts
val cameraXVersion = "1.3.2"
// CameraX
implementation("androidx.camera:camera-core:$cameraXVersion")
implementation("androidx.camera:camera-camera2:$cameraXVersion")
implementation("androidx.camera:camera-lifecycle:$cameraXVersion")
implementation("androidx.camera:camera-video:$cameraXVersion")
implementation("androidx.camera:camera-view:$cameraXVersion")
implementation("androidx.camera:camera-extensions:$cameraXVersion")
...
// TensorFlow
implementation("org.tensorflow:tensorflow-lite-task-vision:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0")
implementation("org.tensorflow:tensorflow-lite-gpu:2.9.0")
Manifestにカメラ権限を追加する
// AndroidManifest.xml
<uses-feature
android:name="android.hardware.camera"
android:required="false" />
<uses-permission android:name="android.permission.CAMERA" />
追加したMLモデルを使う
認識結果のモデルを定義する
// Classification.kt
data class Classification(
val name: String,
val score: Float
)
認識ロジック
// Classifier.kt
interface Classifier {
fun classify(bitmap: Bitmap, rotation: Int): List<Classification>
}
// ImageClassifier.kt
class ImageClassifier(
private val context: Context,
private val threshold: Float = 0.5f,
private val maxResults: Int = 3
): Classifier {
private var classifier: ImageClassifier? = null
private fun setupClassifier() {
val baseOptions = BaseOptions
.builder()
.setNumThreads(2)
.build()
val options = ImageClassifier.ImageClassifierOptions
.builder()
.setBaseOptions(baseOptions)
.setMaxResults(maxResults)
.setScoreThreshold(threshold)
.build()
// モデルを使う
try {
classifier = ImageClassifier.createFromFileAndOptions(
context,
"landmark_asia.tflite",
options
)
} catch (e: IllegalStateException) {
e.printStackTrace()
}
}
override fun classify(bitmap: Bitmap, rotation: Int): List<Classification> {
if(classifier == null) {
setupClassifier()
}
val imageProcessor = ImageProcessor.Builder().build()
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
val imageProcessingOptions = ImageProcessingOptions
.builder()
.setOrientation(getOrientationFromRotation(rotation))
.build()
val results = classifier?.classify(tensorImage, imageProcessingOptions)
return results?.flatMap { classification ->
classification.categories.map { category ->
Classification(
name = category.displayName,
score = category.score
)
}
}?.distinctBy { it.name } ?: emptyList()
}
// インプットの角度処理
private fun getOrientationFromRotation(rotation: Int): ImageProcessingOptions.Orientation {
return when(rotation) {
Surface.ROTATION_270 -> ImageProcessingOptions.Orientation.BOTTOM_RIGHT
Surface.ROTATION_90 -> ImageProcessingOptions.Orientation.TOP_LEFT
Surface.ROTATION_180 -> ImageProcessingOptions.Orientation.RIGHT_BOTTOM
else -> ImageProcessingOptions.Orientation.RIGHT_TOP
}
}
}
CameraXの分析ロジック
// ImageAnalyzer.kt
class ImageAnalyzer(
private val classifier: Classifier,
private val onResults: (List<Classification>) -> Unit
): ImageAnalysis.Analyzer {
// 認識結果が変動しすぎないように、counterを追加する
private var frameSkipCounter = 0
override fun analyze(image: ImageProxy) {
if(frameSkipCounter % 60 == 0) {
val rotationDegrees = image.imageInfo.rotationDegrees
// インプット画像のフォマット
val bitmap = image
.toBitmap()
.centerCrop(
desiredWidth = 321,
desiredHeight = 321
)
val results = classifier.classify(bitmap, rotationDegrees)
onResults(results)
}
frameSkipCounter++
image.close()
}
}
// BitmapExtension.kt
fun Bitmap.centerCrop(desiredWidth: Int, desiredHeight: Int): Bitmap {
val xStart = (width - desiredWidth) / 2
val yStart = (height - desiredHeight) / 2
if(xStart < 0 || yStart < 0 || desiredWidth > width || desiredHeight > height) {
throw IllegalArgumentException("Invalid arguments for center cropping")
}
return Bitmap.createBitmap(this, xStart, yStart, desiredWidth, desiredHeight)
}
UI画面
// MainActivity.kt
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
if(!hasCameraPermission()) {
ActivityCompat.requestPermissions(
this, arrayOf(Manifest.permission.CAMERA), 0
)
}
setContent {
LandmarkRecognitionTensorflowTheme {
// 認識結果リスト
var classifications by remember {
mutableStateOf(emptyList<Classification>())
}
// CameraXのアナライザ
val analyzer = remember {
ImageAnalyzer(
classifier = ImageClassifier(
context = applicationContext
),
onResults = { resultList ->
classifications = resultList
}
)
}
// CameraXのコントローラーを設置
val controller = remember {
LifecycleCameraController(applicationContext).apply {
setEnabledUseCases(CameraController.IMAGE_ANALYSIS)
setImageAnalysisAnalyzer(
ContextCompat.getMainExecutor(applicationContext),
analyzer
)
}
}
Box(
modifier = Modifier
.fillMaxSize()
) {
CameraScreen(controller, Modifier.fillMaxSize())
Column(
modifier = Modifier
.fillMaxWidth()
.align(Alignment.TopCenter)
) {
classifications.forEach { classification ->
val percentage = "%.1f".format(classification.score * 100)
// テキストで認識結果を表示
Text(
text = "${classification.name}\n($percentage%)",
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.primaryContainer)
.padding(8.dp),
textAlign = TextAlign.Center,
fontSize = 20.sp,
color = MaterialTheme.colorScheme.primary
)
}
}
}
}
}
}
// 権限チェック
private fun hasCameraPermission() = ContextCompat.checkSelfPermission(
this, Manifest.permission.CAMERA
) == PackageManager.PERMISSION_GRANTED
}
// CameraScreen.kt
@Composable
fun CameraScreen(
controller: LifecycleCameraController,
modifier: Modifier = Modifier
) {
val lifecycleOwner = LocalLifecycleOwner.current
AndroidView(
factory = { context ->
// CameraXのPreviewViewを使う
PreviewView(context).apply {
this.controller = controller
// ライフサイクルにbind
controller.bindToLifecycle(lifecycleOwner)
}
},
modifier = modifier
)
}
実機で確認する
💭 その他
認識結果の精度を上げたい場合は、自分で訓練したモデルを利用することをおすすめ
世界名勝データセット