Skip to content

Commit

Permalink
Merge pull request #3201 from Juude/master
Browse files Browse the repository at this point in the history
optimized for deepseek r1
  • Loading branch information
jxt1234 authored Feb 6, 2025
2 parents b100f52 + 4d5cf7e commit 6a5abf5
Show file tree
Hide file tree
Showing 25 changed files with 505 additions and 47 deletions.
6 changes: 6 additions & 0 deletions project/android/apps/MnnLlmApp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ This is our full multimodal language model (LLM) Android app
```

# Releases
## Version 0.2
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_llm_app_debug_0_2_0.apk)
+ Optimized for DeepSeek R1 1.5B
+ Added support for Markdown
+ Resolved several bugs and improved stability

## Version 0.1
+ Click here to [download](https://meta.alicdn.com/data/mnn/mnn_llm_app_debug_0_1.apk)
+ this is our first public released version; you can :
Expand Down
6 changes: 6 additions & 0 deletions project/android/apps/MnnLlmApp/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
```

# Releases

+ 点击这里[下载](https://meta.alicdn.com/data/mnn/mnn_llm_app_debug_0_2_0.apk)
+ 针对 DeepSeek R1 1.5B 进行了优化
+ 新增支持 Markdown 格式
+ 修复了一些已知问题

## Version 0.1
+ 点击这里[下载](https://meta.alicdn.com/data/mnn/mnn_llm_app_debug_0_1.apk)
+ 这是我们的首个公开发布版本,您可以:
Expand Down
5 changes: 3 additions & 2 deletions project/android/apps/MnnLlmApp/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ android {
applicationId "com.alibaba.mnnllm.android"
minSdk 26
targetSdk 34
versionCode 1
versionName "0.1"
versionCode 2
versionName "0.2"

testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
Expand Down Expand Up @@ -82,4 +82,5 @@ dependencies {
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.2.1'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.6.1'
implementation "io.noties.markwon:core:4.6.2"
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@
android:foregroundServiceType="dataSync"
/>

<activity android:name=".chat.SelectTextActivity"
android:theme="@style/Theme.MnnLlmApp"
android:exported="false">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.DEFAULT" />
</intent-filter>
</activity>

<activity android:name=".MainActivity"
android:exported="true"
android:configChanges="orientation|screenSize">
Expand Down
27 changes: 20 additions & 7 deletions project/android/apps/MnnLlmApp/app/src/main/cpp/llm_mnn_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MNN_PUBLIC LlmStreamBuffer : public std::streambuf {
using PromptItem = std::pair<std::string, std::string>;
static std::vector<PromptItem> history{};
static bool stop_requested = false;
static bool is_r1 = false;

int utf8CharLength(unsigned char byte) {
if ((byte & 0x80) == 0) return 1;
Expand Down Expand Up @@ -81,6 +82,13 @@ class Utf8StreamProcessor {
std::function<void(const std::string&)> callback;
};

const char* getUserString(const char* user_content) {
if (is_r1) {
return ("<|User|>" + std::string(user_content) + "<|Assistant|><think>\n").c_str();
} else {
return user_content;
}
}

extern "C" {

Expand All @@ -96,7 +104,9 @@ JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved) {
JNIEXPORT jlong JNICALL Java_com_alibaba_mnnllm_android_ChatSession_initNative(JNIEnv* env, jobject thiz, jstring modelDir,
jboolean use_tmp_path,
jobject chat_history,
jboolean is_diffusion) {
jboolean is_diffusion,
jboolean r1) {
is_r1 = r1;
const char* model_dir = env->GetStringUTFChars(modelDir, 0);
MNN_DEBUG("createLLM BeginLoad %s", model_dir);
if (is_diffusion) {
Expand All @@ -108,12 +118,12 @@ JNIEXPORT jlong JNICALL Java_com_alibaba_mnnllm_android_ChatSession_initNative(J
auto model_dir_str = std::string(model_dir);
std::string model_dir_parent = model_dir_str.substr(0, model_dir_str.find_last_of('/'));
std::string temp_dir = model_dir_parent + R"(/tmp")";
auto extra_config = R"({"tmp_path":")" + temp_dir + R"(,"reuse_kv":true, "backend_type":"opencl"})";
auto extra_config = R"({"tmp_path":")" + temp_dir + R"(,"reuse_kv":true, "backend_type":"cpu"})";
MNN_DEBUG("extra_config: %s", extra_config.c_str());
llm->set_config(temp_dir);
}
history.clear();
history.emplace_back("system", "You are a helpful assistant.");
history.emplace_back("system", is_r1 ? "<|begin_of_sentence|>You are a helpful assistant." : "You are a helpful assistant.");
if (chat_history != nullptr) {
jclass listClass = env->GetObjectClass(chat_history);
jmethodID sizeMethod = env->GetMethodID(listClass, "size", "()I");
Expand All @@ -122,7 +132,7 @@ JNIEXPORT jlong JNICALL Java_com_alibaba_mnnllm_android_ChatSession_initNative(J
for (jint i = 0; i < listSize; i++) {
jobject element = env->CallObjectMethod(chat_history, getMethod, i);
const char *elementCStr = env->GetStringUTFChars((jstring)element, nullptr);
history.emplace_back(i == 0 ? "user" : "assistant",elementCStr);
history.emplace_back(i == 0 ? "user" : "assistant",i == 0 ? getUserString(elementCStr) : elementCStr);
env->ReleaseStringUTFChars((jstring)element, elementCStr);
env->DeleteLocalRef(element);
}
Expand All @@ -132,6 +142,8 @@ JNIEXPORT jlong JNICALL Java_com_alibaba_mnnllm_android_ChatSession_initNative(J
return reinterpret_cast<jlong>(llm);
}



JNIEXPORT jobject JNICALL Java_com_alibaba_mnnllm_android_ChatSession_submitNative(JNIEnv* env, jobject thiz,
jlong llmPtr, jstring inputStr,jboolean keepHistory,
jobject progressListener) {
Expand Down Expand Up @@ -161,18 +173,19 @@ JNIEXPORT jobject JNICALL Java_com_alibaba_mnnllm_android_ChatSession_submitNati
}
if (progressListener && onProgressMethod) {
jstring javaString = is_eop ? nullptr : env->NewStringUTF(utf8Char.c_str());
stop_requested = is_eop || env->CallBooleanMethod(progressListener, onProgressMethod, javaString);
jboolean user_stop_requested = env->CallBooleanMethod(progressListener, onProgressMethod, javaString);
stop_requested = is_eop || user_stop_requested;
env->DeleteLocalRef(javaString);
}
});
LlmStreamBuffer stream_buffer{[&processor](const char* str, size_t len){
processor.processStream(str, len);
}};
std::ostream output_ostream(&stream_buffer);
history.emplace_back("user", input_str);
history.emplace_back("user", getUserString(input_str));
MNN_DEBUG("submitNative history count %zu", history.size());
llm->response(history, &output_ostream, "<eop>", 1);
while (!stop_requested && llm->getState().gen_seq_len_ < 512) {
while (!stop_requested) {
llm->generate(1);
}
auto& state = llm->getState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,28 @@ public class ChatService{

private static ChatService instance;

public synchronized ChatSession createSession(String modelDir,
public synchronized ChatSession createSession(String modelName,
String modelDir,
boolean useTmpPath,
String sessionId,
List<ChatDataItem> chatDataItemList) {
if (TextUtils.isEmpty(sessionId)) {
sessionId = String.valueOf(System.currentTimeMillis());
}
ChatSession session = new ChatSession(sessionId, modelDir, useTmpPath, chatDataItemList);
ChatSession session = new ChatSession(modelName, sessionId, modelDir, useTmpPath, chatDataItemList);
sessionMap.put(sessionId, session);
return session;
}

public synchronized ChatSession createDiffusionSession(
String modelName,
String modelDir,
String sessionId,
List<ChatDataItem> chatDataItemList) {
if (TextUtils.isEmpty(sessionId)) {
sessionId = String.valueOf(System.currentTimeMillis());
}
ChatSession session = new ChatSession(sessionId, modelDir, false, chatDataItemList, true);
ChatSession session = new ChatSession(modelName, sessionId, modelDir, false, chatDataItemList, true);
sessionMap.put(sessionId, session);
return session;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ public class ChatSession implements Serializable {

private boolean useTmpPath;
private boolean keepHistory;
private String modelName;

public ChatSession(String sessionId, String configPath, boolean useTmpPath, List<ChatDataItem> history) {
this(sessionId, configPath, useTmpPath, history, false);
public ChatSession(String modelName, String sessionId, String configPath, boolean useTmpPath, List<ChatDataItem> history) {
this(modelName, sessionId, configPath, useTmpPath, history, false);
}

public ChatSession(String sessionId, String configPath, boolean useTmpPath, List<ChatDataItem> history, boolean isDiffusion) {
public ChatSession(String modelname, String sessionId, String configPath, boolean useTmpPath, List<ChatDataItem> history, boolean isDiffusion) {
this.modelName = modelname;
this.sessionId = sessionId;
this.configPath = configPath;
this.savedHistory = history;
Expand All @@ -49,7 +51,7 @@ public void load() {
if (this.savedHistory != null && !this.savedHistory.isEmpty()) {
historyStringList = this.savedHistory.stream().map(ChatDataItem::getText).collect(Collectors.toList());
}
nativePtr = initNative(configPath, useTmpPath, historyStringList, isDiffusion);
nativePtr = initNative(configPath, useTmpPath, historyStringList, isDiffusion, ModelUtils.isR1Model(modelName));
}

public List<ChatDataItem> getSavedHistory() {
Expand Down Expand Up @@ -122,7 +124,7 @@ protected void finalize() throws Throwable {
release();
}

public native long initNative(String configPath, boolean useTmpPath, List<String> history, boolean isDiffusion);
public native long initNative(String configPath, boolean useTmpPath, List<String> history, boolean isDiffusion, boolean isR1);
private native HashMap<String, Object> submitNative(long instanceId, String input, boolean keepHistory, GenerateProgressListener listener);

private native HashMap<String, Object> submitDiffusionNative(long instanceId, String input, String outputPath, GenerateProgressListener progressListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.alibaba.mnnllm.android.R;
import com.alibaba.mnnllm.android.history.ChatHistoryFragment;
import com.alibaba.mnnllm.android.modelist.ModelListFragment;
import com.alibaba.mnnllm.android.utils.GithubUtils;
import com.alibaba.mnnllm.android.utils.ModelUtils;
import com.google.android.material.navigation.NavigationView;
import com.techiness.progressdialoglibrary.ProgressDialog;
Expand All @@ -33,7 +34,6 @@ public class MainActivity extends AppCompatActivity {

public static final String TAG = "MainActivity";
private ProgressDialog progressDialog;
private final String repoGithubUrl = "https://github.com/alibaba/MNN";
private DrawerLayout drawerLayout;
private ActionBarDrawerToggle toggle;
private ModelListFragment modelListFragment;
Expand Down Expand Up @@ -99,10 +99,6 @@ public boolean onOptionsItemSelected(@NonNull MenuItem item) {
return super.onOptionsItemSelected(item);
}

private void openInBrowser(String url) {
Intent intent = new Intent(Intent.ACTION_VIEW, Uri.parse(url));
startActivity(intent);
}

public void runModel(String destModelDir, String modelName, String sessionId) {
ModelDownloadManager.getInstance(this).pauseAllDownloads();
Expand Down Expand Up @@ -138,11 +134,11 @@ public void runModel(String destModelDir, String modelName, String sessionId) {
}

public void onStarProject(View view) {
openInBrowser(this.repoGithubUrl);
GithubUtils.starProject(this);
}

public void onReportIssue(View view) {
openInBrowser(this.repoGithubUrl + "/issues");
GithubUtils.reportIssue(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ private void setupSession() {
}
if (ModelUtils.isDiffusionModel(modelName)) {
String diffusionDir = getIntent().getStringExtra("diffusionDir");
chatSession = chatService.createDiffusionSession(diffusionDir, chatSessionId, chatDataItemList);
chatSession = chatService.createDiffusionSession(modelName, diffusionDir, chatSessionId, chatDataItemList);
} else {
String configFilePath = getIntent().getStringExtra("configFilePath");
chatSession = chatService.createSession(configFilePath, true, chatSessionId, chatDataItemList);
chatSession = chatService.createSession(modelName, configFilePath, true, chatSessionId, chatDataItemList);
}
chatSessionId = chatSession.getSessionId();
chatSession.setKeepHistory(!ModelUtils.isVisualModel(modelName) && !ModelUtils.isAudioModel(modelName));
Expand Down Expand Up @@ -497,7 +497,6 @@ private void addResponsePlaceholder() {
private void submitRequest(String input) {
isUserScrolling = false;
stopGenerating = false;
StringBuilder stringBuilder = new StringBuilder();
ChatDataItem chatDataItem = adapter.getRecentItem();
HashMap<String, Object> benchMarkResult;
if (ModelUtils.isDiffusionModel(this.modelName)) {
Expand All @@ -513,12 +512,16 @@ private void submitRequest(String input) {
return false;
});
} else {
GenerateResultProcessor generateResultProcessor = ModelUtils.isR1Model(this.modelName) ?
new GenerateResultProcessor.R1GenerateResultProcessor(getString(R.string.r1_thinking_message),
getString(R.string.r1_think_complete_template)) :
new GenerateResultProcessor.NormalGenerateResultProcessor();
generateResultProcessor.generateBegin();
benchMarkResult = chatSession.generate(input, progress -> {
if (progress != null) {
stringBuilder.append(progress);
chatDataItem.setText(stringBuilder.toString());
runOnUiThread(() -> updateAssistantResponse(chatDataItem));
}
generateResultProcessor.process(progress);
chatDataItem.setDisplayText(generateResultProcessor.getDisplayResult());
chatDataItem.setText(generateResultProcessor.getRawResult());
runOnUiThread(() -> updateAssistantResponse(chatDataItem));
if (stopGenerating) {
Log.d(TAG, "stopGenerating requeted");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.alibaba.mnnllm.android.chat;

import android.net.Uri;
import android.text.TextUtils;

import java.io.File;

Expand All @@ -18,6 +19,8 @@ public class ChatDataItem {

private String benchmarkInfo;

private String displayText;

private float audioDuration;

public ChatDataItem(String time, int type, String text) {
Expand Down Expand Up @@ -101,5 +104,13 @@ public String getAudioPath() {
}
return null;
}

public String getDisplayText() {
return TextUtils.isEmpty(this.displayText) ? this.text : this.displayText;
}

public void setDisplayText(String displayText) {
this.displayText = displayText;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import android.database.Cursor;
import android.database.sqlite.SQLiteDatabase;
import android.net.Uri;
import android.text.TextUtils;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -75,6 +76,7 @@ public void addChatData(String sessionId, ChatDataItem chatDataItem) {
values.put(ChatDatabaseHelper.COLUMN_AUDIO_URI, (String) null);
}
values.put(ChatDatabaseHelper.COLUMN_AUDIO_DURATION, chatDataItem.getAudioDuration());
values.put(ChatDatabaseHelper.COLUMN_DISPLAY_TEXT, chatDataItem.getDisplayText());
db.insert(ChatDatabaseHelper.TABLE_CHAT, null, values);
db.close();
}
Expand All @@ -97,10 +99,14 @@ public List<ChatDataItem> getChatDataBySession(String sessionId) {
String imageUriStr = cursor.getString(cursor.getColumnIndex(ChatDatabaseHelper.COLUMN_IMAGE_URI));
String audioUriStr = cursor.getString(cursor.getColumnIndex(ChatDatabaseHelper.COLUMN_AUDIO_URI));
float audioDuration = cursor.getFloat(cursor.getColumnIndex(ChatDatabaseHelper.COLUMN_AUDIO_DURATION));
String displayText = cursor.getString(cursor.getColumnIndex(ChatDatabaseHelper.COLUMN_DISPLAY_TEXT));
ChatDataItem chatDataItem = new ChatDataItem(time, type, text);
if (imageUriStr != null) {
chatDataItem.setImageUri(Uri.parse(imageUriStr));
}
if (!TextUtils.isEmpty(displayText)) {
chatDataItem.setDisplayText(displayText);
}
if (audioUriStr != null) {
chatDataItem.setAudioUri(Uri.parse(audioUriStr));
chatDataItem.setAudioDuration(audioDuration);
Expand Down
Loading

0 comments on commit 6a5abf5

Please sign in to comment.