Upload FlaxTransformerLMForCausalLM
Browse files- flax_model.msgpack +1 -1
- modeling_transformerlm_flax.py +3 -0
flax_model.msgpack
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 524522413
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f43dc830c806b64d6a77027a61d16bd2fcbe896c799d5dbba0a81b9e7f26fc8b
|
| 3 |
size 524522413
|
modeling_transformerlm_flax.py
CHANGED
|
@@ -404,6 +404,9 @@ class FlaxTransformerLMPreTrainedModel(FlaxPreTrainedModel):
|
|
| 404 |
last_logits, last_cache = last
|
| 405 |
lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
| 406 |
|
|
|
|
|
|
|
|
|
|
| 407 |
if not return_dict:
|
| 408 |
outputs = (lm_logits,) + (last_cache,)
|
| 409 |
else:
|
|
|
|
| 404 |
last_logits, last_cache = last
|
| 405 |
lm_logits = jnp.reshape(all_logits, (1, seq_length, vcab_size))
|
| 406 |
|
| 407 |
+
if input_ids.shape[1] > 1:
|
| 408 |
+
lm_logits = lm_logits[:, 1:, :] # Ignore leading zeros in prompts
|
| 409 |
+
|
| 410 |
if not return_dict:
|
| 411 |
outputs = (lm_logits,) + (last_cache,)
|
| 412 |
else:
|