| from transformers import PretrainedConfig | |
| from typing import List | |
| class ReconResNetConfig(PretrainedConfig): | |
| model_type = "ReconResNet" | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| out_channels=1, | |
| res_blocks=14, | |
| starting_nfeatures=64, | |
| updown_blocks=2, | |
| is_relu_leaky=True, | |
| do_batchnorm=False, | |
| res_drop_prob=0.2, | |
| is_replicatepad=0, | |
| out_act="sigmoid", | |
| forwardV=0, | |
| upinterp_algo='convtrans', | |
| post_interp_convtrans=False, | |
| is3D=False, | |
| **kwargs): | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.res_blocks = res_blocks | |
| self.starting_nfeatures = starting_nfeatures | |
| self.updown_blocks = updown_blocks | |
| self.is_relu_leaky = is_relu_leaky | |
| self.do_batchnorm = do_batchnorm | |
| self.res_drop_prob = res_drop_prob | |
| self.is_replicatepad = is_replicatepad | |
| self.out_act = out_act | |
| self.forwardV = forwardV | |
| self.upinterp_algo = upinterp_algo | |
| self.post_interp_convtrans = post_interp_convtrans | |
| self.is3D = is3D | |
| super().__init__(**kwargs) | |