Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral7b custom inference with LMI not working: java.lang.IllegalStateException: Read chunk timeout. #2362

Open
jeremite opened this issue Sep 5, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@jeremite
Copy link

jeremite commented Sep 5, 2024

Description

(A clear and concise description of what the bug is.)
I deployed mistral-7b v0.1 to a sagemaker endpoint with '763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124'; I custmized the handle function and return Output() object but it always gave me java.lang.IllegalStateException: Read chunk timeout. error.

Expected Behavior

(what's the expected behavior?)
I wanted to get my customized result (e.g. {predictions:{"positive":0.8}})

Error Message

(Paste the complete error message, including stack trace.)
`
[WARN ] InferenceRequestHandler - Chunk reading interrupted

java.lang.IllegalStateException: Read chunk timeout.
#011at ai.djl.inference.streaming.ChunkedBytesSupplier.next(ChunkedBytesSupplier.java:79) ~[api-0.29.0.jar:?]
#011at ai.djl.inference.streaming.ChunkedBytesSupplier.nextChunk(ChunkedBytesSupplier.java:93) ~[api-0.29.0.jar:?]
#011at ai.djl.serving.http.InferenceRequestHandler.sendOutput(InferenceRequestHandler.java:414) ~[serving-0.29.0.jar:?]
#011at ai.djl.serving.http.InferenceRequestHandler.lambda$runJob$5(InferenceRequestHandler.java:309) ~[serving-0.29.0.jar:?]
#011at java.base/java.util.concurrent.CompletableFuture.uniWhenComplete(CompletableFuture.java:863) [?:?]
#011at java.base/java.util.concurrent.CompletableFuture$UniWhenComplete.tryFire(CompletableFuture.java:841) [?:?]
#011at java.base/java.util.concurrent.CompletableFuture$Completion.exec(CompletableFuture.java:483) [?:?]
#011at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:373) [?:?]
#011at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1182) [?:?]
#011at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1655) [?:?]
#011at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1622) [?:?]
#011at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:165) [?:?]

`

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
My inference code (sorry i put many print statement for debugging): `
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from djl_python import Input, Output

model = None
tokenizer = None

def get_model(model_name):
#model_name = properties['model_id']
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer

def get_test_prompt(sample):
"""
get prompt for test data to generate template
"""
text_row = f""""You are given a chat transcript that between customer service and customer: {sample}.
Can you help analyze the customer's sentiment from the transcript? please only answer with Positive or Negative! Answer:"""
return text_row

def inference(text):
text_data = get_test_prompt(text)
input_ids = tokenizer.encode(text_data, return_tensors='pt')
output = model.generate(input_ids, do_sample=True, top_k=2, max_new_tokens=1, pad_token_id = 11)
next_word_logits = model(input_ids).logits[:, -1, :]

probs = torch.softmax(next_word_logits, dim=-1)
top_probs, top_indices = torch.topk(probs, 2)

mp = {"pos":"positive", 'positive':'positive','posit':"positive", \
      'neg':'negative', 'negative':'negative','negat':"negative"}
for i, (ix, p) in enumerate(zip(top_indices, top_probs)):
    next_word = tokenizer.decode(ix)
    next_word = next_word.split()
    prob_dict = {}
    for word, p in zip(next_word, p):
        if word.lower() in mp:
            prob_dict[word.lower()] = p.item()
#prob_dict: {'Pos': 0.6447348594665527, 'Neg': 0.1729944795370102}


prob_dict = {mp[k]:v for k, v in prob_dict.items()}

if "positive" in prob_dict and "negative" in prob_dict:
    sentiment = ["positive", prob_dict["positive"]] if prob_dict["positive"] > prob_dict["negative"] \
    else ["negative", prob_dict["negative"]]
elif "positive" in prob_dict:
    sentiment = ["positive", prob_dict["positive"]]
elif "negative" in prob_dict:
    sentiment = ["negative", prob_dict["negative"]]
else:
    sentiment = ["no_sentiment", 0]

output = Output()
output.add_as_json([{'predictions':{sentiment[0].capitalize():f"{sentiment[1]:.2f}"}}])
print(f"out is: {output}")
return output

def input_fn(input_data: Input):
data = input_data.get_as_json()
text = data["feedbackEvent_customerFeedbackText"]
return text

def handle(inputs: Input) -> None:
global model, tokenizer

if model == None:
    print(f"what we have: {inputs.get_properties()}")
    properties = inputs.get_properties()
    model_name = properties['model_id']
    print(f"model name: {model_name}")
    model, tokenizer = get_model(model_name)

if inputs.is_empty():
    # Model server makes an empty call to warmup the model on startup
    return None

print(f"we have inputs: {inputs} and {inputs.get_as_json()}")
text = input_fn(inputs)
print(f"we get text: {text}")
o = Output()
print(f"out is: {o}")
o = o.add(inputs.get_as_json(), key="data")
print(f"out2 is: {o}")
return o#Output().add(inputs.get_as_json(),key=)#inference(text)

and my serving.properties:engine=Python
option.model_id=s3://tattletale-multistacking-all-prod-execution/cs/mistral-7b-1
option.tensor_parallel_degree=1
option.max_rolling_batch_size=16
option.rolling_batch=auto
option.enable_streaming=true
option.entryPoint=inference.py`

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

  1. I tried all the methods from internet
@jeremite jeremite added the bug Something isn't working label Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant