update clip_eval to accept dmx model
Browse files- clip_eval.py +28 -29
clip_eval.py
CHANGED
|
@@ -29,7 +29,6 @@ _CITATION = """
|
|
| 29 |
}
|
| 30 |
"""
|
| 31 |
|
| 32 |
-
|
| 33 |
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 34 |
class DmxClipEval(evaluate.Metric):
|
| 35 |
def _info(self):
|
|
@@ -38,19 +37,17 @@ class DmxClipEval(evaluate.Metric):
|
|
| 38 |
description=_DESCRIPTION,
|
| 39 |
citation=_CITATION,
|
| 40 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 41 |
-
features=
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
),
|
| 49 |
-
],
|
| 50 |
)
|
| 51 |
|
| 52 |
def clip_dataset_evaluator(
|
| 53 |
-
self, model, device,
|
| 54 |
):
|
| 55 |
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
|
| 56 |
if dataset_name == "mscoco":
|
|
@@ -116,34 +113,36 @@ class DmxClipEval(evaluate.Metric):
|
|
| 116 |
}
|
| 117 |
return metrics
|
| 118 |
|
| 119 |
-
def clip_evaluator(self, model, device,
|
| 120 |
metrics = {}
|
| 121 |
-
for
|
| 122 |
metrics.update(
|
| 123 |
-
self.clip_dataset_evaluator(model, device,
|
| 124 |
)
|
| 125 |
return metrics
|
| 126 |
|
| 127 |
-
def _compute(self,
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
actual_n_examples = n_examples[0]
|
| 132 |
|
| 133 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
metrics = {}
|
| 139 |
-
for
|
| 140 |
dataset_metrics = self.clip_dataset_evaluator(
|
| 141 |
-
model=
|
| 142 |
device=device,
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
n_examples=actual_n_examples,
|
| 146 |
)
|
| 147 |
metrics.update(dataset_metrics)
|
| 148 |
|
| 149 |
-
return metrics
|
|
|
|
| 29 |
}
|
| 30 |
"""
|
| 31 |
|
|
|
|
| 32 |
@add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 33 |
class DmxClipEval(evaluate.Metric):
|
| 34 |
def _info(self):
|
|
|
|
| 37 |
description=_DESCRIPTION,
|
| 38 |
citation=_CITATION,
|
| 39 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 40 |
+
features=datasets.Features(
|
| 41 |
+
{
|
| 42 |
+
"model": datasets.Value("string"),
|
| 43 |
+
"dataset_names": datasets.Value("string"),
|
| 44 |
+
"n_examples": datasets.Value("int32"),
|
| 45 |
+
}
|
| 46 |
+
),
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
|
| 49 |
def clip_dataset_evaluator(
|
| 50 |
+
self, model, device, dataset_name="mscoco", n_examples=-1
|
| 51 |
):
|
| 52 |
processor = CLIPProcessor.from_pretrained(model.config._name_or_path)
|
| 53 |
if dataset_name == "mscoco":
|
|
|
|
| 113 |
}
|
| 114 |
return metrics
|
| 115 |
|
| 116 |
+
def clip_evaluator(self, model, device, n_examples=-1):
|
| 117 |
metrics = {}
|
| 118 |
+
for dataset_name in ["mscoco", "flickr"]:
|
| 119 |
metrics.update(
|
| 120 |
+
self.clip_dataset_evaluator(model, device, dataset_name, n_examples)
|
| 121 |
)
|
| 122 |
return metrics
|
| 123 |
|
| 124 |
+
def _compute(self, model, dataset_names, n_examples, **kwargs):
|
| 125 |
+
dataset = dataset_names[0]
|
| 126 |
+
num_examples = n_examples[0]
|
| 127 |
+
model_input = model[0]
|
|
|
|
| 128 |
|
| 129 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 130 |
+
|
| 131 |
+
if isinstance(model_input, str):
|
| 132 |
+
actual_model = CLIPModel.from_pretrained(model_input).to(device)
|
| 133 |
+
else:
|
| 134 |
+
actual_model = model_input
|
| 135 |
+
|
| 136 |
+
datasets_to_evaluate = [dataset]
|
| 137 |
+
|
| 138 |
metrics = {}
|
| 139 |
+
for ds_name in datasets_to_evaluate:
|
| 140 |
dataset_metrics = self.clip_dataset_evaluator(
|
| 141 |
+
model=actual_model,
|
| 142 |
device=device,
|
| 143 |
+
dataset_name=ds_name,
|
| 144 |
+
n_examples=num_examples,
|
|
|
|
| 145 |
)
|
| 146 |
metrics.update(dataset_metrics)
|
| 147 |
|
| 148 |
+
return metrics
|