mill.py/mill_llm_llama_cpp.py

337 lines
13 KiB
Python

# mill.py, Markdown interface for llama.cpp
# Copyright (C) 2024 unworriedsafari <unworriedsafari@tilde.club>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
"""
## `llama.cpp` tutorial
This section describes the `llama.cpp` LLM-engine module of `mill.py`.
### LLM variables
`suppress eos`
Some models perform better if the EOS is part of the prompt. `llama.cpp` models
have a setting `add_eos_token` that seems to mean 'please add the EOS to the
generated text.' `mill.py` respects this setting, and adds the EOS in that case
if the model generates it, _unless_ you declare the LLM variable `suppress eos`
in the document. In that case `mill.py` will not add the EOS token if the model
generates it.
Other LLM variables are simply passed on to `llama.cpp` as command-line
arguments. A variable with an empty value is passed without a value (i.e. as a
flag). There are a couple of LLM variables that are reserved for `mill.py`, so
you cannot use them. These are:
- `--file` for the input prompt
- `--prompt-cache` for the prompt cache to use
- `--prompt-cache-all` so that generated text is also cached
Using these variables results in an error.
### Environment variables
Apart from LLM variables, there are also a few environment variables that
influence the behavior of the `llama.cpp` module.
`MILL_LLAMACPP_MAIN`
This variable is required and must be set to the path of `llama.cpp/main`. It
can be your own script too. As long as:
1. The script can accept arguments that are passed from `mill.py`.
2. The standard output consists of the input prompt followed by the generated
text.
3. The error output contains error output generated by `llama.cpp`. This is
used by `mill.py` to extract some settings from the model's metadata
such as BOS and EOS tokens and whether or not to add them in the right
places.
`MILL_LLAMACPP_CACHE_DIR`
Path to the directory where the prompt caches are stored. By default this is
the OS's temporary-files directory. Note: prompt caches can be large files and
`mill.py` does not automatically clean them. You can recognize the files by the
extension `.promptcache`.
`MILL_LLAMACPP_TIMEOUT`
The maximum number of seconds to wait for the `llama.cpp/main` process to
complete. Default is 600.
### Prompt caching
For each invocation, a prompt cache is generated. `mill.py` searches for a
matching prompt cache after parsing.
"""
import contextlib, hashlib, os, pathlib, re, shlex, shutil, subprocess, sys
import tempfile
def generate(llm_vars, prompt):
return LLM(llm_vars, prompt)
class LLM(contextlib.AbstractContextManager):
def __init__(self, llm_vars, prompt):
self.returncode = 0
self._add_eos = False
self._bos = ''
self._buffer = ''
self._eos = ''
self._generated_text = ''
self._llama_process = None
self._llm_vars = llm_vars
self._prompt = prompt
self._promptcache_dest = None
self._prompt_found = False
self._tmp_prompt_cache = ''
reserved_vars = set(['--file', '--prompt-cache', '--prompt-cache-all'])
illegal_vars = set(self._llm_vars.keys()) & reserved_vars
if illegal_vars:
raise RuntimeError(f'variables not allowed: {illegal_vars}')
def __enter__(self):
## Phase 1: input is processed
# Search for prompt cache
prompt_cache = ''
cached_prompt = self._prompt
while cached_prompt and not prompt_cache:
cache_path = self._promptcache_path(cached_prompt,
self._llm_vars.get('--model'),
self._llm_vars.get('--ctx-size'))
if cache_path.is_file():
print('[DEBUG] cache exists!', file=sys.stderr)
prompt_cache = str(cache_path)
else:
cached_prompt = cached_prompt[:-1]
# Get final variable values
ctx_size = self._llm_vars.get('--ctx-size', '')
model = self._llm_vars.get('--model', '')
## Phase 2: get model metadata
metadata_cmd = [os.environ[f'MILL_LLAMACPP_MAIN'], '--n-predict', '0']
if model:
metadata_cmd += ['--model', model]
model_metadata = subprocess.run(
[shlex.quote(param) for param in metadata_cmd],
capture_output=True, text=True)
# Return any errors
if model_metadata.returncode != 0:
print(model_metadata.stderr, file=sys.stderr)
self.returncode = model_metadata.returncode
return self
model_metadata_lines = model_metadata.stderr.split(os.linesep)
# We need to add BOS/EOS if model doesn't do it
add_bos = not self._get_metadata_setting(model_metadata_lines,
'add_bos_token',
False)
self._bos = self._get_metadata_setting(model_metadata_lines,
'BOS token', '')
self._add_eos = not self._get_metadata_setting(model_metadata_lines,
'add_eos_token',
False)
self._eos = self._get_metadata_setting(model_metadata_lines,
'EOS token', '')
## Phase 3: respond with nothing if no prompt
if not self._prompt:
return self
if add_bos and not self._prompt.startswith(bos):
self._prompt = self._bos + self._prompt
## Phase 4: Prepare prompt cache
fid, self._tmp_prompt_cache = tempfile.mkstemp()
os.close(fid)
if prompt_cache:
print(f'[DEBUG] Prompt cache: {prompt_cache}', file=sys.stderr)
shutil.copy(prompt_cache, self._tmp_prompt_cache)
else:
os.remove(self._tmp_prompt_cache) # must be created by llama.cpp
self._llm_vars['--prompt-cache'] = self._tmp_prompt_cache
self._llm_vars['--prompt-cache-all'] = ''
## Phase 5: Execute LLM
# Execute llama.cpp main with prompt
with tempfile.NamedTemporaryFile(delete=False) as fp:
# llama.cpp likes to drop the last trailing newline from the input.
# that's a problem for us if our input ends with a newline.
# so just add one
llama_prompt = self._prompt + os.linesep
print('[DEBUG] Prompt in file:', file=sys.stderr)
print(os.linesep.join([f'[DEBUG] {line}' for line in \
llama_prompt.split(os.linesep)]),
file=sys.stderr)
fp.file.write(llama_prompt.encode('utf-8'))
fp.file.flush()
self._llm_vars['--file'] = fp.name
# Construct command-line
cmd = [os.environ[f'MILL_LLAMACPP_MAIN']]
for name, value in self._llm_vars.items():
if name in ['suppress eos']:
continue
cmd += [name]
if value:
cmd += [value]
cmd = [shlex.quote(param) for param in cmd]
print(f'[DEBUG] cmd: {cmd}', file=sys.stderr)
self._llama_process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
self._generated_text = ''
self._buffer = ''
return self
def read(self, num_chars):
if not self._llama_process:
return ''
# Filter BOS and EOS from prompt (llama.cpp doesn't output those)
# then search output for start of generated text.
if not self._prompt_found:
input_prompt = self._prompt
if self._bos:
input_prompt = ''.join(input_prompt.split(self._bos))
if self._eos:
input_prompt = ''.join(input_prompt.split(self._eos))
while True:
chars = self._llama_process.stdout.read(num_chars)
if not chars:
break
self._buffer += chars
input_prompt_pos = self._buffer.find(input_prompt)
if input_prompt_pos != -1:
self._prompt_found = True
self._buffer = \
self._buffer[input_prompt_pos+len(input_prompt):]
break
if not self._prompt_found:
if self._llama_process.returncode != 0:
print(self._llama_process.stderr.read(), file=sys.stderr)
self.returncode = self._llama_process.returncode
else:
self.returncode = 1
print('Unable to find input prompt', file=sys.stderr)
print(os.linesep.join([f'[INPUT] {line}' for line in \
input_prompt.split(os.linesep)]),
file=sys.stderr)
print(os.linesep.join([f'[OUTPUT] {line}' for line in \
result.stdout.split(os.linesep)]),
file=sys.stderr)
return ''
# Extract generated text
if len(self._buffer) < num_chars:
chars = self._llama_process.stdout.read(num_chars-len(self._buffer))
self._buffer += chars
self._generated_text += chars
# Generation is finished
if not chars:
self._llama_process.wait(int(
os.environ.get(f'MILL_LLAMACPP_TIMEOUT', 600)))
err = self._llama_process.stderr.read()
print(err, file=sys.stderr)
if self._llama_process.returncode != 0:
self.returncode = self._llama_process.returncode
print(f'[DEBUG] returncode!: {self.returncode}',
file=sys.stderr)
elif err.find('[end of text]') != -1 and self._add_eos and \
self._llm_vars.get('suppress eos') is None:
print(f'[DEBUG] adding eos {self._eos}', file=sys.stderr)
self._buffer += self._eos
self._generated_text += self._eos
self._promptcache_dest = self._promptcache_path(
self._prompt + self._generated_text,
self._llm_vars.get('--model'),
self._llm_vars.get('--ctx-size'))
chars = self._buffer[:num_chars]
self._buffer = self._buffer[num_chars:]
return chars
def __exit__(self, exc_type, exc_value, traceback):
# Remove temporary files
if self._llm_vars.get('--file'):
os.remove(self._llm_vars['--file'])
# Save prompt cache
if self._promptcache_dest:
self._promptcache_dest.parent.mkdir(parents=True, exist_ok=True)
shutil.move(self._tmp_prompt_cache, str(self._promptcache_dest))
return None
def _get_metadata_setting(self, metadata_lines, name, default):
matches = [line for line in metadata_lines if name in line]
if not matches:
return default
match = matches[0]
words = match.split(' ')
if words[-1] == 'false' or words[-1] == 'true':
return words[-1] == 'true'
if words[-1].startswith("'") and words[-1].endswith("'"):
return words[-1][1:-1]
raise Exception(f'Unrecognized metadata settings line: {match}')
def _promptcache_path(self, prompt, model, ctx_size):
model = model if model else ''
ctx_size = ctx_size if ctx_size else ''
cache_id = 'ctx:' + ctx_size + 'model:' + model + 'prompt:' + prompt
cache_id_hash = hashlib.sha256(cache_id.encode('utf-8')).hexdigest()[:10]
print('[DEBUG] promptcache_path called', file=sys.stderr)
print(f'[DEBUG] {cache_id_hash}', file=sys.stderr)
print(os.linesep.join([f'[DEBUG] {line}' for line in \
cache_id.split(os.linesep)]),
file=sys.stderr)
return pathlib.Path(
os.environ.get(f'MILL_LLAMACPP_CACHE_DIR', tempfile.gettempdir()),
f'{cache_id_hash}.promptcache')