Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HuggingFace Model Explorer #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ android {
defaultConfig {
applicationId = "io.shubham0204.smollmandroid"
minSdk = 26
targetSdk = 34
targetSdk = 35
versionCode = 1
versionName = "1.0"

Expand Down Expand Up @@ -85,6 +85,7 @@ dependencies {
implementation(libs.androidx.compose.navigation)

implementation(project(":smollm"))
implementation(project(":hf-model-hub-api"))

// Koin: dependency injection
implementation(libs.koin.android)
Expand All @@ -107,6 +108,11 @@ dependencies {
implementation("io.noties:prism4j:2.0.0")
annotationProcessor("io.noties:prism4j-bundler:2.0.0")

// Jetpack Paging3: loading paged data for Compose
val pagingVersion = "3.3.5"
implementation("androidx.paging:paging-runtime:$pagingVersion")
implementation("androidx.paging:paging-compose:$pagingVersion")

testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
Expand Down
68 changes: 68 additions & 0 deletions app/src/main/java/io/shubham0204/smollmandroid/data/HFModelsAPI.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (C) 2024 Shubham Panchal
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.shubham0204.smollmandroid.data

import androidx.paging.Pager
import androidx.paging.PagingConfig
import androidx.paging.PagingSource
import androidx.paging.PagingState
import io.shubham0204.hf_model_hub_api.HFModelInfo
import io.shubham0204.hf_model_hub_api.HFModelSearch
import io.shubham0204.hf_model_hub_api.HFModelTree
import io.shubham0204.hf_model_hub_api.HFModels
import org.koin.core.annotation.Single

@Single
class HFModelsAPI {
suspend fun getModelInfo(modelId: String): HFModelInfo.ModelInfo = HFModels.info.getModelInfo(modelId)

suspend fun getModelTree(modelId: String): List<HFModelTree.HFModelFile> = HFModels.tree.getModelFileTree(modelId)

fun getModelsList(query: String) =
Pager(
config =
PagingConfig(
pageSize = 20,
),
pagingSourceFactory = {
HFModelSearchPagedDataSource(query)
},
).flow

class HFModelSearchPagedDataSource(
private val query: String,
) : PagingSource<Int, HFModelSearch.ModelSearchResult>() {
private val ggufModelFilter = "gguf,conversational"
private val pageSize = 20

override suspend fun load(params: LoadParams<Int>): LoadResult<Int, HFModelSearch.ModelSearchResult> {
val pageNumber = params.key ?: 1
val result = HFModels.search.searchModels(query, "", limit = pageSize, filter = ggufModelFilter)
return LoadResult.Page(
data = result,
prevKey = null,
nextKey = if (result.isEmpty()) null else pageNumber + 1,
)
}

override fun getRefreshKey(state: PagingState<Int, HFModelSearch.ModelSearchResult>): Int? =
state.anchorPosition?.let { anchorPosition ->
state.closestPageToPosition(anchorPosition)?.prevKey?.plus(1)
?: state.closestPageToPosition(anchorPosition)?.nextKey?.minus(1)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@ import androidx.compose.animation.AnimatedContent
import androidx.compose.animation.core.tween
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.slideInHorizontally
import androidx.compose.animation.slideOutHorizontally
import androidx.compose.animation.togetherWith
import androidx.compose.animation.with
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
Expand All @@ -43,11 +39,9 @@ import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.CalendarToday
import androidx.compose.material.icons.filled.Delete
import androidx.compose.material.icons.filled.Title
import androidx.compose.material3.DividerDefaults
import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
Expand Down Expand Up @@ -77,7 +71,7 @@ import java.io.File

enum class SortOrder {
NAME,
DATE_ADDED
DATE_ADDED,
}

@Composable
Expand All @@ -89,7 +83,7 @@ fun SelectModelsList(
showModelDeleteIcon: Boolean = true,
) {
val context = LocalContext.current
var sortOrder by remember{ mutableStateOf(SortOrder.NAME) }
var sortOrder by remember { mutableStateOf(SortOrder.NAME) }
Dialog(onDismissRequest = onDismissRequest) {
Column(
modifier =
Expand All @@ -102,7 +96,7 @@ fun SelectModelsList(
Spacer(modifier = Modifier.height(4.dp))
SmallLabelText(
"Select a downloaded model from below to use as a 'default' model for this chat.",
textColor = Color.DarkGray
textColor = Color.DarkGray,
)
Spacer(modifier = Modifier.height(16.dp))

Expand All @@ -112,37 +106,40 @@ fun SelectModelsList(
sortOrder,
transitionSpec = {
fadeIn(
animationSpec = tween(100)
animationSpec = tween(100),
) togetherWith fadeOut(animationSpec = tween(100))
},
modifier = Modifier.clickable(
interactionSource = remember { MutableInteractionSource() },
indication = null
) {
sortOrder = when (sortOrder) {
SortOrder.NAME -> SortOrder.DATE_ADDED
SortOrder.DATE_ADDED -> SortOrder.NAME
}
},
modifier =
Modifier.clickable(
interactionSource = remember { MutableInteractionSource() },
indication = null,
) {
sortOrder =
when (sortOrder) {
SortOrder.NAME -> SortOrder.DATE_ADDED
SortOrder.DATE_ADDED -> SortOrder.NAME
}
},
label = "change-sort-order-anim",
) { targetSortOrder: SortOrder ->
Row(
verticalAlignment = Alignment.CenterVertically,
modifier = Modifier
.align(Alignment.End)
.clickable(
interactionSource = remember { MutableInteractionSource() },
indication = null
) {
sortOrder = if (sortOrder == SortOrder.NAME) SortOrder.DATE_ADDED else SortOrder.NAME
}
modifier =
Modifier
.align(Alignment.End)
.clickable(
interactionSource = remember { MutableInteractionSource() },
indication = null,
) {
sortOrder = if (sortOrder == SortOrder.NAME) SortOrder.DATE_ADDED else SortOrder.NAME
},
) {
when (targetSortOrder) {
SortOrder.DATE_ADDED -> {
Image(
imageVector = Icons.Default.Title,
contentDescription = "Sort by Model Name",
colorFilter = ColorFilter.tint(Color.DarkGray)
colorFilter = ColorFilter.tint(Color.DarkGray),
)
Spacer(modifier = Modifier.width(4.dp))
SmallLabelText("Sort by Name")
Expand All @@ -151,7 +148,7 @@ fun SelectModelsList(
Image(
imageVector = Icons.Default.CalendarToday,
contentDescription = "Sort by Date Added",
colorFilter = ColorFilter.tint(Color.DarkGray)
colorFilter = ColorFilter.tint(Color.DarkGray),
)
Spacer(modifier = Modifier.width(4.dp))
SmallLabelText("Sort by Date Added")
Expand All @@ -160,11 +157,10 @@ fun SelectModelsList(
}
}


Spacer(modifier = Modifier.height(8.dp))
LazyColumn(modifier = Modifier.heightIn(max = 300.dp)) {
if (sortOrder == SortOrder.NAME) {
items(modelsList.sortedBy{ it.name }) {
items(modelsList.sortedBy { it.name }) {
ModelListItem(
model = it,
onModelListItemClick,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.compose.setContent
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.Arrangement
Expand Down Expand Up @@ -61,21 +63,52 @@ import androidx.compose.ui.graphics.Color
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.unit.dp
import androidx.core.net.toUri
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.rememberNavController
import io.shubham0204.smollmandroid.llm.exampleModelsList
import io.shubham0204.smollmandroid.ui.components.AppProgressDialog
import io.shubham0204.smollmandroid.ui.screens.chat.ChatActivity
import io.shubham0204.smollmandroid.ui.theme.AppAccentColor
import io.shubham0204.smollmandroid.ui.theme.AppFontFamily
import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme
import org.koin.androidx.compose.koinViewModel
import org.koin.android.ext.android.inject

class DownloadModelActivity : ComponentActivity() {
private var openChatScreen: Boolean = true
private val viewModel: DownloadModelsViewModel by inject()

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContent { DownloadModelScreen() }

setContent {
val navController = rememberNavController()
NavHost(
navController = navController,
startDestination = "download-model",
enterTransition = { fadeIn() },
exitTransition = { fadeOut() },
) {
composable("view-model") {
ViewHFModelScreen(viewModel)
}
composable("hf-model-select") {
HFModelDownloadScreen(
viewModel,
onBackClicked = { navController.navigateUp() },
onModelClick = { modelId ->
viewModel.viewModelId = modelId
navController.navigate("view-model")
},
)
}
composable("download-model") {
DownloadModelScreen(
viewModel,
onHFModelSelectClick = { navController.navigate("hf-model-select") },
)
}
}
}
openChatScreen = intent.extras?.getBoolean("openChatScreen") ?: true
}

Expand All @@ -91,9 +124,10 @@ class DownloadModelActivity : ComponentActivity() {
}

@Composable
private fun DownloadModelScreen() {
val viewModel: DownloadModelsViewModel = koinViewModel()

private fun DownloadModelScreen(
viewModel: DownloadModelsViewModel,
onHFModelSelectClick: () -> Unit,
) {
val launcher =
rememberLauncherForActivityResult(ActivityResultContracts.StartActivityForResult()) { activityResult ->
activityResult.data?.let {
Expand Down Expand Up @@ -134,6 +168,9 @@ class DownloadModelActivity : ComponentActivity() {
ModelsList(viewModel)
Spacer(modifier = Modifier.height(4.dp))
ModelURLInput(viewModel)
OutlinedButton(onClick = onHFModelSelectClick) {
Text("Browse from HuggingFace")
}
Spacer(modifier = Modifier.height(4.dp))
OutlinedButton(
enabled =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import android.widget.Toast
import androidx.compose.runtime.mutableStateOf
import androidx.core.net.toUri
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import androidx.paging.PagingData
import androidx.paging.cachedIn
import io.shubham0204.hf_model_hub_api.HFModelInfo
import io.shubham0204.hf_model_hub_api.HFModelSearch
import io.shubham0204.hf_model_hub_api.HFModelTree
import io.shubham0204.smollmandroid.data.HFModelsAPI
import io.shubham0204.smollmandroid.data.LLMModel
import io.shubham0204.smollmandroid.data.ModelsDB
import io.shubham0204.smollmandroid.ui.components.hideProgressDialog
Expand All @@ -33,6 +40,9 @@ import io.shubham0204.smollmandroid.ui.components.setProgressDialogTitle
import io.shubham0204.smollmandroid.ui.components.showProgressDialog
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.koin.core.annotation.Single
Expand All @@ -44,10 +54,16 @@ import java.nio.file.Paths
class DownloadModelsViewModel(
val context: Context,
val modelsDB: ModelsDB,
val hfModelsAPI: HFModelsAPI,
) : ViewModel() {
private val _modelInfoAndTree = MutableStateFlow<Pair<HFModelInfo.ModelInfo, List<HFModelTree.HFModelFile>>?>(null)
val modelInfoAndTree: StateFlow<Pair<HFModelInfo.ModelInfo, List<HFModelTree.HFModelFile>>?> = _modelInfoAndTree

val selectedModelState = mutableStateOf<LLMModel?>(null)
val modelUrlState = mutableStateOf("")

var viewModelId: String? = null

private val downloadManager =
context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager

Expand All @@ -71,6 +87,9 @@ class DownloadModelsViewModel(
downloadManager.enqueue(request)
}

fun getModels(query: String): Flow<PagingData<HFModelSearch.ModelSearchResult>> =
hfModelsAPI.getModelsList(query).cachedIn(viewModelScope)

/**
* Given the model file URI, copy the model file to the app's internal directory. Once copied,
* add a new LLMModel entity with modelName=fileName where fileName is the name of the model
Expand Down Expand Up @@ -112,4 +131,10 @@ class DownloadModelsViewModel(
Toast.makeText(context, "Invalid file", Toast.LENGTH_SHORT).show()
}
}

fun fetchModelInfoAndTree(modelId: String) {
CoroutineScope(Dispatchers.IO).launch {
_modelInfoAndTree.value = Pair(hfModelsAPI.getModelInfo(modelId), hfModelsAPI.getModelTree(modelId))
}
}
}
Loading