|
package backend |
|
|
|
import ( |
|
"fmt" |
|
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
|
"github.com/mudler/LocalAI/pkg/grpc" |
|
model "github.com/mudler/LocalAI/pkg/model" |
|
) |
|
|
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { |
|
modelFile := backendConfig.Model |
|
|
|
grpcOpts := gRPCModelOpts(backendConfig) |
|
|
|
var inferenceModel interface{} |
|
var err error |
|
|
|
opts := modelOpts(backendConfig, appConfig, []model.Option{ |
|
model.WithLoadGRPCLoadModelOpts(grpcOpts), |
|
model.WithThreads(uint32(*backendConfig.Threads)), |
|
model.WithAssetDir(appConfig.AssetsDestination), |
|
model.WithModel(modelFile), |
|
model.WithContext(appConfig.Context), |
|
}) |
|
|
|
if backendConfig.Backend == "" { |
|
inferenceModel, err = loader.GreedyLoader(opts...) |
|
} else { |
|
opts = append(opts, model.WithBackendString(backendConfig.Backend)) |
|
inferenceModel, err = loader.BackendLoader(opts...) |
|
} |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
var fn func() ([]float32, error) |
|
switch model := inferenceModel.(type) { |
|
case grpc.Backend: |
|
fn = func() ([]float32, error) { |
|
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath) |
|
if len(tokens) > 0 { |
|
embeds := []int32{} |
|
|
|
for _, t := range tokens { |
|
embeds = append(embeds, int32(t)) |
|
} |
|
predictOptions.EmbeddingTokens = embeds |
|
|
|
res, err := model.Embeddings(appConfig.Context, predictOptions) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return res.Embeddings, nil |
|
} |
|
predictOptions.Embeddings = s |
|
|
|
res, err := model.Embeddings(appConfig.Context, predictOptions) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return res.Embeddings, nil |
|
} |
|
default: |
|
fn = func() ([]float32, error) { |
|
return nil, fmt.Errorf("embeddings not supported by the backend") |
|
} |
|
} |
|
|
|
return func() ([]float32, error) { |
|
embeds, err := fn() |
|
if err != nil { |
|
return embeds, err |
|
} |
|
|
|
for i := len(embeds) - 1; i >= 0; i-- { |
|
if embeds[i] == 0.0 { |
|
embeds = embeds[:i] |
|
} else { |
|
break |
|
} |
|
} |
|
return embeds, nil |
|
}, nil |
|
} |
|
|