saujasv commited on
Commit
9154370
·
1 Parent(s): 2a73d1c

fix deps and make cuda optional

Browse files
Files changed (2) hide show
  1. handler.py +3 -6
  2. requirements.txt +1 -1
handler.py CHANGED
@@ -4,7 +4,7 @@ from greenery.parse import NoMatch
4
  from listener import Listener, ListenerOutput
5
  import time
6
  import json
7
- import logging
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
@@ -14,14 +14,11 @@ class EndpointHandler:
14
  "top_p": 0.9,
15
  "num_return_sequences": 500,
16
  "num_beams": 1
17
- }, device="cuda")
18
- logging.info("Loaded model with custom endpoint handler")
19
 
20
  def __call__(self, data):
21
  # get inputs
22
- inp = json.loads(data.pop("inputs", None))
23
- logging.info(str(data))
24
- logging.info(str(data["inputs"]))
25
  spec = inp["spec"]
26
  true_program = inp["true_program"]
27
 
 
4
  from listener import Listener, ListenerOutput
5
  import time
6
  import json
7
+ import torch
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
 
14
  "top_p": 0.9,
15
  "num_return_sequences": 500,
16
  "num_beams": 1
17
+ }, device="cuda" if torch.cuda.is_available() else "cpu")
 
18
 
19
  def __call__(self, data):
20
  # get inputs
21
+ inp = data.pop("inputs", None)
 
 
22
  spec = inp["spec"]
23
  true_program = inp["true_program"]
24
 
requirements.txt CHANGED
@@ -21,6 +21,6 @@ sympy==1.12
21
  tokenizers==0.13.3
22
  torch==2.0.1
23
  tqdm==4.66.1
24
- transformers==4.33.0
25
  typing_extensions==4.7.1
26
  urllib3==2.0.4
 
21
  tokenizers==0.13.3
22
  torch==2.0.1
23
  tqdm==4.66.1
24
+ transformers==4.28.1
25
  typing_extensions==4.7.1
26
  urllib3==2.0.4