Skip to content

Commit

Permalink
Compatible with lmdeploy (#258)
Browse files Browse the repository at this point in the history
* compatible with lmdeploy

* update

* update depends

* update

---------

Co-authored-by: liukuikun <[email protected]>
  • Loading branch information
lvhan028 and Harold-lkk authored Oct 21, 2024
1 parent 238aacf commit f63e6ef
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
51 changes: 46 additions & 5 deletions lagent/llms/lmdeploy_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy
import logging
from typing import List, Optional, Union

from lagent.llms.base_llm import BaseModel
Expand All @@ -23,7 +25,12 @@ def __init__(self,
log_level: str = 'WARNING',
**kwargs):
super().__init__(path=None, **kwargs)
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
try:
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
except Exception as e:
logging.error(f'{e}')
raise RuntimeError('DO NOT use turbomind.chatbot since it has '
'been removed by lmdeploy since v0.5.2')
self.state_map = {
StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
Expand Down Expand Up @@ -226,11 +233,32 @@ def __init__(self,
tp: int = 1,
pipeline_cfg=dict(),
**kwargs):

import lmdeploy
from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info

self.str_version = lmdeploy.__version__
self.version = version_info
self.do_sample = kwargs.pop('do_sample', None)
if self.do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
super().__init__(path=path, **kwargs)
from lmdeploy import pipeline
backend_config = copy.deepcopy(pipeline_cfg)
backend_config.update(tp=tp)
backend_config = {
k: v
for k, v in backend_config.items()
if hasattr(TurbomindEngineConfig, k)
}
backend_config = TurbomindEngineConfig(**backend_config)
chat_template_config = ChatTemplateConfig(
model_name=model_name) if model_name else None
self.model = pipeline(
model_path=self.path, model_name=model_name, tp=tp, **pipeline_cfg)
model_path=self.path,
backend_config=backend_config,
chat_template_config=chat_template_config,
log_level='WARNING')

def generate(self,
inputs: Union[str, List[str]],
Expand All @@ -249,13 +277,26 @@ def generate(self,
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig

batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
do_sample = kwargs.pop('do_sample', None)
gen_params = self.update_gen_params(**kwargs)

if do_sample is None:
do_sample = self.do_sample
if do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
if self.version >= (0, 6, 0):
if do_sample is None:
do_sample = gen_params['top_k'] > 1 or gen_params[
'temperature'] > 0
gen_params.update(do_sample=do_sample)

gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
Expand Down
2 changes: 1 addition & 1 deletion requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
google-search-results
lmdeploy<=0.5.3
lmdeploy>=0.2.5
pillow
python-pptx
timeout_decorator
Expand Down

0 comments on commit f63e6ef

Please sign in to comment.