Skip to content

Commit

Permalink
[Automated Commit] Format Codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmpablo157321 committed Nov 22, 2024
1 parent 76142d0 commit 3fbdbb0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
17 changes: 12 additions & 5 deletions language/llama2-70b/evaluate-accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def rouge(label, pred):


def niah_em(label, pred):
label_uuids = re.findall(r'[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}', label)
label_uuids = re.findall(
r'[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}', label)
pred_uuids = re.findall(r'[\w]{8}-[\w]{4}-[\w]{4}-[\w]{4}-[\w]{12}', pred)

# https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py#L28
Expand All @@ -73,7 +74,8 @@ def qa_em(label, pred):
return {'exact_match': 100.0}

normalized_answer = re.sub(r'\s+', '', answer_substring).lower()
label_entries = [re.sub(r'\s+', '', entry).lower() for entry in label.split('|')]
label_entries = [re.sub(r'\s+', '', entry).lower()
for entry in label.split('|')]

match_found = any(entry in normalized_answer for entry in label_entries)
return {'exact_match': 100.0 if match_found else 0.0}
Expand All @@ -85,7 +87,7 @@ def qa_em(label, pred):
}


def get_groundtruth(processed_dataset_file, return_metrics = True):
def get_groundtruth(processed_dataset_file, return_metrics=True):
data = pd.read_pickle(processed_dataset_file)
ground_truths = data["gt_output"]
if return_metrics:
Expand All @@ -104,17 +106,22 @@ def postprocess_text(preds, targets):

return preds, targets


def process_item(item):
pred, target, metric = item
metric_fn = metrics[metric]
metric_eval = metric_fn(target, pred)
return metric_eval


def run_evaluation(preds, targets, metrics, n_process = None):
def run_evaluation(preds, targets, metrics, n_process=None):
n_process = cpu_count() if n_process is None else n_process
with Pool(n_process) as pool:
accuracies = list(tqdm(pool.imap(process_item, zip(preds, targets, metrics)), total=len(preds)))
accuracies = list(
tqdm(
pool.imap(
process_item, zip(
preds, targets, metrics)), total=len(preds)))
df = pd.DataFrame({"accuracy": accuracies, "metric": metrics})
return df.groupby(by="metric", axis=1).mean()

Expand Down
33 changes: 21 additions & 12 deletions language/llama3-405b/SUT_VLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
log = logging.getLogger("Llama-405B-SUT")



class SUT:
def __init__(
self,
Expand Down Expand Up @@ -113,11 +112,12 @@ def process_queries(self):

tik1 = time.time()

input_ids_tensor = [self.data_object.input_ids[q.index] for q in qitem]

input_ids_tensor = [
self.data_object.input_ids[q.index] for q in qitem]

tik2 = time.time()
outputs = self.model.generate(
prompt_token_ids=input_ids_tensor, sampling_params = self.sampling_params
prompt_token_ids=input_ids_tensor, sampling_params=self.sampling_params
)
pred_output_tokens = []
for output in outputs:
Expand Down Expand Up @@ -155,7 +155,11 @@ def process_queries(self):

def load_model(self):
log.info("Loading model...")
self.model = LLM(self.model_path, dtype=self.dtype, tensor_parallel_size=self.tensor_parallel_size,)
self.model = LLM(
self.model_path,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size,
)
log.info("Loaded model")

def get_sut(self):
Expand Down Expand Up @@ -218,14 +222,15 @@ def start(self):
worker = threading.Thread(target=self.process_queries)
worker.start()
self.worker_threads[j] = worker

async def stream_output(self, qitem, results_generator):
first = True
async for request_output in results_generator:
output_response = request_output
if first:
first_tokens = list(output_response.outputs[0].token_ids)
response_data = array.array("B", np.array(first_tokens, np.int32).tobytes())
response_data = array.array(
"B", np.array(first_tokens, np.int32).tobytes())
bi = response_data.buffer_info()
response = [lg.QuerySampleResponse(qitem.id, bi[0], bi[1])]
lg.FirstTokenComplete(response)
Expand All @@ -246,7 +251,6 @@ async def stream_output(self, qitem, results_generator):
n_tokens)]
lg.QuerySamplesComplete(response)


def process_queries(self):
"""Processor of the queued queries. User may choose to add batching logic"""
while True:
Expand All @@ -255,12 +259,14 @@ def process_queries(self):
if qitem is None:
break

input_ids_tensor = TokensPrompt(prompt_token_ids=self.data_object.input_ids[qitem.index])
input_ids_tensor = TokensPrompt(
prompt_token_ids=self.data_object.input_ids[qitem.index])

# TODO: This PoC is super slow with significant overhead. Best to
# create a patch to `generate`
results_generator = self.model.generate(
prompt=input_ids_tensor, sampling_params = self.sampling_params, request_id = str(self.request_id)
prompt=input_ids_tensor, sampling_params=self.sampling_params, request_id=str(
self.request_id)
)
self.request_id += 1
asyncio.run(self.stream_output(qitem, results_generator))
Expand All @@ -277,9 +283,12 @@ def stop(self):

self.first_token_queue.put(None)
self.ft_response_thread.join()

def load_model(self):
log.info("Loading model")
self.engine_args = AsyncEngineArgs(self.model_path, dtype=self.dtype, tensor_parallel_size=self.tensor_parallel_size)
self.engine_args = AsyncEngineArgs(
self.model_path,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size)
self.model = AsyncLLMEngine.from_engine_args(self.engine_args)
log.info("Loaded model")
1 change: 0 additions & 1 deletion language/llama3-405b/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(
self.total_sample_count = min(len(self.input_ids), total_sample_count)
self.perf_count = perf_count_override or self.total_sample_count


def load_processed_dataset(self):
if not os.path.isfile(self.dataset_path):
log.warn(
Expand Down

0 comments on commit 3fbdbb0

Please sign in to comment.