| from transformers import PretrainedConfig | |
| from typing import List | |
| class UNet3DConfig(PretrainedConfig): | |
| model_type = "UNet" | |
| def __init__( | |
| self, | |
| in_ch=1, | |
| out_ch=1, | |
| init_features=64, | |
| **kwargs): | |
| self.in_ch = in_ch | |
| self.out_ch = out_ch | |
| self.init_features = init_features | |
| super().__init__(**kwargs) | |
| class UNetMSS3DConfig(PretrainedConfig): | |
| model_type = "UNetMSS" | |
| def __init__( | |
| self, | |
| in_ch=1, | |
| out_ch=1, | |
| init_features=64, | |
| **kwargs): | |
| self.in_ch = in_ch | |
| self.out_ch = out_ch | |
| self.init_features = init_features | |
| super().__init__(**kwargs) |