-
Notifications
You must be signed in to change notification settings - Fork 0
/
zhipuai_llm.py
66 lines (54 loc) · 1.88 KB
/
zhipuai_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, List, Mapping, Optional, Dict
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from zhipuai import ZhipuAI
import os
# 继承自 langchain.llms.base.LLM
class ZhipuAILLM(LLM):
# 默认选用 glm-4
model: str = "glm-4"
# 温度系数
temperature: float = 0.1
# API_Key
api_key: str = None
def _call(self, prompt : str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any):
client = ZhipuAI(
api_key = self.api_key
)
def gen_glm_params(prompt):
'''
构造 GLM 模型请求参数 messages
请求参数:
prompt: 对应的用户提示词
'''
messages = [{"role": "user", "content": prompt}]
return messages
messages = gen_glm_params(prompt)
response = client.chat.completions.create(
model = self.model,
messages = messages,
temperature = self.temperature
)
if len(response.choices) > 0:
return response.choices[0].message.content
return "generate answer error"
# 首先定义一个返回默认参数的方法
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用API的默认参数。"""
normal_params = {
"temperature": self.temperature,
}
# print(type(self.model_kwargs))
return {**normal_params}
@property
def _llm_type(self) -> str:
return "Zhipu"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}