Update app_flash.py
Browse files- app_flash.py +2 -2
app_flash.py
CHANGED
|
@@ -116,7 +116,7 @@ def push_flashpack_model_to_hf(model, hf_repo: str):
|
|
| 116 |
# ============================================================
|
| 117 |
def train_flashpack_model(
|
| 118 |
dataset_name: str = "rahul7star/prompt-enhancer-dataset",
|
| 119 |
-
max_encode: int =
|
| 120 |
hidden_dim: int = 1024,
|
| 121 |
push_to_hub: bool = True,
|
| 122 |
hf_repo: str = "rahul7star/FlashPack"
|
|
@@ -152,7 +152,7 @@ def train_flashpack_model(
|
|
| 152 |
# Loss & optimizer
|
| 153 |
criterion = nn.CosineSimilarity(dim=1)
|
| 154 |
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 155 |
-
max_epochs =
|
| 156 |
batch_size = 32
|
| 157 |
n = short_embeddings.shape[0]
|
| 158 |
|
|
|
|
| 116 |
# ============================================================
|
| 117 |
def train_flashpack_model(
|
| 118 |
dataset_name: str = "rahul7star/prompt-enhancer-dataset",
|
| 119 |
+
max_encode: int = 1000,
|
| 120 |
hidden_dim: int = 1024,
|
| 121 |
push_to_hub: bool = True,
|
| 122 |
hf_repo: str = "rahul7star/FlashPack"
|
|
|
|
| 152 |
# Loss & optimizer
|
| 153 |
criterion = nn.CosineSimilarity(dim=1)
|
| 154 |
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 155 |
+
max_epochs = 50
|
| 156 |
batch_size = 32
|
| 157 |
n = short_embeddings.shape[0]
|
| 158 |
|