|
package grpc |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"log" |
|
"net" |
|
|
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto" |
|
"google.golang.org/grpc" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type server struct { |
|
pb.UnimplementedBackendServer |
|
llm LLM |
|
} |
|
|
|
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { |
|
return newReply("OK"), nil |
|
} |
|
|
|
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
embeds, err := s.llm.Embeddings(in) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return &pb.EmbeddingResult{Embeddings: embeds}, nil |
|
} |
|
|
|
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
err := s.llm.Load(in) |
|
if err != nil { |
|
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err |
|
} |
|
return &pb.Result{Message: "Loading succeeded", Success: true}, nil |
|
} |
|
|
|
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
result, err := s.llm.Predict(in) |
|
return newReply(result), err |
|
} |
|
|
|
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
err := s.llm.GenerateImage(in) |
|
if err != nil { |
|
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err |
|
} |
|
return &pb.Result{Message: "Image generated", Success: true}, nil |
|
} |
|
|
|
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
err := s.llm.TTS(in) |
|
if err != nil { |
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err |
|
} |
|
return &pb.Result{Message: "Audio generated", Success: true}, nil |
|
} |
|
|
|
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
result, err := s.llm.AudioTranscription(in) |
|
if err != nil { |
|
return nil, err |
|
} |
|
tresult := &pb.TranscriptResult{} |
|
for _, s := range result.Segments { |
|
tks := []int32{} |
|
for _, t := range s.Tokens { |
|
tks = append(tks, int32(t)) |
|
} |
|
tresult.Segments = append(tresult.Segments, |
|
&pb.TranscriptSegment{ |
|
Text: s.Text, |
|
Id: int32(s.Id), |
|
Start: int64(s.Start), |
|
End: int64(s.End), |
|
Tokens: tks, |
|
}) |
|
} |
|
|
|
tresult.Text = result.Text |
|
return tresult, nil |
|
} |
|
|
|
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
resultChan := make(chan string) |
|
|
|
done := make(chan bool) |
|
go func() { |
|
for result := range resultChan { |
|
stream.Send(newReply(result)) |
|
} |
|
done <- true |
|
}() |
|
|
|
err := s.llm.PredictStream(in, resultChan) |
|
<-done |
|
|
|
return err |
|
} |
|
|
|
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
res, err := s.llm.TokenizeString(in) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
castTokens := make([]int32, len(res.Tokens)) |
|
for i, v := range res.Tokens { |
|
castTokens[i] = int32(v) |
|
} |
|
|
|
return &pb.TokenizationResponse{ |
|
Length: int32(res.Length), |
|
Tokens: castTokens, |
|
}, err |
|
} |
|
|
|
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) { |
|
res, err := s.llm.Status() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return &res, nil |
|
} |
|
|
|
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
err := s.llm.StoresSet(in) |
|
if err != nil { |
|
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err |
|
} |
|
return &pb.Result{Message: "Set key", Success: true}, nil |
|
} |
|
|
|
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
err := s.llm.StoresDelete(in) |
|
if err != nil { |
|
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err |
|
} |
|
return &pb.Result{Message: "Deleted key", Success: true}, nil |
|
} |
|
|
|
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
res, err := s.llm.StoresGet(in) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &res, nil |
|
} |
|
|
|
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) { |
|
if s.llm.Locking() { |
|
s.llm.Lock() |
|
defer s.llm.Unlock() |
|
} |
|
res, err := s.llm.StoresFind(in) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return &res, nil |
|
} |
|
|
|
func StartServer(address string, model LLM) error { |
|
lis, err := net.Listen("tcp", address) |
|
if err != nil { |
|
return err |
|
} |
|
s := grpc.NewServer() |
|
pb.RegisterBackendServer(s, &server{llm: model}) |
|
log.Printf("gRPC Server listening at %v", lis.Addr()) |
|
if err := s.Serve(lis); err != nil { |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func RunServer(address string, model LLM) (func() error, error) { |
|
lis, err := net.Listen("tcp", address) |
|
if err != nil { |
|
return nil, err |
|
} |
|
s := grpc.NewServer() |
|
pb.RegisterBackendServer(s, &server{llm: model}) |
|
log.Printf("gRPC Server listening at %v", lis.Addr()) |
|
if err = s.Serve(lis); err != nil { |
|
return func() error { |
|
return lis.Close() |
|
}, err |
|
} |
|
|
|
return func() error { |
|
s.GracefulStop() |
|
return nil |
|
}, nil |
|
} |
|
|