Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add stopword checker + iterable generate function #106

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions rwkv_pip_package/src/rwkv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,38 @@
import torch
from torch.nn import functional as F


def end_overlap(a, b):
for i in reversed(range(1, len(a) + 1)):
if b.startswith(a[-i:]):
return i
return 0

class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256):
def __init__(self,
temperature=1.0,
top_p=0.85,
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
token_ban=None,
token_stop=None,
stop_words=None,
chunk_len=256
):

token_ban = token_ban or []
token_stop = token_stop or []
stop_words = stop_words or []

self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.stop_words = stop_words # stop generation whenever you see any token here
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)

class PIPELINE():
Expand Down Expand Up @@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)

def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):

def generate(self, *args, callback=None, **kwargs):
outstr = []
for delta in self.igenerate(*args, **kwargs):
outstr += [delta]
if callback:
callback(delta)
return ''.join(outstr)

def igenerate(self, ctx, token_count=100, args=PIPELINE_ARGS(), state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}

stopword_checker = self.check_stopwords(args.stop_words)
next(stopword_checker)
for i in range(token_count):

# forward & adjust prob.
Expand All @@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st

# output
tmp = self.decode(all_tokens[out_last:])
if len(all_tokens)==1:
tmp = tmp[1:] # strip leading space
if tmp == '':
continue
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp

try:
tmp = stopword_checker.send(tmp)
except StopIteration:
break
out_last = i + 1
return out_str

if tmp is None:
continue
yield tmp
out_str += tmp
out_last = i + 1

@staticmethod
def check_stopwords(stop_words):

longest_stopword = 0 if len(stop_words)==0 else max(map(len, stop_words))
chunk = ""
delta = True
# yield
to_yield = None
while delta:
delta = yield to_yield
chunk = chunk + delta

if longest_stopword == 0:
# nothing to check just passthrough
to_yield = delta
continue
if chunk == '':
to_yield = None
continue
if any(map(lambda stop_word: chunk.startswith(stop_word), stop_words)):
return

if start_idx := max(map(lambda stop_word: end_overlap(chunk, stop_word), stop_words)):
if start_idx > longest_stopword:
start_idx = longest_stopword # can no longer be a stopword so cut it down
good, chunk = chunk[:-start_idx], chunk[-start_idx:]
if good:
to_yield = good
continue

to_yield = None
continue

out = chunk
chunk = ""
to_yield = out