From 09cfffd32ef003112c197142a19380c3287fa232 Mon Sep 17 00:00:00 2001 From: Shubham Panchal Date: Tue, 7 Jan 2025 11:01:50 +0530 Subject: [PATCH] feat: add HF model explorer #17 --- app/build.gradle.kts | 8 +- .../smollmandroid/data/HFModelsAPI.kt | 68 +++++++++ .../ui/screens/chat/SelectModelsList.kt | 58 ++++---- .../model_download/DownloadModelActivity.kt | 49 ++++++- .../model_download/DownloadModelsViewModel.kt | 25 ++++ .../model_download/HFModelSearchScreen.kt | 111 +++++++++++++++ .../model_download/ViewHFModelScreen.kt | 133 ++++++++++++++++++ build.gradle.kts | 4 +- gradle/libs.versions.toml | 2 + hf-model-hub-api/.gitignore | 1 + hf-model-hub-api/build.gradle.kts | 26 ++++ .../hf_model_hub_api/CustomDateSerializer.kt | 48 +++++++ .../hf_model_hub_api/HFEndpoints.kt | 29 ++++ .../hf_model_hub_api/HFModelInfo.kt | 51 +++++++ .../hf_model_hub_api/HFModelSearch.kt | 99 +++++++++++++ .../hf_model_hub_api/HFModelTree.kt | 42 ++++++ .../shubham0204/hf_model_hub_api/HFModels.kt | 42 ++++++ .../src/test/java/HFModelTests.kt | 98 +++++++++++++ settings.gradle.kts | 3 + 19 files changed, 858 insertions(+), 39 deletions(-) create mode 100644 app/src/main/java/io/shubham0204/smollmandroid/data/HFModelsAPI.kt create mode 100644 app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/HFModelSearchScreen.kt create mode 100644 app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/ViewHFModelScreen.kt create mode 100644 hf-model-hub-api/.gitignore create mode 100644 hf-model-hub-api/build.gradle.kts create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/CustomDateSerializer.kt create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFEndpoints.kt create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelInfo.kt create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelSearch.kt create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelTree.kt create mode 100644 hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModels.kt create mode 100644 hf-model-hub-api/src/test/java/HFModelTests.kt diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 4b2a31a..a072743 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -14,7 +14,7 @@ android { defaultConfig { applicationId = "io.shubham0204.smollmandroid" minSdk = 26 - targetSdk = 34 + targetSdk = 35 versionCode = 1 versionName = "1.0" @@ -85,6 +85,7 @@ dependencies { implementation(libs.androidx.compose.navigation) implementation(project(":smollm")) + implementation(project(":hf-model-hub-api")) // Koin: dependency injection implementation(libs.koin.android) @@ -107,6 +108,11 @@ dependencies { implementation("io.noties:prism4j:2.0.0") annotationProcessor("io.noties:prism4j-bundler:2.0.0") + // Jetpack Paging3: loading paged data for Compose + val pagingVersion = "3.3.5" + implementation("androidx.paging:paging-runtime:$pagingVersion") + implementation("androidx.paging:paging-compose:$pagingVersion") + testImplementation(libs.junit) androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.espresso.core) diff --git a/app/src/main/java/io/shubham0204/smollmandroid/data/HFModelsAPI.kt b/app/src/main/java/io/shubham0204/smollmandroid/data/HFModelsAPI.kt new file mode 100644 index 0000000..74be14a --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/data/HFModelsAPI.kt @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.data + +import androidx.paging.Pager +import androidx.paging.PagingConfig +import androidx.paging.PagingSource +import androidx.paging.PagingState +import io.shubham0204.hf_model_hub_api.HFModelInfo +import io.shubham0204.hf_model_hub_api.HFModelSearch +import io.shubham0204.hf_model_hub_api.HFModelTree +import io.shubham0204.hf_model_hub_api.HFModels +import org.koin.core.annotation.Single + +@Single +class HFModelsAPI { + suspend fun getModelInfo(modelId: String): HFModelInfo.ModelInfo = HFModels.info.getModelInfo(modelId) + + suspend fun getModelTree(modelId: String): List = HFModels.tree.getModelFileTree(modelId) + + fun getModelsList(query: String) = + Pager( + config = + PagingConfig( + pageSize = 20, + ), + pagingSourceFactory = { + HFModelSearchPagedDataSource(query) + }, + ).flow + + class HFModelSearchPagedDataSource( + private val query: String, + ) : PagingSource() { + private val ggufModelFilter = "gguf,conversational" + private val pageSize = 20 + + override suspend fun load(params: LoadParams): LoadResult { + val pageNumber = params.key ?: 1 + val result = HFModels.search.searchModels(query, "", limit = pageSize, filter = ggufModelFilter) + return LoadResult.Page( + data = result, + prevKey = null, + nextKey = if (result.isEmpty()) null else pageNumber + 1, + ) + } + + override fun getRefreshKey(state: PagingState): Int? = + state.anchorPosition?.let { anchorPosition -> + state.closestPageToPosition(anchorPosition)?.prevKey?.plus(1) + ?: state.closestPageToPosition(anchorPosition)?.nextKey?.minus(1) + } + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/SelectModelsList.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/SelectModelsList.kt index 12a62ed..86968e6 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/SelectModelsList.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/SelectModelsList.kt @@ -21,15 +21,11 @@ import androidx.compose.animation.AnimatedContent import androidx.compose.animation.core.tween import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeOut -import androidx.compose.animation.slideInHorizontally -import androidx.compose.animation.slideOutHorizontally import androidx.compose.animation.togetherWith -import androidx.compose.animation.with import androidx.compose.foundation.Image import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.interaction.MutableInteractionSource -import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer @@ -43,11 +39,9 @@ import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.CalendarToday import androidx.compose.material.icons.filled.Delete import androidx.compose.material.icons.filled.Title -import androidx.compose.material3.DividerDefaults import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton @@ -77,7 +71,7 @@ import java.io.File enum class SortOrder { NAME, - DATE_ADDED + DATE_ADDED, } @Composable @@ -89,7 +83,7 @@ fun SelectModelsList( showModelDeleteIcon: Boolean = true, ) { val context = LocalContext.current - var sortOrder by remember{ mutableStateOf(SortOrder.NAME) } + var sortOrder by remember { mutableStateOf(SortOrder.NAME) } Dialog(onDismissRequest = onDismissRequest) { Column( modifier = @@ -102,7 +96,7 @@ fun SelectModelsList( Spacer(modifier = Modifier.height(4.dp)) SmallLabelText( "Select a downloaded model from below to use as a 'default' model for this chat.", - textColor = Color.DarkGray + textColor = Color.DarkGray, ) Spacer(modifier = Modifier.height(16.dp)) @@ -112,37 +106,40 @@ fun SelectModelsList( sortOrder, transitionSpec = { fadeIn( - animationSpec = tween(100) + animationSpec = tween(100), ) togetherWith fadeOut(animationSpec = tween(100)) }, - modifier = Modifier.clickable( - interactionSource = remember { MutableInteractionSource() }, - indication = null - ) { - sortOrder = when (sortOrder) { - SortOrder.NAME -> SortOrder.DATE_ADDED - SortOrder.DATE_ADDED -> SortOrder.NAME - } - }, + modifier = + Modifier.clickable( + interactionSource = remember { MutableInteractionSource() }, + indication = null, + ) { + sortOrder = + when (sortOrder) { + SortOrder.NAME -> SortOrder.DATE_ADDED + SortOrder.DATE_ADDED -> SortOrder.NAME + } + }, label = "change-sort-order-anim", ) { targetSortOrder: SortOrder -> Row( verticalAlignment = Alignment.CenterVertically, - modifier = Modifier - .align(Alignment.End) - .clickable( - interactionSource = remember { MutableInteractionSource() }, - indication = null - ) { - sortOrder = if (sortOrder == SortOrder.NAME) SortOrder.DATE_ADDED else SortOrder.NAME - } + modifier = + Modifier + .align(Alignment.End) + .clickable( + interactionSource = remember { MutableInteractionSource() }, + indication = null, + ) { + sortOrder = if (sortOrder == SortOrder.NAME) SortOrder.DATE_ADDED else SortOrder.NAME + }, ) { when (targetSortOrder) { SortOrder.DATE_ADDED -> { Image( imageVector = Icons.Default.Title, contentDescription = "Sort by Model Name", - colorFilter = ColorFilter.tint(Color.DarkGray) + colorFilter = ColorFilter.tint(Color.DarkGray), ) Spacer(modifier = Modifier.width(4.dp)) SmallLabelText("Sort by Name") @@ -151,7 +148,7 @@ fun SelectModelsList( Image( imageVector = Icons.Default.CalendarToday, contentDescription = "Sort by Date Added", - colorFilter = ColorFilter.tint(Color.DarkGray) + colorFilter = ColorFilter.tint(Color.DarkGray), ) Spacer(modifier = Modifier.width(4.dp)) SmallLabelText("Sort by Date Added") @@ -160,11 +157,10 @@ fun SelectModelsList( } } - Spacer(modifier = Modifier.height(8.dp)) LazyColumn(modifier = Modifier.heightIn(max = 300.dp)) { if (sortOrder == SortOrder.NAME) { - items(modelsList.sortedBy{ it.name }) { + items(modelsList.sortedBy { it.name }) { ModelListItem( model = it, onModelListItemClick, diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelActivity.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelActivity.kt index 583b1ac..9fcd44b 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelActivity.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelActivity.kt @@ -24,6 +24,8 @@ import androidx.activity.ComponentActivity import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.setContent import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut import androidx.compose.foundation.background import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement @@ -61,21 +63,52 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.text.TextStyle import androidx.compose.ui.unit.dp import androidx.core.net.toUri +import androidx.navigation.compose.NavHost +import androidx.navigation.compose.composable +import androidx.navigation.compose.rememberNavController import io.shubham0204.smollmandroid.llm.exampleModelsList import io.shubham0204.smollmandroid.ui.components.AppProgressDialog import io.shubham0204.smollmandroid.ui.screens.chat.ChatActivity import io.shubham0204.smollmandroid.ui.theme.AppAccentColor import io.shubham0204.smollmandroid.ui.theme.AppFontFamily import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme -import org.koin.androidx.compose.koinViewModel +import org.koin.android.ext.android.inject class DownloadModelActivity : ComponentActivity() { private var openChatScreen: Boolean = true + private val viewModel: DownloadModelsViewModel by inject() override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) - setContent { DownloadModelScreen() } - + setContent { + val navController = rememberNavController() + NavHost( + navController = navController, + startDestination = "download-model", + enterTransition = { fadeIn() }, + exitTransition = { fadeOut() }, + ) { + composable("view-model") { + ViewHFModelScreen(viewModel) + } + composable("hf-model-select") { + HFModelDownloadScreen( + viewModel, + onBackClicked = { navController.navigateUp() }, + onModelClick = { modelId -> + viewModel.viewModelId = modelId + navController.navigate("view-model") + }, + ) + } + composable("download-model") { + DownloadModelScreen( + viewModel, + onHFModelSelectClick = { navController.navigate("hf-model-select") }, + ) + } + } + } openChatScreen = intent.extras?.getBoolean("openChatScreen") ?: true } @@ -91,9 +124,10 @@ class DownloadModelActivity : ComponentActivity() { } @Composable - private fun DownloadModelScreen() { - val viewModel: DownloadModelsViewModel = koinViewModel() - + private fun DownloadModelScreen( + viewModel: DownloadModelsViewModel, + onHFModelSelectClick: () -> Unit, + ) { val launcher = rememberLauncherForActivityResult(ActivityResultContracts.StartActivityForResult()) { activityResult -> activityResult.data?.let { @@ -134,6 +168,9 @@ class DownloadModelActivity : ComponentActivity() { ModelsList(viewModel) Spacer(modifier = Modifier.height(4.dp)) ModelURLInput(viewModel) + OutlinedButton(onClick = onHFModelSelectClick) { + Text("Browse from HuggingFace") + } Spacer(modifier = Modifier.height(4.dp)) OutlinedButton( enabled = diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelsViewModel.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelsViewModel.kt index 5215777..b177744 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelsViewModel.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/DownloadModelsViewModel.kt @@ -25,6 +25,13 @@ import android.widget.Toast import androidx.compose.runtime.mutableStateOf import androidx.core.net.toUri import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import androidx.paging.PagingData +import androidx.paging.cachedIn +import io.shubham0204.hf_model_hub_api.HFModelInfo +import io.shubham0204.hf_model_hub_api.HFModelSearch +import io.shubham0204.hf_model_hub_api.HFModelTree +import io.shubham0204.smollmandroid.data.HFModelsAPI import io.shubham0204.smollmandroid.data.LLMModel import io.shubham0204.smollmandroid.data.ModelsDB import io.shubham0204.smollmandroid.ui.components.hideProgressDialog @@ -33,6 +40,9 @@ import io.shubham0204.smollmandroid.ui.components.setProgressDialogTitle import io.shubham0204.smollmandroid.ui.components.showProgressDialog import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import org.koin.core.annotation.Single @@ -44,10 +54,16 @@ import java.nio.file.Paths class DownloadModelsViewModel( val context: Context, val modelsDB: ModelsDB, + val hfModelsAPI: HFModelsAPI, ) : ViewModel() { + private val _modelInfoAndTree = MutableStateFlow>?>(null) + val modelInfoAndTree: StateFlow>?> = _modelInfoAndTree + val selectedModelState = mutableStateOf(null) val modelUrlState = mutableStateOf("") + var viewModelId: String? = null + private val downloadManager = context.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager @@ -71,6 +87,9 @@ class DownloadModelsViewModel( downloadManager.enqueue(request) } + fun getModels(query: String): Flow> = + hfModelsAPI.getModelsList(query).cachedIn(viewModelScope) + /** * Given the model file URI, copy the model file to the app's internal directory. Once copied, * add a new LLMModel entity with modelName=fileName where fileName is the name of the model @@ -112,4 +131,10 @@ class DownloadModelsViewModel( Toast.makeText(context, "Invalid file", Toast.LENGTH_SHORT).show() } } + + fun fetchModelInfoAndTree(modelId: String) { + CoroutineScope(Dispatchers.IO).launch { + _modelInfoAndTree.value = Pair(hfModelsAPI.getModelInfo(modelId), hfModelsAPI.getModelTree(modelId)) + } + } } diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/HFModelSearchScreen.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/HFModelSearchScreen.kt new file mode 100644 index 0000000..c1c3e4b --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/HFModelSearchScreen.kt @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2025 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.model_download + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.ArrowBack +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.material3.Text +import androidx.compose.material3.TextField +import androidx.compose.material3.TopAppBar +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Modifier +import androidx.paging.compose.collectAsLazyPagingItems +import io.shubham0204.hf_model_hub_api.HFModelSearch +import io.shubham0204.smollmandroid.ui.components.AppBarTitleText +import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun HFModelDownloadScreen( + viewModel: DownloadModelsViewModel, + onBackClicked: () -> Unit, + onModelClick: (String) -> Unit, +) { + SmolLMAndroidTheme { + Scaffold( + modifier = Modifier.fillMaxSize(), + topBar = { + TopAppBar( + title = { AppBarTitleText("Browse Models from HuggingFace") }, + navigationIcon = { + IconButton(onClick = { onBackClicked() }) { + Icon( + Icons.AutoMirrored.Filled.ArrowBack, + contentDescription = "Navigate Back", + ) + } + }, + ) + }, + ) { innerPadding -> + Column( + modifier = + Modifier + .padding(innerPadding) + .background(MaterialTheme.colorScheme.background), + ) { + var query by remember { mutableStateOf("") } + TextField( + value = query, + onValueChange = { query = it }, + ) + ModelList(query, viewModel, onModelClick) + } + } + } +} + +@Composable +private fun ModelList( + query: String, + viewModel: DownloadModelsViewModel, + onModelClick: (String) -> Unit, +) { + val models = viewModel.getModels(query).collectAsLazyPagingItems() + LazyColumn { + items(count = models.itemCount) { index -> + models[index]?.modelId?.let { modelId -> + Text( + text = modelId, + modifier = + Modifier.clickable { + onModelClick(modelId) + }, + ) + } + } + } +} + +@Composable +private fun ModelListItem(model: HFModelSearch.ModelSearchResult) { +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/ViewHFModelScreen.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/ViewHFModelScreen.kt new file mode 100644 index 0000000..f549c6a --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/model_download/ViewHFModelScreen.kt @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2025 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.model_download + +import android.text.format.DateUtils +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.AccessTime +import androidx.compose.material.icons.filled.Download +import androidx.compose.material.icons.filled.ThumbUp +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.unit.dp +import androidx.lifecycle.compose.LocalLifecycleOwner +import androidx.lifecycle.compose.collectAsStateWithLifecycle +import io.shubham0204.hf_model_hub_api.HFModelInfo +import io.shubham0204.smollmandroid.ui.components.SmallLabelText +import io.shubham0204.smollmandroid.ui.theme.AppAccentColor +import io.shubham0204.smollmandroid.ui.theme.AppFontFamily +import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme +import java.time.ZoneId + +@Composable +fun ViewHFModelScreen(viewModel: DownloadModelsViewModel) { + viewModel.viewModelId?.let { modelId -> + SmolLMAndroidTheme { + Column(modifier = Modifier.fillMaxSize().background(Color.White)) { + LaunchedEffect(0) { + viewModel.fetchModelInfoAndTree(modelId) + } + val modelInfoAndTree by viewModel.modelInfoAndTree.collectAsStateWithLifecycle(LocalLifecycleOwner.current) + modelInfoAndTree?.let { modelInfoAndTree -> + val modelInfo = modelInfoAndTree.first + val modelFiles = modelInfoAndTree.second + ModelInfoCard(modelInfo) + } + } + } + } +} + +@Composable +private fun ModelInfoCard(modelInfo: HFModelInfo.ModelInfo) { + Column(modifier = Modifier.padding(16.dp)) { + Text( + text = modelInfo.modelId, + style = MaterialTheme.typography.titleMedium, + fontFamily = AppFontFamily, + modifier = Modifier.fillMaxWidth(), + ) + Row { + ModelInfoIconBubble( + icon = Icons.Default.Download, + contentDescription = "Number of downloads", + text = modelInfo.numDownloads.toString(), + ) + Spacer(modifier = Modifier.width(8.dp)) + ModelInfoIconBubble( + icon = Icons.Default.ThumbUp, + contentDescription = "Number of likes", + text = modelInfo.numLikes.toString(), + ) + ModelInfoIconBubble( + icon = Icons.Default.AccessTime, + contentDescription = "Last updated", + text = + DateUtils + .getRelativeTimeSpanString( + modelInfo.lastModified + .atZone(ZoneId.systemDefault()) + .toInstant() + .toEpochMilli(), + ).toString(), + ) + } + } +} + +@Composable +private fun ModelInfoIconBubble( + icon: ImageVector, + contentDescription: String, + text: String, +) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = + Modifier + .padding(4.dp) + .background(Color.White, RoundedCornerShape(4.dp)) + .padding(4.dp), + ) { + Icon( + modifier = Modifier.size(16.dp), + imageVector = icon, + contentDescription = contentDescription, + tint = AppAccentColor, + ) + Spacer(modifier = Modifier.width(2.dp)) + SmallLabelText(text = text) + } +} diff --git a/build.gradle.kts b/build.gradle.kts index c1d1e6f..4d6bf56 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -11,4 +11,6 @@ plugins { alias(libs.plugins.kotlin.compose) apply false alias(libs.plugins.android.library) apply false id("com.google.devtools.ksp") version "2.0.0-1.0.24" apply false -} \ No newline at end of file + alias(libs.plugins.jetbrains.kotlin.jvm) apply false + kotlin("plugin.serialization") version "2.1.0" apply false +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8877666..08685fe 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -14,6 +14,7 @@ navigationComposeVersion = "2.8.3" koin = "3.5.6" koinAnnotations = "1.3.1" objectboxGradlePlugin = "4.0.3" +jetbrainsKotlinJvm = "2.0.0" [libraries] @@ -50,4 +51,5 @@ android-application = { id = "com.android.application", version.ref = "agp" } kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } android-library = { id = "com.android.library", version.ref = "agp" } +jetbrains-kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "jetbrainsKotlinJvm" } diff --git a/hf-model-hub-api/.gitignore b/hf-model-hub-api/.gitignore new file mode 100644 index 0000000..42afabf --- /dev/null +++ b/hf-model-hub-api/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/hf-model-hub-api/build.gradle.kts b/hf-model-hub-api/build.gradle.kts new file mode 100644 index 0000000..17f6eb8 --- /dev/null +++ b/hf-model-hub-api/build.gradle.kts @@ -0,0 +1,26 @@ +plugins { + id("java-library") + alias(libs.plugins.jetbrains.kotlin.jvm) + kotlin("plugin.serialization") version "2.1.0" +} + +val ktorVersion = "3.0.2" +val kotlinSerializationVersion = "1.7.3" + +dependencies { + implementation("io.ktor:ktor-client-core:$ktorVersion") + implementation("io.ktor:ktor-client-okhttp:$ktorVersion") + implementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") + implementation("io.ktor:ktor-serialization-kotlinx-json:$ktorVersion") + testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.9.0") + testImplementation(kotlin("test")) +} + +tasks.test { + useJUnitPlatform() +} + +java { + sourceCompatibility = JavaVersion.VERSION_21 + targetCompatibility = JavaVersion.VERSION_21 +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/CustomDateSerializer.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/CustomDateSerializer.kt new file mode 100644 index 0000000..a782c64 --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/CustomDateSerializer.kt @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import java.time.LocalDateTime +import java.time.ZoneOffset +import java.time.format.DateTimeFormatter + +/** + * A custom serializer implementation for the java.time.LocalDateTime class + */ +class CustomDateSerializer : KSerializer { + override val descriptor: SerialDescriptor = + PrimitiveSerialDescriptor("LocalDateTime", PrimitiveKind.STRING) + private val utcFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + + override fun serialize( + encoder: Encoder, + value: LocalDateTime, + ) { + encoder.encodeString(value.atOffset(ZoneOffset.UTC).format(utcFormatter)) + } + + override fun deserialize(decoder: Decoder): LocalDateTime { + val stringRepr = decoder.decodeString() + return LocalDateTime.parse(stringRepr, utcFormatter) + } +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFEndpoints.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFEndpoints.kt new file mode 100644 index 0000000..c2deef7 --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFEndpoints.kt @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +class HFEndpoints { + companion object { + private const val HF_BASE_ENDPOINT = "https://huggingface.co/api/models" + + val getHFModelsListEndpoint: (() -> String) = { HF_BASE_ENDPOINT } + + val getHFModelTreeEndpoint: ((String) -> String) = { "$HF_BASE_ENDPOINT/$it/tree/main" } + + val getHFModelSpecsEndpoint: ((String) -> String) = { "$HF_BASE_ENDPOINT/$it" } + } +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelInfo.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelInfo.kt new file mode 100644 index 0000000..98ebfd1 --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelInfo.kt @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.request.get +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import java.time.LocalDateTime + +class HFModelInfo( + private val client: HttpClient, +) { + @Serializable + data class ModelInfo( + val _id: String, + val id: String, + val modelId: String, + val author: String, + val private: Boolean, + val disabled: Boolean, + val tags: List, + @SerialName(value = "downloads") val numDownloads: Long, + @SerialName(value = "likes") val numLikes: Long, + @Serializable(with = CustomDateSerializer::class) val lastModified: LocalDateTime, + @Serializable(with = CustomDateSerializer::class) val createdAt: LocalDateTime, + ) + + suspend fun getModelInfo(modelId: String): ModelInfo { + val response = client.get(urlString = HFEndpoints.getHFModelSpecsEndpoint(modelId)) + if (response.status.value != 200) { + throw Exception("Invalid model ID") + } + return response.body() + } +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelSearch.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelSearch.kt new file mode 100644 index 0000000..f376d27 --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelSearch.kt @@ -0,0 +1,99 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.request.get +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import java.time.LocalDateTime + +class HFModelSearch( + private val client: HttpClient, +) { + @Serializable + data class ModelSearchResult( + val _id: String, + val id: String, + @SerialName("likes") val numLikes: Int, + @SerialName("downloads") val numDownloads: Int, + val trendingScore: Int, + @SerialName("private") val isPrivate: Boolean, + val tags: List, + @Serializable(with = CustomDateSerializer::class) val createdAt: LocalDateTime, + val modelId: String, + ) + + enum class ModelSortParam( + val value: String, + ) { + NONE(""), + DOWNLOADS("downloads"), + AUTHOR("author"), + } + + enum class ModelSearchDirection( + val value: Int, + ) { + ASCENDING(1), + DESCENDING(-1), + } + + private var pageURL = HFEndpoints.getHFModelsListEndpoint() + + suspend fun searchModels( + query: String, + author: String, + filter: String, + sort: ModelSortParam = ModelSortParam.NONE, + direction: ModelSearchDirection = ModelSearchDirection.DESCENDING, + limit: Int, + full: Boolean = true, + config: Boolean = true, + ): List { + val response = + if (pageURL == HFEndpoints.getHFModelsListEndpoint()) { + client + .get(HFEndpoints.getHFModelsListEndpoint()) { + url { + parameters.append("search", query) + parameters.append("author", author) + parameters.append("filter", filter) + parameters.append("sort", sort.value) + parameters.append("direction", direction.value.toString()) + parameters.append("limit", limit.toString()) + parameters.append("full", full.toString()) + parameters.append("config", config.toString()) + } + } + } else { + client.get(pageURL) + } + val linkHeader = response.headers["Link"] ?: return emptyList() + pageURL = parseLinkHeader(linkHeader)["next"] ?: return emptyList() + return response.body() + } + + private fun parseLinkHeader(header: String): Map { + val regex = """<(https?:\/\/[^>]+)>;\s+rel="([^"]+)"""".toRegex() + return regex.findAll(header).associate { matchResult -> + val (url, rel) = matchResult.destructured + rel to url + } + } +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelTree.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelTree.kt new file mode 100644 index 0000000..c3b7857 --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModelTree.kt @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.request.get +import kotlinx.serialization.Serializable + +class HFModelTree( + private val client: HttpClient, +) { + @Serializable + data class HFModelFile( + val type: String, + val oid: String, + val size: Long, + val path: String, + ) + + suspend fun getModelFileTree(modelId: String): List { + val response = client.get(urlString = HFEndpoints.getHFModelTreeEndpoint(modelId)) + if (response.status.value != 200) { + throw Exception("Invalid model ID") + } + return response.body() + } +} diff --git a/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModels.kt b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModels.kt new file mode 100644 index 0000000..a7e66ed --- /dev/null +++ b/hf-model-hub-api/src/main/java/io/shubham0204/hf_model_hub_api/HFModels.kt @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.hf_model_hub_api + +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.serialization.kotlinx.json.json +import kotlinx.serialization.json.Json + +class HFModels { + companion object { + private val client: HttpClient = + HttpClient(OkHttp) { + install(ContentNegotiation) { + json( + Json { + ignoreUnknownKeys = true + }, + ) + } + } + + val info = HFModelInfo(client) + val tree = HFModelTree(client) + val search = HFModelSearch(client) + } +} diff --git a/hf-model-hub-api/src/test/java/HFModelTests.kt b/hf-model-hub-api/src/test/java/HFModelTests.kt new file mode 100644 index 0000000..c5d5708 --- /dev/null +++ b/hf-model-hub-api/src/test/java/HFModelTests.kt @@ -0,0 +1,98 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import io.ktor.client.HttpClient +import io.ktor.client.engine.okhttp.OkHttp +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.serialization.kotlinx.json.json +import io.shubham0204.hf_model_hub_api.HFModelInfo +import io.shubham0204.hf_model_hub_api.HFModelSearch +import io.shubham0204.hf_model_hub_api.HFModelTree +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.Json +import org.junit.jupiter.api.assertThrows +import kotlin.test.Test +import kotlin.test.assertEquals + +class HFModelTests { + private val huggingFaceModelId = "QuantFactory/BharatGPT-3B-Indic-GGUF" + private val huggingFaceModelOrg = "QuantFactory" + private val invalidHuggingFaceModelId = "shubham0204/BharatGPT-3B-Indic-GGUF" + private val client: HttpClient = + HttpClient(OkHttp) { + install(ContentNegotiation) { + json( + Json { + // ignore unknown keys from the response + isLenient = true + ignoreUnknownKeys = true + }, + ) + } + } + + @Test + fun testModelInfo_works() = + runTest { + val modelInfo = HFModelInfo(client).getModelInfo(huggingFaceModelId) + assertEquals(huggingFaceModelId, modelInfo.modelId) + assertEquals(huggingFaceModelOrg, modelInfo.author) + assert(modelInfo.tags.isNotEmpty()) + assert(modelInfo.numDownloads > 0) + assert(modelInfo.numLikes > 0) + } + + @Test + fun testModelInfoInvalidID_throws() = + runTest { + val exception = + assertThrows { + HFModelInfo(client).getModelInfo(invalidHuggingFaceModelId) + } + assertEquals("Invalid model ID", exception.message) + } + + @Test + fun testModelTree_works() = + runTest { + val modelTree = HFModelTree(client).getModelFileTree(huggingFaceModelId) + assert(modelTree.isNotEmpty()) + } + + @Test + fun testModelTreeInvalidID_throws() = + runTest { + val exception = + assertThrows { + HFModelTree(client).getModelFileTree(invalidHuggingFaceModelId) + } + assertEquals("Invalid model ID", exception.message) + } + + @Test + fun testModelSearch_works() = + runTest { + val results = + HFModelSearch(client).searchModels( + query = "gguf", + author = "QuantFactory", + filter = "conversational", + limit = 5, + ) + assert(results.isNotEmpty()) + assert(results.size <= 5) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index ba1424a..5fecc55 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,5 +1,8 @@ import java.net.URI +include(":hf-model-hub-api") + + pluginManagement { repositories { google {