Skip to content

Commit

Permalink
feat: select model in android super resolution app
Browse files Browse the repository at this point in the history
Signed-off-by: tumuyan <[email protected]>
  • Loading branch information
tumuyan committed Nov 2, 2024
1 parent fc552f6 commit 7692eca
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@ dmypy.json
# Hub exports
**/*.mlmodel
**/*.tflite
apps/android/SuperResolution/src/main/res/values/models.xml
35 changes: 32 additions & 3 deletions apps/android/SuperResolution/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ android {
}

preBuild.doFirst {
if (!file("./src/main/assets/" + project.properties['superresolution_tfLiteModelAsset']).exists()) {
throw new RuntimeException(missingModelErrorMsg)
}
generateModelList()

for (int i = 1; i <= 2; ++i) {
String filename = "./src/main/assets/images/Sample${i}.jpg"
Expand Down Expand Up @@ -61,3 +59,34 @@ dependencies {
if (System.getProperty("user.dir") != project.rootDir.path) {
throw new RuntimeException("This project should be opened from the `android` directory (parent of SuperResolution directory), NOT the SuperResolution directory.")
}


def generateModelList() {
def assetsDir = file("${projectDir}/src/main/assets")
def outputDir = file("${projectDir}/src/main/res/values")
def outputFile = file("${outputDir}/models.xml")
if (!outputDir.exists()) {
throw new GradleException("res directory not exist: ${outputDir}")
}
if (!assetsDir.exists()) {
throw new GradleException("assets directory not exist: ${assetsDir}")
}

def files = []
if (assetsDir.exists()) {
files = assetsDir.listFiles().findAll { it.name.endsWith('.tflite') || it.name.endsWith('.bin') }.collect { it.name }
}

def xmlContent = """<?xml version="1.0" encoding="utf-8"?>
<resources>
<string-array name="model_files">
"""
files.each { fileName ->
xmlContent += " <item>${fileName}</item>\n"
}
xmlContent += """ </string-array>
</resources>
"""
outputFile.text = xmlContent

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import android.os.Handler;
import android.os.Looper;
import android.provider.MediaStore;
import android.text.TextUtils;
import android.util.Log;
import android.view.View;
import android.widget.AdapterView;
import android.widget.ArrayAdapter;
Expand Down Expand Up @@ -51,7 +53,7 @@ public class MainActivity extends AppCompatActivity {
ImageView selectedImageView;
TextView inferenceTimeView;
TextView predictionTimeView;
Spinner imageSelector;
Spinner imageSelector, modelSelector;
Button predictionButton;
ActivityResultLauncher<Intent> selectImageResultLauncher;
private final String fromGalleryImageSelectorOption = "From Gallery";
Expand All @@ -62,6 +64,8 @@ public class MainActivity extends AppCompatActivity {
"Sample2.jpg",
fromGalleryImageSelectorOption};

private String[] modelSelectorOptions;

// Inference Elements
Bitmap selectedImage = null; // Raw image, not resized
private SuperResolution defaultDelegateUpscaler;
Expand Down Expand Up @@ -91,6 +95,7 @@ protected void onCreate(Bundle savedInstanceState) {
allDelegatesButton = (RadioButton)findViewById(R.id.defaultDelegateRadio);

imageSelector = (Spinner) findViewById((R.id.imageSelector));
modelSelector = (Spinner) findViewById((R.id.modelSelector));
inferenceTimeView = (TextView)findViewById(R.id.inferenceTimeResultText);
predictionTimeView = (TextView)findViewById(R.id.predictionTimeResultText);
predictionButton = (Button)findViewById(R.id.runModelButton);
Expand Down Expand Up @@ -122,6 +127,26 @@ public void onItemSelected(AdapterView<?> parent, View view, int position, long
public void onNothingSelected(AdapterView<?> parent) { }
});

// Setup Model Selector Dropdown
modelSelectorOptions = getResources().getStringArray(R.array.model_files);
ArrayAdapter modelAdapter = new ArrayAdapter(this, android.R.layout.simple_spinner_item, modelSelectorOptions);
modelAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item);
modelSelector.setAdapter(modelAdapter);
modelSelector.setOnItemSelectedListener(new AdapterView.OnItemSelectedListener() {
@Override
public void onItemSelected(AdapterView<?> parent, View view, int position, long id) {
// Load selected models from assets
((TextView) view).setTextColor(getResources().getColor(R.color.white));
((TextView) view).setEllipsize(TextUtils.TruncateAt.END);

// Exit the UI thread and instantiate the model in the background.
String modelName = parent.getItemAtPosition(position).toString();
createTFLiteUpscalerAsync(modelName);
}

@Override
public void onNothingSelected(AdapterView<?> parent) { }
});
// Setup Image Selection from Phone Gallery
selectImageResultLauncher = registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
Expand Down Expand Up @@ -155,9 +180,6 @@ public void onNothingSelected(AdapterView<?> parent) { }
// Setup button callback
predictionButton.setOnClickListener((view) -> updatePredictionDataAsync());

// Exit the UI thread and instantiate the model in the background.
createTFLiteUpscalerAsync();

// Enable image selection
enableImageSelector();
enableDelegateSelectionButtons();
Expand All @@ -176,12 +198,15 @@ void setInferenceUIEnabled(boolean enabled) {
predictionButton.setAlpha(0.5f);
imageSelector.setEnabled(false);
imageSelector.setAlpha(0.5f);
modelSelector.setEnabled(false);
modelSelector.setAlpha(0.5f);
cpuOnlyButton.setEnabled(false);
allDelegatesButton.setEnabled(false);
} else if (cpuOnlyUpscaler != null && defaultDelegateUpscaler != null && selectedImage != null) {
predictionButton.setEnabled(true);
predictionButton.setAlpha(1.0f);
enableImageSelector();
enableModelSelector();
enableDelegateSelectionButtons();
}
}
Expand All @@ -193,6 +218,13 @@ void enableImageSelector() {
imageSelector.setEnabled(true);
imageSelector.setAlpha(1.0f);
}
/**
* Enable the model selector UI spinner.
*/
void enableModelSelector() {
modelSelector.setEnabled(true);
modelSelector.setAlpha(1.0f);
}

/**
* Enable the image selector UI radio buttons.
Expand Down Expand Up @@ -327,17 +359,18 @@ void updatePredictionDataAsync() {
* Loading the TF Lite model takes time, so this is done asynchronously to the main UI thread.
* Disables the inference UI during load and reenables it afterwards.
*/
void createTFLiteUpscalerAsync() {
void createTFLiteUpscalerAsync(final String tfLiteModelAsset) {
if (defaultDelegateUpscaler != null || cpuOnlyUpscaler != null) {
throw new RuntimeException("Classifiers were already created");
defaultDelegateUpscaler.close();
cpuOnlyUpscaler.close();
// throw new RuntimeException("Classifiers were already created");
}
setInferenceUIEnabled(false);

// Exit the UI thread and instantiate the model in the background.
backgroundTaskExecutor.execute(() -> {
// Create two upscalers.
// One uses the default set of delegates (can access NPU, GPU, CPU), and the other uses only XNNPack (CPU).
String tfLiteModelAsset = this.getResources().getString(R.string.tfLiteModelAsset);
try {
defaultDelegateUpscaler = new SuperResolution(
this,
Expand All @@ -352,6 +385,7 @@ void createTFLiteUpscalerAsync() {
} catch (IOException | NoSuchAlgorithmException e) {
throw new RuntimeException(e.getMessage());
}
Log.i("createTFLiteUpscalerAsync","model load finish: "+tfLiteModelAsset);

mainLooperHandler.post(() -> setInferenceUIEnabled(true));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@

</androidx.cardview.widget.CardView>

<LinearLayout
android:orientation="horizontal"
android:id="@+id/modelSelectorCard"
android:layout_width="409sp"
android:layout_height="50sp"
android:background="@color/purple_qcom"
app:layout_constraintBottom_toTopOf="@id/imageSelectorCard"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent">

<TextView
android:id="@+id/modelSelectorText"
android:layout_width="wrap_content"
android:layout_height="match_parent"
android:layout_marginStart="60sp"
android:gravity="center"
android:text="Model"
android:textColor="@color/white"
android:textSize="17sp" />


<Spinner
android:id="@+id/modelSelector"
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_gravity="end"
android:layout_marginEnd="30sp"
android:layout_weight="1"
android:backgroundTint="@color/white"
android:textAlignment="textEnd"
android:theme="@style/spinnerTheme" />

</LinearLayout>
<androidx.cardview.widget.CardView
android:id="@+id/imageSelectorCard"
android:layout_width="409sp"
Expand Down Expand Up @@ -67,7 +100,7 @@
android:layout_marginBottom="8sp"
android:backgroundTint="@color/purple_qcom"
android:orientation="horizontal"
app:layout_constraintBottom_toTopOf="@+id/imageSelectorCard"
app:layout_constraintBottom_toTopOf="@+id/modelSelectorCard"
tools:layout_editor_absoluteX="2sp">

<RadioButton
Expand Down

0 comments on commit 7692eca

Please sign in to comment.