Deepinfra / inference.py
API-Handler's picture
Upload 5 files
4e2263c verified
raw
history blame
4.94 kB
import requests
import json
from typing import Union, Dict, Generator
import time
class ChatCompletionTester:
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.endpoint = f"{base_url}/chat/completions"
def create_test_payload(self, stream: bool = False) -> Dict:
"""Create a sample payload for testing"""
return {
"model": "mistralai/Mixtral-8x22B-Instruct-v0.1",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"}
],
"temperature": 0.7,
"max_tokens": 4096,
"stream": stream
}
def test_non_streaming(self) -> Union[Dict, None]:
"""Test non-streaming response"""
print("\n=== Testing Non-Streaming Response ===")
try:
payload = self.create_test_payload(stream=False)
print("Sending request...")
response = requests.post(
self.endpoint,
json=payload,
headers={"Content-Type": "application/json"}
)
if response.status_code == 200:
result = response.json()
content = result['choices'][0]['message']['content']
print("\nResponse received successfully!")
print(f"Content: {content}")
return result
else:
print(f"Error: Status code {response.status_code}")
print(f"Response: {response.text}")
return None
except Exception as e:
print(f"Error during non-streaming test: {str(e)}")
return None
def test_streaming(self) -> Union[str, None]:
"""Test streaming response"""
print("\n=== Testing Streaming Response ===")
try:
payload = self.create_test_payload(stream=True)
print("Sending request...")
response = requests.post(
self.endpoint,
json=payload,
headers={"Content-Type": "application/json"},
stream=True
)
if response.status_code == 200:
print("\nReceiving streaming response:")
full_response = ""
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
try:
data = json.loads(line[6:])
if data == "[DONE]":
continue
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
print(content, end="", flush=True)
full_response += content
except json.JSONDecodeError:
continue
print("\n\nStreaming completed!")
return full_response
else:
print(f"Error: Status code {response.status_code}")
print(f"Response: {response.text}")
return None
except Exception as e:
print(f"Error during streaming test: {str(e)}")
return None
def run_all_tests(self):
"""Run both streaming and non-streaming tests"""
print("Starting API endpoint tests...")
# Test server connectivity
try:
requests.get(self.base_url)
print("βœ“ Server is accessible")
except requests.exceptions.ConnectionError:
print("βœ— Server is not accessible. Please ensure the FastAPI server is running.")
return
# Run tests with timing
start_time = time.time()
# Test non-streaming
non_streaming_result = self.test_non_streaming()
if non_streaming_result:
print("βœ“ Non-streaming test passed")
else:
print("βœ— Non-streaming test failed")
# Test streaming
streaming_result = self.test_streaming()
if streaming_result:
print("βœ“ Streaming test passed")
else:
print("βœ— Streaming test failed")
end_time = time.time()
print(f"\nAll tests completed in {end_time - start_time:.2f} seconds")
def main():
# Create tester instance
tester = ChatCompletionTester()
# Run all tests
tester.run_all_tests()
if __name__ == "__main__":
main()