I am working on a simple Create ML project. I trained a custom model on classifying the images of US dollar bill notes. Everything seems good to me and I don't know why the classification label isn't being updated with any value.
Here's what I have so far:
UIImage+Extensions.swift
import Foundation
import UIKit
extension UIImage {
func resizeTo(size: CGSize) -> UIImage? {
UIGraphicsBeginImageContextWithOptions(size, false, 0.0)
self.draw(in: CGRect(origin: CGPoint.zero, size: size))
let resizedImage = UIGraphicsGetImageFromCurrentImageContext()!
UIGraphicsEndImageContext()
return resizedImage
}
func toBuffer() -> CVPixelBuffer? {
let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue,
kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary
var pixelBuffer : CVPixelBuffer?
let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(self.size.width), Int(self.size.height),
kCVPixelFormatType_32ARGB, attrs, &pixelBuffer)
guard (status == kCVReturnSuccess) else {
return nil
}
CVPixelBufferLockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0))
let pixelData = CVPixelBufferGetBaseAddress(pixelBuffer!)
let rgbColorSpace = CGColorSpaceCreateDeviceRGB()
let context = CGContext(data: pixelData, width: Int(self.size.width), height: Int(self.size.height),
bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBuffer!), space: rgbColorSpace,
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue)
context?.translateBy(x: 0, y: self.size.height)
context?.scaleBy(x: 1.0, y: -1.0)
UIGraphicsPushContext(context!)
self.draw(in: CGRect(x: 0, y: 0, width: self.size.width, height: self.size.height))
UIGraphicsPopContext()
CVPixelBufferUnlockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0))
return pixelBuffer
}
}
ImagePicker.swift
import Foundation
import UIKit
extension UIImage {
func resizeTo(size: CGSize) -> UIImage? {
UIGraphicsBeginImageContextWithOptions(size, false, 0.0)
self.draw(in: CGRect(origin: CGPoint.zero, size: size))
let resizedImage = UIGraphicsGetImageFromCurrentImageContext()!
UIGraphicsEndImageContext()
return resizedImage
}
func toBuffer() -> CVPixelBuffer? {
let attrs = [kCVPixelBufferCGImageCompatibilityKey: kCFBooleanTrue,
kCVPixelBufferCGBitmapContextCompatibilityKey: kCFBooleanTrue] as CFDictionary
var pixelBuffer : CVPixelBuffer?
let status = CVPixelBufferCreate(kCFAllocatorDefault, Int(self.size.width), Int(self.size.height),
kCVPixelFormatType_32ARGB, attrs, &pixelBuffer)
guard (status == kCVReturnSuccess) else {
return nil
}
CVPixelBufferLockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0))
let pixelData = CVPixelBufferGetBaseAddress(pixelBuffer!)
let rgbColorSpace = CGColorSpaceCreateDeviceRGB()
let context = CGContext(data: pixelData, width: Int(self.size.width), height: Int(self.size.height),
bitsPerComponent: 8, bytesPerRow: CVPixelBufferGetBytesPerRow(pixelBuffer!), space: rgbColorSpace,
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue)
context?.translateBy(x: 0, y: self.size.height)
context?.scaleBy(x: 1.0, y: -1.0)
UIGraphicsPushContext(context!)
self.draw(in: CGRect(x: 0, y: 0, width: self.size.width, height: self.size.height))
UIGraphicsPopContext()
CVPixelBufferUnlockBaseAddress(pixelBuffer!, CVPixelBufferLockFlags(rawValue: 0))
return pixelBuffer
}
}
ContentView.swift
import SwiftUI
import UIKit
struct ContentView: View {
let photos = ["dollar", "dollar", "dollar"]
@State private var currentIndex: Int = 0
@State private var classificationLabel: String = ""
let model = currency_classifier()
private func performImageClassification() {
let currentImageName = photos[currentIndex]
guard let img = UIImage(named: currentImageName),
let resizedImage = img.resizeTo(size:CGSize(width: 299,
height: 299)),
let buffer = resizedImage.toBuffer() else {
return
}
let output = try? model.prediction(image: buffer)
if let output = output {
// "Banana" : 99.9%
// self.classificationLabel = output.classLabel
let results = output.classLabelProbs.sorted { $0.1 > $1.1 }
let result = results.map { (key, value) in
return "\(key) = \(value * 100)%"
}.joined(separator: "\n")
self.classificationLabel = output.classLabel
}
}
@State private var showSheet: Bool = false
@State private var showImagePicker: Bool = false
@State private var sourceType: UIImagePickerController.SourceType = .camera
@State private var image: UIImage?
var body: some View {
VStack {
if image != nil {
Image(uiImage: image ?? UIImage(named: "placeholder")!)
.resizable()
.frame(width:299, height: 299)
}
Button("Choose Picture") {
self.showSheet = true
}.padding()
.actionSheet(isPresented: $showSheet) {
ActionSheet(title: Text("Select Photo"),
message: Text("Choose"), buttons: [
.default(Text("Photo Library")) {
self.showImagePicker = true
self.sourceType = .photoLibrary
},
.default(Text("Camera")) {
// open camera
self.showImagePicker = true
self.sourceType = .camera
},
.cancel()
])
}
HStack {
Button("Previous") {
if self.currentIndex >= self.photos.count {
self.currentIndex = self.currentIndex - 1
} else {
self.currentIndex = 0
}
}.padding()
.foregroundColor(Color.white)
.background(Color.gray)
.cornerRadius(10)
.frame(width: 100)
Button("Next") {
if self.currentIndex < self.photos.count - 1 {
self.currentIndex = self.currentIndex + 1
} else {
self.currentIndex = 0
}
}
.padding()
.foregroundColor(Color.white)
.frame(width: 100)
.background(Color.gray)
.cornerRadius(10)
}.padding()
Button("Classify") {
// Classify the image here
self.performImageClassification()
}.padding()
.foregroundColor(Color.white)
.background(Color.green)
.cornerRadius(8)
Text(classificationLabel)
.font(.largeTitle)
.padding()
}.sheet(isPresented: $showImagePicker) {
ImagePicker(image: self.$image, isShown: self.$showImagePicker, sourceType: self.sourceType)
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
Thank you in advance!