Skip to content

Commit

Permalink
skip examples in legacy format
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed May 24, 2024
1 parent 192c6e5 commit 46708cf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions utilization/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Dataset(torch.utils.data.Dataset, DatasetUtilMixin):
- `tokenizer (Tokenizer)`: The tokenizer used for the dataset.
- `num_shots (int)`: The number of few-shot examples to construct.
- `max_example_tokens (int)`: The maximum number of tokens allowed for the few-shot examples.
- `examples (str)`: The constructed demonstration text.
- `examples (Conversation)`: The constructed demonstration text.
- `evaluation_data (List[Dict])`: The loaded evaluation data.
- `example_data (List[Dict])`: The loaded example data.
- `evaluation_instances (List[str])`: The final formatted instances for evaluation.
Expand Down Expand Up @@ -561,7 +561,11 @@ def construct_instance(
if self.examples is None or self.kate or self.globale:
self.examples = self.construct_examples(instance)

convers.add_(self.examples)
if isinstance(self.examples, Conversation):
convers.add_(self.examples)
else:
# FIXME new example format for quac, squad
logger.warning(f"{self.display_name} has legacy examples format. Skipping the examples.")
option_num = len(instance["options"]) if instance.get("options", None) else 1
if isinstance(instance["source"], list):
if self.model_evaluation_method == "get_ppl":
Expand Down
4 changes: 2 additions & 2 deletions utilization/utils/log_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def wrapper():
merge_by_option = ["option"]
return pd.DataFrame(lines).groupby("index").apply(to_dict(merge, merge_by_option))
except Exception as e:
lines = {k: len(v) for k, v in lines.items()}
lines = {k: getattr(v, "__len__", lambda: None)() for k, v in lines.items()}
logger.warning(f"Failed to log_pgenerate final predictions: {e}\n{lines}")
return None

Expand All @@ -221,7 +221,7 @@ def wrapper():
try:
return pd.DataFrame(lines).groupby("index").apply(to_dict())
except Exception as e:
lines = {k: len(v) for k, v in lines.items()}
lines = {k: getattr(v, "__len__", lambda: None)() for k, v in lines.items()}
logger.warning(f"Failed to generate final predictions: {e}\n{lines}")
return None

Expand Down

0 comments on commit 46708cf

Please sign in to comment.