Skip to content

Commit

Permalink
add TencentSearch
Browse files Browse the repository at this point in the history
  • Loading branch information
braisedpork1964 committed Oct 16, 2024
1 parent f55b01f commit e189ec5
Showing 1 changed file with 198 additions and 4 deletions.
202 changes: 198 additions & 4 deletions lagent/actions/web_browser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import hashlib
import hmac
import json
import logging
import random
import re
import time
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from http.client import HTTPSConnection
from typing import List, Optional, Tuple, Type, Union

import aiohttp
Expand Down Expand Up @@ -87,7 +91,7 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
await asyncio.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from DuckDuckGo after retries.')

Expand Down Expand Up @@ -163,7 +167,7 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
await asyncio.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Bing Search after retries.')

Expand Down Expand Up @@ -277,7 +281,7 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
await asyncio.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Brave Search after retries.')

Expand Down Expand Up @@ -413,7 +417,7 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
await asyncio.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Google Serper Search after retries.'
)
Expand Down Expand Up @@ -499,6 +503,196 @@ def _parse_response(self, response: dict) -> dict:
return self._filter_results(raw_results)


class TencentSearch(BaseSearch):
"""Wrapper around the tencentclound Search API.
To use, you should pass your secret_id and secret_key to the constructor.
Args:
secret_id (str): Your Tencent Cloud secret ID for accessing the API.
For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488.
secret_key (str): Your Tencent Cloud secret key for accessing the API.
api_key (str, optional): Additional API key, if required.
action (str): The action for this interface, use `SearchCommon`.
version (str): The API version, use `2020-12-29`.
service (str): The service name, use `tms`.
host (str): The API host, use `tms.tencentcloudapi.com`.
topk (int): The maximum number of search results to return.
tsn (int): Time filter for search results. Valid values:
1 (within 1 day), 2 (within 1 week), 3 (within 1 month),
4 (within 1 year), 5 (within 6 months), 6 (within 3 years).
insite (str): Specify a site to search within (supports only a single site).
If not specified, the entire web is searched. Example: `zhihu.com`.
category (str): Vertical category for filtering results. Optional values include:
`baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope).
vrid (str): Result card type(s). Different `vrid` values represent different types of result cards.
Supports multiple values separated by commas. Example: `30010255`.
"""

def __init__(self,
secret_id: str = 'Your SecretId',
secret_key: str = 'Your SecretKey',
api_key: str = '',
action: str = 'SearchCommon',
version: str = '2020-12-29',
service: str = 'tms',
host: str = 'tms.tencentcloudapi.com',
topk: int = 3,
tsn: int = None,
insite: str = None,
category: str = None,
vrid: str = None,
black_list: List[str] = [
'enoN',
'youtube.com',
'bilibili.com',
'researchgate.net',
]):
self.secret_id = secret_id
self.secret_key = secret_key
self.api_key = api_key
self.action = action
self.version = version
self.service = service
self.host = host
self.tsn = tsn
self.insite = insite
self.category = category
self.vrid = vrid
super().__init__(topk, black_list=black_list)

@cached(cache=TTLCache(maxsize=100, ttl=600))
def search(self, query: str, max_retry: int = 3) -> dict:
for attempt in range(max_retry):
try:
response = self._call_tencent_api(query)
return self._parse_response(response)
except Exception as e:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Bing Search after retries.')

@acached(cache=TTLCache(maxsize=100, ttl=600))
async def asearch(self, query: str, max_retry: int = 3) -> dict:
for attempt in range(max_retry):
try:
response = await self._async_call_tencent_api(query)
return self._parse_response(response)
except Exception as e:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
await asyncio.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Bing Search after retries.')

def _get_headers_and_payload(self, query: str) -> tuple:

def sign(key, msg):
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()

params = dict(Query=query)
# if self.topk:
# params['Cnt'] = self.topk
if self.tsn:
params['Tsn'] = self.tsn
if self.insite:
params['Insite'] = self.insite
if self.category:
params['Category'] = self.category
if self.vrid:
params['Vrid'] = self.vrid
payload = json.dumps(params)
algorithm = 'TC3-HMAC-SHA256'
timestamp = int(time.time())
date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d')

# ************* 步骤 1:拼接规范请求串 *************
http_request_method = 'POST'
canonical_uri = '/'
canonical_querystring = ''
ct = 'application/json; charset=utf-8'
canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
signed_headers = 'content-type;host;x-tc-action'
hashed_request_payload = hashlib.sha256(
payload.encode('utf-8')).hexdigest()
canonical_request = (
http_request_method + '\n' + canonical_uri + '\n' +
canonical_querystring + '\n' + canonical_headers + '\n' +
signed_headers + '\n' + hashed_request_payload)

# ************* 步骤 2:拼接待签名字符串 *************
credential_scope = date + '/' + self.service + '/' + 'tc3_request'
hashed_canonical_request = hashlib.sha256(
canonical_request.encode('utf-8')).hexdigest()
string_to_sign = (
algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
'\n' + hashed_canonical_request)

# ************* 步骤 3:计算签名 *************
secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
secret_service = sign(secret_date, self.service)
secret_signing = sign(secret_service, 'tc3_request')
signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
hashlib.sha256).hexdigest()

# ************* 步骤 4:拼接 Authorization *************
authorization = (
algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
', ' + 'Signature=' + signature)

# ************* 步骤 5:构造并发起请求 *************
headers = {
'Authorization': authorization,
'Content-Type': 'application/json; charset=utf-8',
'Host': self.host,
'X-TC-Action': self.action,
'X-TC-Timestamp': str(timestamp),
'X-TC-Version': self.version
}
# if self.region:
# headers["X-TC-Region"] = self.region
if self.api_key:
headers['X-TC-Token'] = self.api_key
return headers, payload

def _call_tencent_api(self, query: str) -> dict:
headers, payload = self._get_headers_and_payload(query)
req = HTTPSConnection(self.host)
req.request('POST', '/', headers=headers, body=payload.encode('utf-8'))
resp = req.getresponse()
try:
resp = json.loads(resp.read().decode('utf-8'))
except Exception as e:
logging.warning(str(e))
import ast
resp = ast.literal_eval(resp)
return resp.get('Response', dict())

async def _async_call_tencent_api(self, query: str):
headers, payload = self._get_headers_and_payload(query)
async with aiohttp.ClientSession(raise_for_status=True) as session:
async with session.post(
'https://' + self.host.lstrip('/'),
headers=headers,
data=payload) as resp:
return (await resp.json()).get('Response', {})

def _parse_response(self, response: dict) -> dict:
raw_results = []
for item in response.get('Pages', []):
display = json.loads(item['Display'])
if not display['url']:
continue
raw_results.append((display['url'], display['content']
or display['abstract_info'], display['title']))
return self._filter_results(raw_results)


class ContentFetcher:

def __init__(self, timeout: int = 5):
Expand Down

0 comments on commit e189ec5

Please sign in to comment.