-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
128 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,208 +1,145 @@ | ||
import React, { useState, useEffect } from 'react'; | ||
import { StyleSheet, View,Image,TouchableOpacity,Dimensions,ActivityIndicator } from 'react-native'; | ||
import { Text,Button, Input,Card,Overlay } from 'react-native-elements'; | ||
import * as Permissions from 'expo-permissions'; | ||
import { Camera, getPermissionsAsync } from 'expo-camera'; | ||
import * as ImageManipulator from 'expo-image-manipulator'; | ||
import { StyleSheet, View,Image } from 'react-native'; | ||
import { Button, Input } from 'react-native-elements'; | ||
import Svg, {Rect} from 'react-native-svg'; | ||
import * as tf from '@tensorflow/tfjs'; | ||
import '@tensorflow/tfjs-react-native'; | ||
import * as mobilenet from '@tensorflow-models/mobilenet'; | ||
import * as knnClassifier from '@tensorflow-models/knn-classifier'; | ||
import { fetch, bundleResourceIO } from '@tensorflow/tfjs-react-native'; | ||
import * as blazeface from '@tensorflow-models/blazeface'; | ||
import * as jpeg from 'jpeg-js' | ||
export default function App() { | ||
const statusList=["Loading Model...","Classifying Image...","Predicting Image..."] | ||
const [hasPermission, setHasPermission] = useState(null); | ||
const [type, setType] = useState(Camera.Constants.Type.back); | ||
const [isTfReady,setIsTfReady] = useState(false) | ||
const [mobilenetModel,setMobilenetModel] = useState(null) | ||
const [knnClassifierModel,setKnnClassifierModel] = useState(null) | ||
const [prediction,setPrediction] = useState({ | ||
"label":"No Results", | ||
"confidence":{} | ||
}) | ||
const [status,setStatus]=useState(statusList[0]) | ||
const [isLoading,setIsLoading]=useState(true) | ||
const [countExamples,setCountExamples] = useState(0) | ||
const [countClassExamples,setCountClassExamples] = useState({ | ||
"Class A":0, | ||
"Class B":0, | ||
"Class C":0 | ||
}) | ||
const classList=[ | ||
{ | ||
id:0, | ||
name:"Class A" | ||
}, | ||
{ | ||
id:1, | ||
name:"Class B" | ||
}, | ||
{ | ||
id:2, | ||
name:"Class C" | ||
}, | ||
] | ||
//load tensorflow | ||
useEffect(() => { | ||
async function startup (){ | ||
if(!isTfReady){ | ||
console.log("[+] Loading TF Model") | ||
setStatus(statusList[0]) | ||
setIsLoading(true) | ||
let { status } = await Camera.requestPermissionsAsync(); | ||
setHasPermission(status === 'granted'); | ||
await tf.ready() | ||
setIsTfReady(true) | ||
setMobilenetModel(await mobilenet.load()) | ||
setKnnClassifierModel(await knnClassifier.create()) | ||
setIsLoading(false) | ||
console.log("[+] TF Model Loaded") | ||
const [imageLink,setImageLink] = useState("https://raw.githubusercontent.com/ohyicong/masksdetection/master/dataset/without_mask/142.jpg") | ||
const [isEnabled,setIsEnabled] = useState(true) | ||
const [faces,setFaces]=useState([]) | ||
const [faceDetector,setFaceDetector]=useState("") | ||
const [maskDetector,setMaskDetector]=useState("") | ||
useEffect(() => { | ||
async function loadModel(){ | ||
console.log("[+] Application started") | ||
//Wait for tensorflow module to be ready | ||
const tfReady = await tf.ready(); | ||
console.log("[+] Loading custom mask detection model") | ||
//Replce model.json and group1-shard.bin with your own custom model | ||
const modelJson = await require("./assets/model/model.json"); | ||
const modelWeight = await require("./assets/model/group1-shard.bin"); | ||
const maskDetector = await tf.loadLayersModel(bundleResourceIO(modelJson,modelWeight)); | ||
console.log("[+] Loading pre-trained face detection model") | ||
//Blazeface is a face detection model provided by Google | ||
const faceDetector = await blazeface.load(); | ||
//Assign model to variable | ||
setMaskDetector(maskDetector) | ||
setFaceDetector(faceDetector) | ||
} | ||
loadModel() | ||
}, []); | ||
function imageToTensor(rawImageData){ | ||
//Function to convert jpeg image to tensors | ||
const TO_UINT8ARRAY = true; | ||
const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY); | ||
// Drop the alpha channel info for mobilenet | ||
const buffer = new Uint8Array(width * height * 3); | ||
let offset = 0; // offset into original data | ||
for (let i = 0; i < buffer.length; i += 3) { | ||
buffer[i] = data[offset]; | ||
buffer[i + 1] = data[offset + 1]; | ||
buffer[i + 2] = data[offset + 2]; | ||
offset += 4; | ||
} | ||
return tf.tensor3d(buffer, [height, width, 3]); | ||
} | ||
startup() | ||
},[isTfReady]); | ||
|
||
//1. collect and label images from camera | ||
const collectData = async(className)=>{ | ||
console.log(`[+] Class ${className} selected`) | ||
setStatus(statusList[1]) | ||
setIsLoading(true) | ||
if(this.camera){ | ||
let photo = await this.camera.takePictureAsync({ | ||
skipProcessing: true, | ||
}); | ||
//2. resize images into width:224 height:224 | ||
image = await resizeImage(photo.uri, 224 , 224); | ||
let imageTensor = base64ImageToTensor(image.base64); | ||
//3. get embeddings from mobilenet | ||
let embeddings = await mobilenetModel.infer(imageTensor, true); | ||
//4. train knn classifier | ||
knnClassifierModel.addExample(embeddings,className) | ||
let tempCountExamples = countExamples + 1 | ||
let tempCountClassExamples = countClassExamples | ||
tempCountClassExamples[`${className}`] = tempCountClassExamples[`${className}`] +1 | ||
setCountExamples(tempCountExamples) | ||
setCountClassExamples(tempCountClassExamples) | ||
|
||
console.log("[+] Class Added") | ||
|
||
} | ||
setIsLoading(false) | ||
} | ||
//5. predict new images | ||
const getPredictions = async() =>{ | ||
if(this.camera){ | ||
console.log("[+] Analysing Photo") | ||
setStatus(statusList[2]) | ||
setIsLoading(true) | ||
let photo = await this.camera.takePictureAsync({ | ||
skipProcessing: true, | ||
}); | ||
//resize images into width:224 height:224 | ||
image = await resizeImage(photo.uri, 224 , 224); | ||
let imageTensor = base64ImageToTensor(image.base64); | ||
//get embeddings from mobilenet | ||
let embeddings = await mobilenetModel.infer(imageTensor,true) | ||
//predict with knn classifier | ||
let prediction = await knnClassifierModel.predictClass(embeddings); | ||
console.log(JSON.stringify(prediction)) | ||
setPrediction(prediction) | ||
setIsLoading(false) | ||
console.log("[+] Photo Analysed") | ||
} | ||
} | ||
function base64ImageToTensor(base64){ | ||
//Function to convert jpeg image to tensors | ||
const rawImageData = tf.util.encodeString(base64, 'base64'); | ||
const TO_UINT8ARRAY = true; | ||
const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY); | ||
// Drop the alpha channel info for mobilenet | ||
const buffer = new Uint8Array(width * height * 3); | ||
let offset = 0; // offset into original data | ||
for (let i = 0; i < buffer.length; i += 3) { | ||
buffer[i] = data[offset]; | ||
buffer[i + 1] = data[offset + 1]; | ||
buffer[i + 2] = data[offset + 2]; | ||
offset += 4; | ||
const getFaces = async() => { | ||
try{ | ||
console.log("[+] Retrieving image from link :"+imageLink) | ||
const response = await fetch(imageLink, {}, { isBinary: true }); | ||
const rawImageData = await response.arrayBuffer(); | ||
const imageTensor = imageToTensor(rawImageData).resizeBilinear([224,224]) | ||
const faces = await faceDetector.estimateFaces(imageTensor, false); | ||
var tempArray=[] | ||
//Loop through the available faces, check if the person is wearing a mask. | ||
for (let i=0;i<faces.length;i++){ | ||
let color = "red" | ||
let width = parseInt((faces[i].bottomRight[1] - faces[i].topLeft[1])) | ||
let height = parseInt((faces[i].bottomRight[0] - faces[i].topLeft[0])) | ||
let faceTensor=imageTensor.slice([parseInt(faces[i].topLeft[1]),parseInt(faces[i].topLeft[0]),0],[width,height,3]) | ||
faceTensor = faceTensor.resizeBilinear([224,224]).reshape([1,224,224,3]) | ||
let result = await maskDetector.predict(faceTensor).data() | ||
//if result[0]>result[1], the person is wearing a mask | ||
if(result[0]>result[1]){ | ||
color="green" | ||
} | ||
tempArray.push({ | ||
id:i, | ||
location:faces[i], | ||
color:color | ||
}) | ||
} | ||
setFaces(tempArray) | ||
console.log("[+] Prediction Completed") | ||
}catch{ | ||
console.log("[-] Unable to load image") | ||
} | ||
|
||
} | ||
return tf.tensor3d(buffer, [height, width, 3]); | ||
} | ||
|
||
async function resizeImage(imageUrl, width, height){ | ||
const actions = [{ | ||
resize: { | ||
width, | ||
height | ||
}, | ||
}]; | ||
const saveOptions = { | ||
compress: 0.75, | ||
format: ImageManipulator.SaveFormat.JPEG, | ||
base64: true, | ||
}; | ||
const res = await ImageManipulator.manipulateAsync(imageUrl, actions, saveOptions); | ||
return res; | ||
} | ||
|
||
return ( | ||
<View style={styles.container}> | ||
<Overlay isVisible={isLoading} fullScreen={true} overlayStyle={{alignItems: "center", justifyContent: 'center'}}> | ||
<View> | ||
<Text style={{marginBottom:10}}>{status}</Text> | ||
<ActivityIndicator size="large" color="lightblue" /> | ||
</View> | ||
</Overlay> | ||
<Input | ||
placeholder="image link" | ||
onChangeText = {(inputText)=>{ | ||
console.log(inputText) | ||
setImageLink(inputText) | ||
const elements= inputText.split(".") | ||
if(elements.slice(-1)[0]=="jpg" || elements.slice(-1)[0]=="jpeg"){ | ||
setIsEnabled(true) | ||
}else{ | ||
setIsEnabled(false) | ||
} | ||
}} | ||
value={imageLink} | ||
containerStyle={{height:40,fontSize:10,margin:15}} | ||
inputContainerStyle={{borderRadius:10,borderWidth:1,paddingHorizontal:5}} | ||
inputStyle={{fontSize:15}} | ||
|
||
<Card containerStyle={{width:"100%",marginBottom:10,borderRadius:5}}> | ||
<Card.Title style={{fontSize:16}}>Image Classification</Card.Title> | ||
<Card.Divider/> | ||
<View style={{flexDirection:"row"}}> | ||
{classList.map((item, key) => { | ||
return ( | ||
<View style={{flex:1,padding:5}} key={item.id}> | ||
<Button | ||
title={`${item.name} (${countClassExamples[item.name]})`} | ||
onPress={()=>{collectData(item.name)}} | ||
/> | ||
<View style={{marginBottom:20}}> | ||
<Image | ||
style={{width:224,height:224,borderWidth:2,borderColor:"black",resizeMode: "contain"}} | ||
source={{ | ||
uri: imageLink | ||
}} | ||
PlaceholderContent={<View>No Image Found</View>} | ||
/> | ||
<Svg height="224" width="224" style={{marginTop:-224}}> | ||
{ | ||
faces.map((face)=>{ | ||
return ( | ||
<Rect | ||
key={face.id} | ||
x={face.location.topLeft[0]} | ||
y={face.location.topLeft[1]} | ||
width={(face.location.bottomRight[0] - face.location.topLeft[0])} | ||
height={(face.location.bottomRight[1] - face.location.topLeft[1])} | ||
stroke={face.color} | ||
strokeWidth="3" | ||
fill="" | ||
/> | ||
</View> | ||
); | ||
})} | ||
</View> | ||
</Card> | ||
<View style={{width:224,height:224}}> | ||
<Camera | ||
style={{ flex: 1 }} | ||
type={type} | ||
ref={ref => {this.camera = ref; }}> | ||
</Camera> | ||
</View> | ||
|
||
<View style={{flexDirection:"row",padding:5}}> | ||
<View style={{flex:1,padding:5}}> | ||
<Button | ||
title="Predict" | ||
onPress={()=>{getPredictions()}} | ||
disabled={countExamples==0} | ||
/> | ||
</View> | ||
<View style={{flex:2,padding:5}}> | ||
<Text style={{borderRadius:5,borderWidth:1,padding:10,borderColor:"lightgrey"}}> | ||
{prediction.label} | ||
</Text> | ||
</View> | ||
) | ||
}) | ||
} | ||
</Svg> | ||
</View> | ||
<Button | ||
title="Predict" | ||
onPress={()=>{getFaces()}} | ||
disabled={!isEnabled} | ||
/> | ||
</View> | ||
); | ||
} | ||
const screenWidth = Math.round(Dimensions.get('window').width); | ||
const screenHeight = Math.round(Dimensions.get('window').height); | ||
|
||
const styles = StyleSheet.create({ | ||
container: { | ||
flexDirection:"column", | ||
flex:1, | ||
alignItems: "center", | ||
flex: 1, | ||
backgroundColor: '#fff', | ||
alignItems: 'center', | ||
justifyContent: 'center', | ||
padding:10 | ||
}, | ||
|
||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters