-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d3a610c
commit c87c1b1
Showing
7 changed files
with
164 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 0 additions & 106 deletions
106
src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.java
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 0 additions & 40 deletions
40
src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java
This file was deleted.
Oops, something went wrong.
100 changes: 100 additions & 0 deletions
100
src/main/kotlin/ee/carlrobert/codegpt/settings/service/llama/form/DownloadModelAction.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package ee.carlrobert.codegpt.settings.service.llama.form | ||
|
||
import com.intellij.openapi.actionSystem.AnAction | ||
import com.intellij.openapi.actionSystem.AnActionEvent | ||
import com.intellij.openapi.diagnostic.Logger | ||
import com.intellij.openapi.progress.ProgressIndicator | ||
import com.intellij.openapi.progress.ProgressManager | ||
import com.intellij.openapi.progress.Task | ||
import com.intellij.openapi.project.Project | ||
import ee.carlrobert.codegpt.CodeGPTBundle | ||
import ee.carlrobert.codegpt.completions.HuggingFaceModel | ||
import ee.carlrobert.codegpt.util.DownloadingUtil | ||
import ee.carlrobert.codegpt.util.file.FileUtil.copyFileWithProgress | ||
import java.io.IOException | ||
import java.util.concurrent.Executors | ||
import java.util.concurrent.ScheduledFuture | ||
import java.util.concurrent.TimeUnit | ||
import java.util.function.Consumer | ||
import javax.swing.DefaultComboBoxModel | ||
|
||
class DownloadModelAction( | ||
private val onDownload: Consumer<ProgressIndicator>, | ||
private val onDownloaded: Runnable, | ||
private val onFailed: Consumer<Exception>, | ||
private val onUpdateProgress: Consumer<String>, | ||
private val comboBoxModel: DefaultComboBoxModel<HuggingFaceModel> | ||
) : AnAction() { | ||
|
||
override fun actionPerformed(e: AnActionEvent) { | ||
ProgressManager.getInstance().run(DownloadBackgroundTask(e.project)) | ||
} | ||
|
||
internal inner class DownloadBackgroundTask(project: Project?) : Task.Backgroundable( | ||
project, | ||
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"), | ||
true | ||
) { | ||
override fun run(indicator: ProgressIndicator) { | ||
val model = comboBoxModel.selectedItem as HuggingFaceModel | ||
val urls = model.fileURLs | ||
val numberOfFiles = urls.size | ||
var errorOccured = false | ||
for (i in 1..numberOfFiles + 1) { | ||
if (errorOccured || indicator.isCanceled) { | ||
break | ||
} | ||
val executorService = Executors.newSingleThreadScheduledExecutor() | ||
var progressUpdateScheduler: ScheduledFuture<*>? = null | ||
val url = urls[i - 1] | ||
|
||
try { | ||
onDownload.accept(indicator) | ||
|
||
indicator.isIndeterminate = false | ||
indicator.text = String.format( | ||
CodeGPTBundle.get( | ||
"settingsConfigurable.service.llama.progress.downloadingModelIndicator.text" | ||
), | ||
model.fileNames[i - 1] | ||
) | ||
|
||
val fileSize = url.openConnection().contentLengthLong | ||
val bytesRead = longArrayOf(0) | ||
val startTime = System.currentTimeMillis() | ||
|
||
progressUpdateScheduler = executorService.scheduleAtFixedRate( | ||
{ | ||
onUpdateProgress.accept( | ||
DownloadingUtil.getFormattedDownloadProgress( | ||
i, | ||
numberOfFiles, | ||
startTime, | ||
fileSize, | ||
bytesRead[0] | ||
) | ||
) | ||
}, | ||
0, 1, TimeUnit.SECONDS | ||
) | ||
copyFileWithProgress(model.fileNames[i - 1], url, bytesRead, fileSize, indicator) | ||
} catch (ex: IOException) { | ||
LOG.error("Unable to download", ex, url.toString()) | ||
onFailed.accept(ex) | ||
errorOccured = true | ||
} finally { | ||
progressUpdateScheduler?.cancel(true) | ||
executorService.shutdown() | ||
} | ||
} | ||
} | ||
|
||
override fun onSuccess() { | ||
onDownloaded.run() | ||
} | ||
} | ||
|
||
companion object { | ||
private val LOG = Logger.getInstance(DownloadModelAction::class.java) | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
src/main/kotlin/ee/carlrobert/codegpt/util/DownloadingUtil.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package ee.carlrobert.codegpt.util | ||
|
||
import ee.carlrobert.codegpt.util.file.FileUtil.convertFileSize | ||
|
||
object DownloadingUtil { | ||
private const val BYTES_IN_MB = 1024 * 1024 | ||
|
||
fun getFormattedDownloadProgress( | ||
fileNumber: Int, fileCount: Int, startTime: Long, | ||
fileSize: Long, bytesRead: Long | ||
): String { | ||
val timeElapsed = System.currentTimeMillis() - startTime | ||
|
||
val speed = (bytesRead.toDouble() / timeElapsed) * 1000 / BYTES_IN_MB | ||
val percent = bytesRead.toDouble() / fileSize * 100 | ||
val downloadedMB = bytesRead.toDouble() / BYTES_IN_MB | ||
val totalMB = fileSize.toDouble() / BYTES_IN_MB | ||
val remainingMB = totalMB - downloadedMB | ||
|
||
return String.format( | ||
"File %d/%d: %s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s", | ||
fileNumber, | ||
fileCount, | ||
convertFileSize(downloadedMB.toLong() * BYTES_IN_MB), | ||
convertFileSize(totalMB.toLong() * BYTES_IN_MB), | ||
percent, | ||
speed, | ||
getTimeLeftFormattedString(speed, remainingMB) | ||
) | ||
} | ||
|
||
private fun getTimeLeftFormattedString(speed: Double, remainingMB: Double): String { | ||
val timeLeftSec = if (speed > 0) remainingMB / speed else 0.0 | ||
val hours = (timeLeftSec / 3600).toLong() | ||
val minutes = ((timeLeftSec % 3600) / 60).toLong() | ||
val seconds = (timeLeftSec % 60).toLong() | ||
|
||
return String.format("%02d:%02d:%02d", hours, minutes, seconds) | ||
} | ||
} |