337 lines
13 KiB
Python
337 lines
13 KiB
Python
# mill.py, Markdown interface for llama.cpp
|
|
# Copyright (C) 2024 unworriedsafari <unworriedsafari@tilde.club>
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>
|
|
|
|
"""
|
|
## `llama.cpp` tutorial
|
|
|
|
This section describes the `llama.cpp` LLM-engine module of `mill.py`.
|
|
|
|
|
|
### LLM variables
|
|
|
|
`suppress eos`
|
|
|
|
Some models perform better if the EOS is part of the prompt. `llama.cpp` models
|
|
have a setting `add_eos_token` that seems to mean 'please add the EOS to the
|
|
generated text.' `mill.py` respects this setting, and adds the EOS in that case
|
|
if the model generates it, _unless_ you declare the LLM variable `suppress eos`
|
|
in the document. In that case `mill.py` will not add the EOS token if the model
|
|
generates it.
|
|
|
|
Other LLM variables are simply passed on to `llama.cpp` as command-line
|
|
arguments. A variable with an empty value is passed without a value (i.e. as a
|
|
flag). There are a couple of LLM variables that are reserved for `mill.py`, so
|
|
you cannot use them. These are:
|
|
|
|
- `--file` for the input prompt
|
|
- `--prompt-cache` for the prompt cache to use
|
|
- `--prompt-cache-all` so that generated text is also cached
|
|
|
|
Using these variables results in an error.
|
|
|
|
|
|
### Environment variables
|
|
|
|
Apart from LLM variables, there are also a few environment variables that
|
|
influence the behavior of the `llama.cpp` module.
|
|
|
|
|
|
`MILL_LLAMACPP_MAIN`
|
|
|
|
This variable is required and must be set to the path of `llama.cpp/main`. It
|
|
can be your own script too. As long as:
|
|
|
|
1. The script can accept arguments that are passed from `mill.py`.
|
|
2. The standard output consists of the input prompt followed by the generated
|
|
text.
|
|
3. The error output contains error output generated by `llama.cpp`. This is
|
|
used by `mill.py` to extract some settings from the model's metadata
|
|
such as BOS and EOS tokens and whether or not to add them in the right
|
|
places.
|
|
|
|
|
|
`MILL_LLAMACPP_CACHE_DIR`
|
|
|
|
Path to the directory where the prompt caches are stored. By default this is
|
|
the OS's temporary-files directory. Note: prompt caches can be large files and
|
|
`mill.py` does not automatically clean them. You can recognize the files by the
|
|
extension `.promptcache`.
|
|
|
|
|
|
`MILL_LLAMACPP_TIMEOUT`
|
|
|
|
The maximum number of seconds to wait for the `llama.cpp/main` process to
|
|
complete. Default is 600.
|
|
|
|
|
|
### Prompt caching
|
|
|
|
For each invocation, a prompt cache is generated. `mill.py` searches for a
|
|
matching prompt cache after parsing.
|
|
"""
|
|
|
|
import contextlib, hashlib, os, pathlib, re, shlex, shutil, subprocess, sys
|
|
import tempfile
|
|
|
|
|
|
def generate(llm_vars, prompt):
|
|
return LLM(llm_vars, prompt)
|
|
|
|
|
|
class LLM(contextlib.AbstractContextManager):
|
|
def __init__(self, llm_vars, prompt):
|
|
self.returncode = 0
|
|
self._add_eos = False
|
|
self._bos = ''
|
|
self._buffer = ''
|
|
self._eos = ''
|
|
self._generated_text = ''
|
|
self._llama_process = None
|
|
self._llm_vars = llm_vars
|
|
self._prompt = prompt
|
|
self._promptcache_dest = None
|
|
self._prompt_found = False
|
|
self._tmp_prompt_cache = ''
|
|
|
|
reserved_vars = set(['--file', '--prompt-cache', '--prompt-cache-all'])
|
|
illegal_vars = set(self._llm_vars.keys()) & reserved_vars
|
|
if illegal_vars:
|
|
raise RuntimeError(f'variables not allowed: {illegal_vars}')
|
|
|
|
|
|
def __enter__(self):
|
|
## Phase 1: input is processed
|
|
# Search for prompt cache
|
|
prompt_cache = ''
|
|
cached_prompt = self._prompt
|
|
while cached_prompt and not prompt_cache:
|
|
cache_path = self._promptcache_path(cached_prompt,
|
|
self._llm_vars.get('--model'),
|
|
self._llm_vars.get('--ctx-size'))
|
|
if cache_path.is_file():
|
|
print('[DEBUG] cache exists!', file=sys.stderr)
|
|
prompt_cache = str(cache_path)
|
|
else:
|
|
cached_prompt = cached_prompt[:-1]
|
|
|
|
# Get final variable values
|
|
ctx_size = self._llm_vars.get('--ctx-size', '')
|
|
model = self._llm_vars.get('--model', '')
|
|
|
|
## Phase 2: get model metadata
|
|
metadata_cmd = [os.environ[f'MILL_LLAMACPP_MAIN'], '--n-predict', '0']
|
|
if model:
|
|
metadata_cmd += ['--model', model]
|
|
|
|
model_metadata = subprocess.run(
|
|
[shlex.quote(param) for param in metadata_cmd],
|
|
capture_output=True, text=True)
|
|
|
|
# Return any errors
|
|
if model_metadata.returncode != 0:
|
|
print(model_metadata.stderr, file=sys.stderr)
|
|
self.returncode = model_metadata.returncode
|
|
return self
|
|
|
|
model_metadata_lines = model_metadata.stderr.split(os.linesep)
|
|
|
|
# We need to add BOS/EOS if model doesn't do it
|
|
add_bos = not self._get_metadata_setting(model_metadata_lines,
|
|
'add_bos_token',
|
|
False)
|
|
self._bos = self._get_metadata_setting(model_metadata_lines,
|
|
'BOS token', '')
|
|
self._add_eos = not self._get_metadata_setting(model_metadata_lines,
|
|
'add_eos_token',
|
|
False)
|
|
self._eos = self._get_metadata_setting(model_metadata_lines,
|
|
'EOS token', '')
|
|
|
|
## Phase 3: respond with nothing if no prompt
|
|
if not self._prompt:
|
|
return self
|
|
|
|
if add_bos and not self._prompt.startswith(bos):
|
|
self._prompt = self._bos + self._prompt
|
|
|
|
## Phase 4: Prepare prompt cache
|
|
fid, self._tmp_prompt_cache = tempfile.mkstemp()
|
|
os.close(fid)
|
|
if prompt_cache:
|
|
print(f'[DEBUG] Prompt cache: {prompt_cache}', file=sys.stderr)
|
|
shutil.copy(prompt_cache, self._tmp_prompt_cache)
|
|
else:
|
|
os.remove(self._tmp_prompt_cache) # must be created by llama.cpp
|
|
self._llm_vars['--prompt-cache'] = self._tmp_prompt_cache
|
|
self._llm_vars['--prompt-cache-all'] = ''
|
|
|
|
## Phase 5: Execute LLM
|
|
# Execute llama.cpp main with prompt
|
|
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
|
# llama.cpp likes to drop the last trailing newline from the input.
|
|
# that's a problem for us if our input ends with a newline.
|
|
# so just add one
|
|
llama_prompt = self._prompt + os.linesep
|
|
print('[DEBUG] Prompt in file:', file=sys.stderr)
|
|
print(os.linesep.join([f'[DEBUG] {line}' for line in \
|
|
llama_prompt.split(os.linesep)]),
|
|
file=sys.stderr)
|
|
fp.file.write(llama_prompt.encode('utf-8'))
|
|
fp.file.flush()
|
|
self._llm_vars['--file'] = fp.name
|
|
|
|
# Construct command-line
|
|
cmd = [os.environ[f'MILL_LLAMACPP_MAIN']]
|
|
for name, value in self._llm_vars.items():
|
|
if name in ['suppress eos']:
|
|
continue
|
|
cmd += [name]
|
|
if value:
|
|
cmd += [value]
|
|
cmd = [shlex.quote(param) for param in cmd]
|
|
print(f'[DEBUG] cmd: {cmd}', file=sys.stderr)
|
|
|
|
self._llama_process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True)
|
|
self._generated_text = ''
|
|
self._buffer = ''
|
|
|
|
return self
|
|
|
|
|
|
def read(self, num_chars):
|
|
if not self._llama_process:
|
|
return ''
|
|
|
|
# Filter BOS and EOS from prompt (llama.cpp doesn't output those)
|
|
# then search output for start of generated text.
|
|
if not self._prompt_found:
|
|
input_prompt = self._prompt
|
|
if self._bos:
|
|
input_prompt = ''.join(input_prompt.split(self._bos))
|
|
if self._eos:
|
|
input_prompt = ''.join(input_prompt.split(self._eos))
|
|
|
|
while True:
|
|
chars = self._llama_process.stdout.read(num_chars)
|
|
if not chars:
|
|
break
|
|
self._buffer += chars
|
|
input_prompt_pos = self._buffer.find(input_prompt)
|
|
if input_prompt_pos != -1:
|
|
self._prompt_found = True
|
|
self._buffer = \
|
|
self._buffer[input_prompt_pos+len(input_prompt):]
|
|
break
|
|
|
|
if not self._prompt_found:
|
|
if self._llama_process.returncode != 0:
|
|
print(self._llama_process.stderr.read(), file=sys.stderr)
|
|
self.returncode = self._llama_process.returncode
|
|
else:
|
|
self.returncode = 1
|
|
print('Unable to find input prompt', file=sys.stderr)
|
|
print(os.linesep.join([f'[INPUT] {line}' for line in \
|
|
input_prompt.split(os.linesep)]),
|
|
file=sys.stderr)
|
|
print(os.linesep.join([f'[OUTPUT] {line}' for line in \
|
|
result.stdout.split(os.linesep)]),
|
|
file=sys.stderr)
|
|
return ''
|
|
|
|
# Extract generated text
|
|
if len(self._buffer) < num_chars:
|
|
chars = self._llama_process.stdout.read(num_chars-len(self._buffer))
|
|
self._buffer += chars
|
|
self._generated_text += chars
|
|
|
|
# Generation is finished
|
|
if not chars:
|
|
self._llama_process.wait(int(
|
|
os.environ.get(f'MILL_LLAMACPP_TIMEOUT', 600)))
|
|
|
|
err = self._llama_process.stderr.read()
|
|
print(err, file=sys.stderr)
|
|
|
|
if self._llama_process.returncode != 0:
|
|
self.returncode = self._llama_process.returncode
|
|
print(f'[DEBUG] returncode!: {self.returncode}',
|
|
file=sys.stderr)
|
|
|
|
elif err.find('[end of text]') != -1 and self._add_eos and \
|
|
self._llm_vars.get('suppress eos') is None:
|
|
print(f'[DEBUG] adding eos {self._eos}', file=sys.stderr)
|
|
self._buffer += self._eos
|
|
self._generated_text += self._eos
|
|
|
|
self._promptcache_dest = self._promptcache_path(
|
|
self._prompt + self._generated_text,
|
|
self._llm_vars.get('--model'),
|
|
self._llm_vars.get('--ctx-size'))
|
|
|
|
chars = self._buffer[:num_chars]
|
|
self._buffer = self._buffer[num_chars:]
|
|
return chars
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
# Remove temporary files
|
|
if self._llm_vars.get('--file'):
|
|
os.remove(self._llm_vars['--file'])
|
|
|
|
# Save prompt cache
|
|
if self._promptcache_dest:
|
|
self._promptcache_dest.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.move(self._tmp_prompt_cache, str(self._promptcache_dest))
|
|
|
|
return None
|
|
|
|
|
|
def _get_metadata_setting(self, metadata_lines, name, default):
|
|
matches = [line for line in metadata_lines if name in line]
|
|
if not matches:
|
|
return default
|
|
|
|
match = matches[0]
|
|
words = match.split(' ')
|
|
|
|
if words[-1] == 'false' or words[-1] == 'true':
|
|
return words[-1] == 'true'
|
|
|
|
if words[-1].startswith("'") and words[-1].endswith("'"):
|
|
return words[-1][1:-1]
|
|
|
|
raise Exception(f'Unrecognized metadata settings line: {match}')
|
|
|
|
|
|
def _promptcache_path(self, prompt, model, ctx_size):
|
|
model = model if model else ''
|
|
ctx_size = ctx_size if ctx_size else ''
|
|
|
|
cache_id = 'ctx:' + ctx_size + 'model:' + model + 'prompt:' + prompt
|
|
cache_id_hash = hashlib.sha256(cache_id.encode('utf-8')).hexdigest()[:10]
|
|
|
|
print('[DEBUG] promptcache_path called', file=sys.stderr)
|
|
print(f'[DEBUG] {cache_id_hash}', file=sys.stderr)
|
|
print(os.linesep.join([f'[DEBUG] {line}' for line in \
|
|
cache_id.split(os.linesep)]),
|
|
file=sys.stderr)
|
|
|
|
return pathlib.Path(
|
|
os.environ.get(f'MILL_LLAMACPP_CACHE_DIR', tempfile.gettempdir()),
|
|
f'{cache_id_hash}.promptcache')
|