| from transformers import PreTrainedModel | |
| from .ReconResNetBase import ReconResNetBase | |
| from .ReconResNetConfig import ReconResNetConfig | |
| class ReconResNet(PreTrainedModel): | |
| config_class = ReconResNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = ReconResNetBase( | |
| in_channels=config.in_channels, | |
| out_channels=config.out_channels, | |
| res_blocks=config.res_blocks, | |
| starting_nfeatures=config.starting_nfeatures, | |
| updown_blocks=config.updown_blocks, | |
| is_relu_leaky=config.is_relu_leaky, | |
| do_batchnorm=config.do_batchnorm, | |
| res_drop_prob=config.res_drop_prob, | |
| is_replicatepad=config.is_replicatepad, | |
| out_act=config.out_act, | |
| forwardV=config.forwardV, | |
| upinterp_algo=config.upinterp_algo, | |
| post_interp_convtrans=config.post_interp_convtrans, | |
| is3D=config.is3D) | |
| def forward(self, x): | |
| return self.model(x) |