-
Notifications
You must be signed in to change notification settings - Fork 146
/
chat_demo.py
97 lines (85 loc) · 3.51 KB
/
chat_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import paddle
seed = 2024
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
import argparse
from PIL import Image
from paddlemix.auto.modeling import AutoModelMIX
from paddlemix.auto.tokenizer import AutoTokenizerMIX
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path", type=str, default="THUDM/cogagent-chat", help="pretrained ckpt and tokenizer"
)
args = parser.parse_args()
MODEL_PATH = args.model_name_or_path
TOKENIZER_PATH = MODEL_PATH
tokenizer = AutoTokenizerMIX.from_pretrained(TOKENIZER_PATH)
data_type = "float16"
model = AutoModelMIX.from_pretrained(
MODEL_PATH,
dtype=data_type,
low_cpu_mem_usage=False,
)
model.eval()
text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
while True:
image_path = input("image path >>>>> ")
if image_path == "":
print("You did not enter image path, the following will be a plain text conversation.")
image = None
text_only_first_query = True
else:
image = Image.open(image_path).convert("RGB")
history = []
while True:
query = input("Human:")
if query == "clear":
break
if image is None:
if text_only_first_query:
query = text_only_template.format(query)
text_only_first_query = False
else:
old_prompt = ""
for _, (old_query, response) in enumerate(history):
old_prompt += old_query + " " + response + "\n"
query = old_prompt + "USER: {} ASSISTANT:".format(query)
if image is None:
input_by_model = model.build_conversation_input_ids(
tokenizer, query=query, history=history, template_version="base"
)
else:
input_by_model = model.build_conversation_input_ids(
tokenizer, query=query, history=history, images=[image]
)
inputs = {
"input_ids": input_by_model["input_ids"].unsqueeze(axis=0),
"token_type_ids": input_by_model["token_type_ids"].unsqueeze(axis=0),
"attention_mask": input_by_model["attention_mask"].unsqueeze(axis=0),
"images": [[input_by_model["images"][0].to(data_type)]] if image is not None else None,
}
if "cross_images" in input_by_model and input_by_model["cross_images"]:
inputs["cross_images"] = [[input_by_model["cross_images"][0].to(data_type)]]
gen_kwargs = {"max_new_tokens": 2048, "do_sample": False}
with paddle.no_grad():
outputs, _ = model.generate(**inputs, **gen_kwargs)
response = tokenizer.decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
history.append((query, response))