|
package main |
|
|
|
|
|
|
|
import ( |
|
bert "github.com/go-skynet/go-bert.cpp" |
|
|
|
"github.com/mudler/LocalAI/pkg/grpc/base" |
|
pb "github.com/mudler/LocalAI/pkg/grpc/proto" |
|
) |
|
|
|
type Embeddings struct { |
|
base.SingleThread |
|
bert *bert.Bert |
|
} |
|
|
|
func (llm *Embeddings) Load(opts *pb.ModelOptions) error { |
|
model, err := bert.New(opts.ModelFile) |
|
llm.bert = model |
|
return err |
|
} |
|
|
|
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
|
|
|
if len(opts.EmbeddingTokens) > 0 { |
|
tokens := []int{} |
|
for _, t := range opts.EmbeddingTokens { |
|
tokens = append(tokens, int(t)) |
|
} |
|
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads))) |
|
} |
|
|
|
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads))) |
|
} |
|
|