Spaces:
Sleeping
Sleeping
| import time | |
| import copy | |
| import torch | |
| from data_loaders import load_dataset, get_dataset_sizes, get_dataloaders | |
| def train_model(model, criterion, optimizer, scheduler, num_epochs, batch_size, device): | |
| since = time.time() | |
| # Load data | |
| data_set = load_dataset() | |
| dataset_sizes = get_dataset_sizes(data_set) | |
| dataloaders = get_dataloaders(data_set, batch_size) | |
| best_model_wts = copy.deepcopy(model.state_dict()) | |
| best_acc = 0.0 | |
| best_loss = 10000.0 # Large arbitrary number | |
| best_acc_train = 0.0 | |
| best_loss_train = 10000.0 # Large arbitrary number | |
| print("Training started:") | |
| time_elapsed = time.time() - since | |
| print("Data Loading Completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) | |
| for epoch in range(num_epochs): | |
| # Each epoch has a training and validation phase | |
| for phase in ["train", "validation"]: | |
| if phase == "train": | |
| # Set model to training mode | |
| model.train() | |
| else: | |
| # Set model to evaluate mode | |
| model.eval() | |
| running_loss = 0.0 | |
| running_corrects = 0 | |
| # Iterate over data. | |
| n_batches = dataset_sizes[phase] // batch_size | |
| it = 0 | |
| for inputs, labels in dataloaders[phase]: | |
| since_batch = time.time() | |
| batch_size_ = len(inputs) | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| # Track/compute gradient and make an optimization step only when training | |
| with torch.set_grad_enabled(phase == "train"): | |
| outputs = model(inputs) | |
| _, preds = torch.max(outputs, 1) | |
| loss = criterion(outputs, labels) | |
| if phase == "train": | |
| loss.backward() | |
| optimizer.step() | |
| # Print iteration results | |
| running_loss += loss.item() * batch_size_ | |
| batch_corrects = torch.sum(preds == labels.data).item() | |
| running_corrects += batch_corrects | |
| print( | |
| "Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}".format( | |
| phase, | |
| epoch + 1, | |
| num_epochs, | |
| it + 1, | |
| n_batches + 1, | |
| time.time() - since_batch, | |
| ), | |
| end="\r", | |
| flush=True, | |
| ) | |
| it += 1 | |
| # Print epoch results | |
| epoch_loss = running_loss / dataset_sizes[phase] | |
| epoch_acc = running_corrects / dataset_sizes[phase] | |
| print( | |
| "Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f} ".format( | |
| "train" if phase == "train" else "validation ", | |
| epoch + 1, | |
| num_epochs, | |
| epoch_loss, | |
| epoch_acc, | |
| ) | |
| ) | |
| # Check if this is the best model wrt previous epochs | |
| if phase == "validation" and epoch_acc > best_acc: | |
| best_acc = epoch_acc | |
| best_model_wts = copy.deepcopy(model.state_dict()) | |
| if phase == "validation" and epoch_loss < best_loss: | |
| best_loss = epoch_loss | |
| if phase == "train" and epoch_acc > best_acc_train: | |
| best_acc_train = epoch_acc | |
| if phase == "train" and epoch_loss < best_loss_train: | |
| best_loss_train = epoch_loss | |
| # Update learning rate | |
| if phase == "train": | |
| scheduler.step() | |
| # Print final results | |
| model.load_state_dict(best_model_wts) | |
| time_elapsed = time.time() - since | |
| print("Training Completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) | |
| print("Best test loss: {:.4f} | Best test accuracy: {:.4f}".format(best_loss, best_acc)) | |
| return model | |