| import argparse | |
| from transformers import RobertaForMaskedLM | |
| def convert_flax_model_to_torch(flax_model_path: str, torch_model_path: str = "./"): | |
| """ | |
| Converts Flax model weights to PyTorch weights. | |
| """ | |
| model = RobertaForMaskedLM.from_pretrained(flax_model_path, from_flax=True) | |
| model.save_pretrained(torch_model_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Flax to Pytorch model coversion") | |
| parser.add_argument( | |
| "--flax_model_path", type=str, default="flax-community/roberta-base-mr", help="Flax model path" | |
| ) | |
| parser.add_argument("--torch_model_path", type=str, default="./", help="PyTorch model path") | |
| args = parser.parse_args() | |
| convert_flax_model_to_torch(args.flax_model_path, args.torch_model_path) | |