Skip to content

Commit

Permalink
[fix] log_final_results #258 (#259)
Browse files Browse the repository at this point in the history
* [fix] log_final_results #258

* [fix] add conversation
  • Loading branch information
huyiwen authored Jun 11, 2024
1 parent a539d64 commit aefdee9
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 35 deletions.
38 changes: 37 additions & 1 deletion .github/.test_durations
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,41 @@
"tests/utilization/utils/test_batch_sampler.py::test_dcbs": 0.8340403449999485,
"tests/utilization/utils/test_batch_sampler.py::test_dcbs_auto_batching": 0.06662718100005804,
"tests/utilization/utils/test_batch_sampler.py::test_dcbs_few_shot": 0.0838686949998646,
"tests/utilization/utils/test_batch_sampler.py::test_dcbs_few_shot_prefix_caching": 3.93439717199999
"tests/utilization/utils/test_batch_sampler.py::test_dcbs_few_shot_prefix_caching": 3.93439717199999,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:api-False-no_split]": 0.002944755367934704,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:api-False-split]": 0.0031265132129192352,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:api-True-no_split]": 0.002978229895234108,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:api-True-split]": 0.003244483843445778,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:local-False-no_split]": 0.00300593301653862,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:local-False-split]": 0.003051883541047573,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:local-True-no_split]": 0.00298389233648777,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample1:local-True-split]": 0.011334518902003765,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:api-False-no_split]": 0.00297668669372797,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:api-False-split]": 0.003215758129954338,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:api-True-no_split]": 0.0030325893312692642,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:api-True-split]": 0.003189575858414173,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:local-False-no_split]": 0.0030073318630456924,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:local-False-split]": 0.003115130588412285,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:local-True-no_split]": 0.002978302538394928,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:conv:sample2:local-True-split]": 0.0031450027599930763,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample1:local-False-no_split]": 0.002832619473338127,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample1:local-False-split]": 0.0028070490807294846,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample1:local-True-no_split]": 0.0029166433960199356,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample1:local-True-split]": 0.005571841262280941,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample2:local-False-no_split]": 0.0028258198872208595,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample2:local-False-split]": 0.0027440031990408897,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample2:local-True-no_split]": 0.0027814237400889397,
"tests/utilization/utils/test_log_results.py::test_log_final_results[generation:no_norm:legacy:sample2:local-True-split]": 0.0028041303157806396,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:acc_norm:legacy:sample1:local-False-no_split]": 0.0029122764244675636,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:acc_norm:legacy:sample1:local-False-split]": 0.0029731355607509613,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:acc_norm:legacy:sample1:local-True-no_split]": 0.002978702075779438,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:acc_norm:legacy:sample1:local-True-split]": 0.0030855443328619003,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:no_norm:legacy:sample1:local-False-no_split]": 0.002977837808430195,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:no_norm:legacy:sample1:local-False-split]": 0.003017907030880451,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:no_norm:legacy:sample1:local-True-no_split]": 0.002957606688141823,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_ppl:no_norm:legacy:sample1:local-True-split]": 0.0033153872936964035,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_prob:no_norm:legacy:sample1:local-False-no_split]": 0.002739163115620613,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_prob:no_norm:legacy:sample1:local-False-split]": 0.002798774279654026,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_prob:no_norm:legacy:sample1:local-True-no_split]": 0.0028306040912866592,
"tests/utilization/utils/test_log_results.py::test_log_final_results[get_prob:no_norm:legacy:sample1:local-True-split]": 0.002850011922419071
}
16 changes: 13 additions & 3 deletions tests/dry_test/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import nltk
import pytest

from .fixtures import run_evaluate
from utilization.utils.logging import list_datasets

from .fixtures import *

nltk.download('punkt')

Expand All @@ -20,7 +22,7 @@
"commonsenseqa": [],
"copa": [],
"coqa": "skip",
# "crows_pairs": "does not support api model",
"crows_pairs": "does not support api model",
"drop": [],
"gaokao": [],
"gsm8k": [],
Expand Down Expand Up @@ -57,7 +59,7 @@
"webq": [],
"wic": [],
"winogender": [],
# "winograd": "does not support api model",
"winograd": "does not support api model",
"winogrande": [],
"wmt16:de-en": [],
"wsc": [],
Expand All @@ -82,6 +84,14 @@
)
}

datasets_to_test = set(list_datasets()) - {
"wmt10", "wmt13", "wmt14", "wmt15", "wmt16", "wmt17", "wmt18", "wmt19", "wmt21", "agieval_cot",
"agieval_single_choice"
}
for dataset in datasets_to_test:
if dataset not in datasets:
datasets[dataset] = []


@pytest.mark.parametrize("dataset, extra_args", datasets.items())
def test_datasets_dry_run(run_evaluate, dataset, extra_args):
Expand Down
2 changes: 1 addition & 1 deletion tests/dry_test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from .fixtures import run_evaluate
from .fixtures import *

models = {
"gpt-3.5-turbo": ["--openai_api_key", "fake_key"],
Expand Down
2 changes: 1 addition & 1 deletion tests/utilization/model/test_apply_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utilization.chat_templates import DEFAULT_CHAT_TEMPLATE
from utilization.model.model_utils.conversation import Conversation, ConversationFormatter

from ..fixtures import conversation
from ..fixtures import *


def test_base(conversation: Conversation):
Expand Down
2 changes: 1 addition & 1 deletion tests/utilization/model/test_to_model_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from utilization.model.model_utils.conversation import Conversation, ConversationFormatter

from ..fixtures import conversation
from ..fixtures import *

model_evaluation_methods = {
("generation", False): (
Expand Down
2 changes: 1 addition & 1 deletion tests/utilization/utils/test_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from ..fixtures import get_dataset_collection
from ..fixtures import *

sys.path.append('.')
from utilization.model.huggingface_model import HuggingFaceModel
Expand Down
141 changes: 141 additions & 0 deletions tests/utilization/utils/test_log_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import sys
from copy import deepcopy

import pandas as pd

from ..fixtures import *

sys.path.append('.')
from utilization.model.model_utils.conversation import Conversation, ConversationFormatter
from utilization.utils.log_results import log_final_results


def get_conv(split):
if split:
msg = [{
"role": "user",
"content": "This is the example input of the model."
}, {
"role": "assistant",
"content": "This is the sample output of the model."
}, {
"role": "user",
"content": "This is the input of the model."
}]
else:
msg = [{"role": "user", "content": "This is the input of the model."}]
return Conversation(
msg,
formatter=ConversationFormatter.from_chat_template("base"),
model_evaluation_method="generation",
split=split,
)


data = {
("generation", "no_split", "legacy"): [("This is the input of the model.",)],
("generation", "split", "legacy"): [("This is", " a splitted sentence.")],
("generation", "split", "conv"): [get_conv(True)],
("generation", "no_split", "conv"): [get_conv(False)],
("get_ppl", "no_split", "legacy"): [("Source parts of get_ppl", " target parts 1 of get_ppl"),
("Source parts of get_ppl", " target parts 2 of get_ppl")],
("get_ppl", "split", "legacy"):
[("Source parts of get_ppl", " can be splitted, but not", " target parts 1 of get_ppl"),
("Source parts of get_ppl", " can be splitted, but not", " target parts 2 of get_ppl")],
("get_prob", "no_split", "legacy"): [("This is the get_prob input of the model", 2)],
("get_prob", "split", "legacy"): [("The get_prob input of the model", " can also be splitted.", 2)],
}
methods = [
"generation:no_norm:legacy:sample1:local",
"generation:no_norm:conv:sample1:api",
"generation:no_norm:conv:sample1:local",
"generation:no_norm:legacy:sample2:local",
"generation:no_norm:conv:sample2:api",
"generation:no_norm:conv:sample2:local",
"get_prob:no_norm:legacy:sample1:local",
"get_ppl:no_norm:legacy:sample1:local",
"get_ppl:acc_norm:legacy:sample1:local",
]


@pytest.mark.parametrize("split", ["split", "no_split"])
@pytest.mark.parametrize("multiple_source", [True, False])
@pytest.mark.parametrize("methods", methods)
def test_log_final_results(split, multiple_source, methods):

eval_method, use_normalization, use_conversation, sample_num, local = methods.split(":")
use_normalization = use_normalization == "acc_norm"
sample_num = int(sample_num[-1])

def set_subset(l: dict):
l["subset"] = "subset_name"

eval_data = data[eval_method, split, use_conversation]
if eval_method == "get_ppl":
raw = [(0.5, 10), (1.0, 10)] # probabilities, length
processed = [1] # index 0
op_num = 2
elif eval_method == "get_prob":
raw = [[0.1, 0.9]] # probabilities
processed = [1] # index 0
op_num = 2
elif eval_method == "generation":
raw = ["This is the model's raw prediction."]
processed = ["prediction"]
op_num = 1

if use_normalization:
no_num = 2
else:
no_num = 1

series = log_final_results(
raw_predictions=raw * sample_num * no_num,
processed_predictions=processed * sample_num,
evaluation_instances=deepcopy(eval_data) * sample_num * no_num,
score_lists={"Metric": [True]}, # score_lists have already been aggregated along self-concsistency
multiple_source=multiple_source,
model_evaluation_method=eval_method,
use_normalization=use_normalization,
option_nums=[op_num] * sample_num,
len_evaluation_data=1,
sample_num=sample_num,
references=["reference"],
local_model=local == "local",
)
series.apply(set_subset)
print(series)
json_str = pd.concat([series]).to_json(orient="records", indent=4, force_ascii=False)

unmarsheled = json.loads(json_str)
print(json_str)
print(unmarsheled)
assert len(unmarsheled) == 1
assert unmarsheled[0]["index"] == 0
assert unmarsheled[0]["subset"] == "subset_name"
if eval_method == "get_ppl" and not multiple_source:
source = "".join(eval_data[0][:-1])
assert unmarsheled[0]["source"] == source
assert unmarsheled[0]["option"] == [" target parts 1 of get_ppl", " target parts 2 of get_ppl"]
assert unmarsheled[0]["perplexity"] == [0.5, 1.0]
elif eval_method == "get_ppl" and multiple_source:
source = "".join(eval_data[0][:-1])
assert unmarsheled[0]["source"] == [source, source]
assert unmarsheled[0]["option"] == " target parts 1 of get_ppl"
assert unmarsheled[0]["perplexity"] == [0.5, 1.0]
elif eval_method == "get_prob":
source = "".join(eval_data[0][:-1])
assert unmarsheled[0]["source"] == source
assert unmarsheled[0]["probabilites"] == [0.1, 0.9]
elif eval_method == "generation":
if use_conversation == "conv" and local == "local":
source = eval_data[0].apply_prompt_template()
elif use_conversation == "conv" and local == "api":
source = eval_data[0].messages
else:
source = "".join(eval_data[0])
assert unmarsheled[0]["source"] == source
assert unmarsheled[0]["raw_prediction"] == ["This is the model's raw prediction."] * sample_num
assert unmarsheled[0]["reference"] == "reference"
assert unmarsheled[0]["metric"]["Metric"] == True
2 changes: 1 addition & 1 deletion utilization/metric/gpteval.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, multi_turn=False, type: Literal["single", "pairwise"] = "sing
def __call__(self, predictions, references):

# load gpteval model after the predictions of dataset are generated
from ..model import load_model
from ..load_model import load_model

self.model = load_model(self.model_args)
self.model.set_generation_args()
Expand Down
23 changes: 18 additions & 5 deletions utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,26 @@ def to_model_prompts(
class Conversation(_HFConversation):

def __init__(
self, messages: Union[str, List[Dict[str, str]], None] = None, conversation_id=None, **deprecated_kwargs
self,
messages: Union[str, List[Dict[str, str]], None] = None,
conversation_id=None,
num_turns: int = 1,
num_shots: int = 0,
num_options: int = 1,
multi_turn_users: Optional[List[str]] = None,
formatter: Optional[ConversationFormatter] = None,
model_evaluation_method: Optional[Literal["get_ppl", "get_prob", "generation", "user_defined"]] = None,
split: Optional[bool] = None,
**deprecated_kwargs
):
super().__init__(messages, conversation_id, **deprecated_kwargs)
self.num_turns = 1
self.num_shots = 0
self.num_options = 1
self.mt_users = []
self.num_turns = num_turns
self.num_shots = num_shots
self.num_options = num_options
self.mt_users = [] if multi_turn_users is None else multi_turn_users
self.formatter = formatter
self.model_evaluation_method = model_evaluation_method
self.split = split

@classmethod
def from_chat(cls, *, user: Optional[str] = None, assistant: Optional[str] = None) -> "Conversation":
Expand Down
Loading

0 comments on commit aefdee9

Please sign in to comment.