Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from data import read_input | |
| from aux import log | |
| import sys | |
| from collections import defaultdict | |
| from evaluate import load as load_metric | |
| SMUGRI_RES = { | |
| 'high': set("Estonian,English,Russian,Finnish,Hungarian,Latvian,German,Swedish,Norwegian,French".split(",")), | |
| 'mid': set("Komi,Komi-Zyrian,Northern Sami,Meadow Mari".split(",")), | |
| 'low': set("Udmurt,Proper Karelian,Southern Sami,Livvi,Veps,Moksha,Erzya,Lule Sami,Võro,Hill Mari," | |
| "Komi-Permyak,Inari Sami".split(",")), | |
| 'xlow': set("Ludian,Livonian,Izhorian,Votic,Shur Khanty,Skolt Sami,Meänkieli," | |
| "Sred Khanty,Surgut Khanty,Priur Khanty,Vakh Khanty,Unk Khanty," | |
| "Pite Sami,Mansi,Kazym Khanty,Kven,Ume Sami,Kildin Sami".split(",")) | |
| } | |
| def _gen_lang(lang): | |
| return lang.split(",")[0] | |
| def _hi_or_lo_lang(lang): | |
| gen_lang = _gen_lang(lang) | |
| for k, v in SMUGRI_RES.items(): | |
| if gen_lang in v: | |
| return k | |
| log(f"Unrecognized language: {lang} / {gen_lang}") | |
| return '?' | |
| def _collect_lp_pairs(json_inputs, str_outputs): | |
| sets_by_lp = defaultdict(list) | |
| for i, o in zip(json_inputs, str_outputs): | |
| ref = i["tgt_segm"] | |
| hyp = o | |
| det_lp = 'detailed: ' + i["src_lang"] + " -> " + i["tgt_lang"] | |
| gen_lp = 'general: ' + _gen_lang(i["src_lang"]) + " -> " + _gen_lang(i["tgt_lang"]) | |
| hilo_lp = 'classes: ' + _hi_or_lo_lang(i["src_lang"]) + " -> " + _hi_or_lo_lang(i["tgt_lang"]) | |
| sets_by_lp[det_lp].append((hyp, ref)) | |
| sets_by_lp[gen_lp].append((hyp, ref)) | |
| sets_by_lp[hilo_lp].append((hyp, ref)) | |
| return sets_by_lp | |
| def compute_metrics(json_inputs, str_outputs): | |
| sets_by_lp = _collect_lp_pairs(json_inputs, str_outputs) | |
| metric = load_metric("chrf") | |
| result = [] | |
| for lp in sets_by_lp: | |
| preds, outputs = zip(*sets_by_lp[lp]) | |
| metric_value = metric.compute(predictions=preds, references=outputs) | |
| result.append((lp, metric_value, len(preds))) | |
| return result | |
| def avoid_global_scope(): | |
| json_inputs = read_input(sys.argv[1], "json") | |
| str_outputs = read_input(sys.argv[2], "json") | |
| lp_metric_dict = compute_metrics(json_inputs, str_outputs) | |
| for lp, metric, size in lp_metric_dict: | |
| print(f"{lp}: {metric['score']:.2f} ({size})") | |
| if __name__ == "__main__": | |
| avoid_global_scope() |