Skip to content

Commit

Permalink
Merge pull request #12 from mutablelogic/dev
Browse files Browse the repository at this point in the history
Added system prompt
  • Loading branch information
djthorpe authored Feb 7, 2025
2 parents f0df9bb + 5b4abc6 commit f880468
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions cmd/llm/chat2.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ import (
// TYPES

type Chat2Cmd struct {
Model string `arg:"" help:"Model name"`
Token string `env:"TELEGRAM_TOKEN" help:"Telegram token" required:""`
Model string `arg:"" help:"Model name"`
TelegramToken string `env:"TELEGRAM_TOKEN" help:"Telegram token" required:""`
System string `flag:"system" help:"Set the system prompt"`
}

type Server struct {
sync.RWMutex
*telegram.Client

// Model and toolkit
toolkit llm.ToolKit
model llm.Model
toolkit llm.ToolKit
opts []llm.Opt

// Map of active sessions
sessions map[string]llm.Context
Expand All @@ -35,11 +37,15 @@ type Server struct {
////////////////////////////////////////////////////////////////////////////////
// LIFECYCLE

func NewTelegramServer(token string, model llm.Model, toolkit llm.ToolKit, opts ...telegram.Opt) (*Server, error) {
func NewTelegramServer(token string, model llm.Model, system string, toolkit llm.ToolKit, opts ...telegram.Opt) (*Server, error) {
server := new(Server)
server.sessions = make(map[string]llm.Context)
server.model = model
server.toolkit = toolkit
server.opts = []llm.Opt{
llm.WithToolKit(toolkit),
llm.WithSystemPrompt(system),
}

// Create a new telegram client
opts = append(opts, telegram.WithCallback(server.receive))
Expand All @@ -58,12 +64,12 @@ func NewTelegramServer(token string, model llm.Model, toolkit llm.ToolKit, opts

func (cmd *Chat2Cmd) Run(globals *Globals) error {
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
server, err := NewTelegramServer(cmd.Token, model, globals.toolkit, telegram.WithDebug(globals.Debug))
server, err := NewTelegramServer(cmd.TelegramToken, model, cmd.System, globals.toolkit, telegram.WithDebug(globals.Debug))
if err != nil {
return err
}

log.Printf("Running Telegram bot %q\n", server.Client.Name())
log.Printf("Running Telegram bot %q with model %q\n", server.Client.Name(), model.Name())

var result error
var wg sync.WaitGroup
Expand Down Expand Up @@ -103,7 +109,7 @@ func (telegram *Server) Purge() {
telegram.Lock()
defer telegram.Unlock()
for user, session := range telegram.sessions {
if session.SinceLast() > 10*time.Minute {
if session.SinceLast() > 5*time.Minute {
log.Printf("Purging session for %q\n", user)
delete(telegram.sessions, user)
}
Expand All @@ -116,10 +122,7 @@ func (telegram *Server) session(user string) llm.Context {
if session, exists := telegram.sessions[user]; exists {
return session
}
session := telegram.model.Context(
llm.WithToolKit(telegram.toolkit),
llm.WithSystemPrompt("Please reply to messages in markdown format."),
)
session := telegram.model.Context(telegram.opts...)
telegram.sessions[user] = session
return session
}
Expand All @@ -130,7 +133,6 @@ func (telegram *Server) receive(ctx context.Context, msg telegram.Message) error

// Process the message
text := msg.Text()
text += "\n\nPlease reply in markdown format."
if err := session.FromUser(ctx, text); err != nil {
return err
}
Expand All @@ -144,7 +146,7 @@ func (telegram *Server) receive(ctx context.Context, msg telegram.Message) error
if text := session.Text(0); text != "" {
msg.Reply(ctx, text, false)
} else {
msg.Reply(ctx, "_Gathering information_", true)
msg.Reply(ctx, "Gathering information", true)
}

results, err := telegram.toolkit.Run(ctx, calls...)
Expand Down

0 comments on commit f880468

Please sign in to comment.