Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UI updates #1

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added android/app/.DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ android {
kotlinOptions {
jvmTarget = rootProject.ext.java_version
}
namespace 'org.tensorflow.lite.examples.audio'
}

// import DownloadModels task
Expand All @@ -63,7 +64,6 @@ dependencies {
implementation 'androidx.localbroadcastmanager:localbroadcastmanager:1.0.0'
implementation 'com.google.android.material:material:1.4.0'


// Navigation library
def nav_version = "2.3.5"
implementation "androidx.navigation:navigation-fragment-ktx:$nav_version"
Expand Down
Binary file added android/app/src/main/.DS_Store
Binary file not shown.
7 changes: 4 additions & 3 deletions android/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.examples.audio">
<manifest xmlns:android="http://schemas.android.com/apk/res/android">

<uses-permission android:name="android.permission.RECORD_AUDIO" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
Expand All @@ -24,7 +23,6 @@
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:taskAffinity=""
android:theme="@style/AppTheme">
Expand All @@ -38,6 +36,9 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<meta-data
android:name="preloaded_fonts"
android:resource="@array/preloaded_fonts" />
</application>

</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.tensorflow.lite.examples.audio

import android.annotation.SuppressLint
import android.content.Context
import android.media.AudioRecord
import android.os.SystemClock
Expand Down Expand Up @@ -53,14 +54,14 @@ import org.tensorflow.lite.task.core.BaseOptions


class AudioClassificationHelper(
val context: Context,
val listener: AudioClassificationListener,
var currentModel: String = YAMNET_MODEL,
var classificationThreshold: Float = DISPLAY_THRESHOLD,
var overlap: Float = DEFAULT_OVERLAP_VALUE,
var numOfResults: Int = DEFAULT_NUM_OF_RESULTS,
var currentDelegate: Int = 0,
var numThreads: Int = 2
val context: Context,
val listener: AudioClassificationListener,
var currentModel: String = YAMNET_MODEL,
var classificationThreshold: Float = DISPLAY_THRESHOLD,
var overlap: Float = DEFAULT_OVERLAP_VALUE,
var numOfResults: Int = DEFAULT_NUM_OF_RESULTS,
var currentDelegate: Int = 0,
var numThreads: Int = 2
) {
private var interpreter: Interpreter? = null
private lateinit var classifier: AudioClassifier
Expand All @@ -86,9 +87,10 @@ class AudioClassificationHelper(
7 to "Knocking",
8 to "Siren",
9 to "Water Running"
)
)

private var rmsThreshold = 0.01f
private var isTraining = false

private val classifyRunnable = Runnable {
classifyAudio()
Expand Down Expand Up @@ -123,6 +125,7 @@ class AudioClassificationHelper(
return rms
}

@SuppressLint("MissingPermission")
fun startAudioClassification() {
val format = TensorAudio.TensorAudioFormat.builder()
.setChannels(1)
Expand All @@ -146,9 +149,9 @@ class AudioClassificationHelper(
// For example, YAMNET expects 0.975 second length recordings.
// This needs to be in milliseconds to avoid the required Long value dropping decimals.
// val lengthInMilliSeconds = ((classifier.requiredInputBufferSize * 1.0f) /
// classifier.requiredTensorAudioFormat.sampleRate) * 1000
// classifier.requiredTensorAudioFormat.sampleRate) * 1000

val lengthInMilliSeconds = 1000 // one second
val lengthInMilliSeconds = 1000 // one second

// val interval = (lengthInMilliSeconds * (1 - overlap)).toLong()
val interval = (1000).toLong()
Expand All @@ -163,13 +166,22 @@ class AudioClassificationHelper(
private fun classifyAudio() {
tensorAudio.load(recorder) // 1, 15600(0.975*sr)

//if the recorder records sounds
//is this the audio data?
var audioData = tensorAudio.tensorBuffer.floatArray
for(data in audioData){
Log.d("Data", "Audio Data: $data")
}



synchronized(lock) {
val rms = calculateRMS(tensorAudio.getTensorBuffer().getFloatArray())
// Log.d("AudioClassificationHelper", "rms: " + rms)
if (rms > rmsThreshold){ // TODO: the method to define the threshold for sound happening
val sr = recorder.getSampleRate()
var inferenceTime = SystemClock.uptimeMillis()

val inputs: MutableMap<String, Any> = HashMap()
inputs["x"] = tensorAudio.getTensorBuffer().buffer

Expand All @@ -196,12 +208,12 @@ class AudioClassificationHelper(

inferenceTime = SystemClock.uptimeMillis() - inferenceTime
listener.onResult(tensorAudio.getTensorBuffer().getFloatArray(),arrayOf(lbl[0].toFloat()).toFloatArray(),id2lblMap[lbl[0].toInt()].toString(), class_probs, inferenceTime)
}
}
else { // no sound
listener.onResult(tensorAudio.getTensorBuffer().getFloatArray(),arrayOf(1f).toFloatArray(),"silence", floatArrayOf(0f), 0)
}
}

}

fun stopAudioClassification() {
Expand All @@ -216,7 +228,7 @@ class AudioClassificationHelper(
categoricalLabel[label.get(0).toInt()] = 1f
return categoricalLabel
}


// Add data to the data buffer
fun collectSample(audio: FloatArray, label: FloatArray) {
Expand All @@ -239,6 +251,11 @@ class AudioClassificationHelper(
}
}

// Check if it's training
fun isModelTraining(): Boolean {
return isTraining
}

// Running the interpreter's signature function
private fun trainOneStep(
x: MutableList<FloatArray>, y: MutableList<FloatArray>
Expand All @@ -265,46 +282,13 @@ class AudioClassificationHelper(
)
)
}

Log.d("AudioClassificationHelper","Start fine-tuning")

// trainingExecutor = Executors.newSingleThreadExecutor()

// trainingExecutor?.execute {
// synchronized(lock) {
// var avgLoss: Float
// var numIterations = 0
// while (trainingExecutor!!.isShutdown == false) {
// var totalLoss = 0f
// // training
// dataBuffer.shuffle() // might not need to do this

// val trainingBatchAudios =
// MutableList(BATCH_SIZE) { FloatArray(44100) }

// val trainingBatchLabels =
// MutableList(BATCH_SIZE) { FloatArray(10) }

// dataBuffer.forEachIndexed { index, sample ->
// trainingBatchAudios[index] = sample.audio
// trainingBatchLabels[index] = sample.label
// }

// val loss = trainOneStep(trainingBatchAudios,trainingBatchLabels)
// numIterations++

// totalLoss += loss

// avgLoss = totalLoss / numIterations
// handler.post {
// listener.onTrainResult(avgLoss, numIterations)
// }
// }
// }
// }
var avgLoss: Float
isTraining = true
var meanLoss: Float = 1000f
var numIterations = 0
while (numIterations < 5) {
while (numIterations < 10 && meanLoss > 1) {
var totalLoss = 0f
// training
dataBuffer.shuffle() // might not need to do this
Expand All @@ -325,15 +309,16 @@ class AudioClassificationHelper(

val loss = trainOneStep(trainingBatchAudios,trainingBatchLabels)
numIterations++

totalLoss += loss
avgLoss = totalLoss / numIterations

meanLoss = loss
handler.post {
listener.onTrainResult(loss, numIterations)
}
}
}
dataBuffer.clear()
isTraining = false
val outfile: File = File(context.getFilesDir(), "sc_model.tflite")
Log.d("AudioClassificationHelper",outfile.getAbsolutePath())
val inputs: MutableMap<String, Any> = HashMap()
Expand Down Expand Up @@ -366,7 +351,7 @@ class AudioClassificationHelper(
// Log.d("AudioClassificationHelper",outfile.lastModified().toString())
}



/* End of on-edge training */

Expand All @@ -378,8 +363,8 @@ class AudioClassificationHelper(
const val DEFAULT_OVERLAP_VALUE = 0.5f
const val YAMNET_MODEL = "yamnet.tflite"
const val SPEECH_COMMAND_MODEL = "speech.tflite"
const val BATCH_SIZE = 20
const val BATCH_SIZE = 10
}

data class TrainingSample(val audio: FloatArray, val label: FloatArray)
}
}
Loading