|
package startup |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
"os" |
|
"path/filepath" |
|
"strings" |
|
|
|
"github.com/mudler/LocalAI/core/config" |
|
"github.com/mudler/LocalAI/core/gallery" |
|
"github.com/mudler/LocalAI/embedded" |
|
"github.com/mudler/LocalAI/pkg/downloader" |
|
"github.com/mudler/LocalAI/pkg/utils" |
|
"github.com/rs/zerolog/log" |
|
) |
|
|
|
|
|
|
|
|
|
func InstallModels(galleries []config.Gallery, modelLibraryURL string, modelPath string, enforceScan bool, downloadStatus func(string, string, string, float64), models ...string) error { |
|
|
|
var err error |
|
|
|
lib, _ := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath) |
|
|
|
for _, url := range models { |
|
|
|
|
|
if modelLibraryURL != "" { |
|
if lib[url] != "" { |
|
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url]) |
|
url = lib[url] |
|
} |
|
} |
|
|
|
url = embedded.ModelShortURL(url) |
|
uri := downloader.URI(url) |
|
|
|
switch { |
|
case embedded.ExistsInModelsLibrary(url): |
|
modelYAML, e := embedded.ResolveContent(url) |
|
|
|
if e != nil { |
|
log.Error().Err(e).Msg("error resolving model content") |
|
err = errors.Join(err, e) |
|
continue |
|
} |
|
|
|
log.Debug().Msgf("[startup] resolved embedded model: %s", url) |
|
md5Name := utils.MD5(url) |
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" |
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); err != nil { |
|
log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") |
|
err = errors.Join(err, e) |
|
} |
|
case uri.LooksLikeOCI(): |
|
log.Debug().Msgf("[startup] resolved OCI model to download: %s", url) |
|
|
|
|
|
ociName := strings.TrimPrefix(url, downloader.OCIPrefix) |
|
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) |
|
ociName = strings.ReplaceAll(ociName, "/", "__") |
|
ociName = strings.ReplaceAll(ociName, ":", "__") |
|
|
|
|
|
if _, e := os.Stat(filepath.Join(modelPath, ociName)); errors.Is(e, os.ErrNotExist) { |
|
modelDefinitionFilePath := filepath.Join(modelPath, ociName) |
|
e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { |
|
utils.DisplayDownloadFunction(fileName, current, total, percent) |
|
}) |
|
if e != nil { |
|
log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") |
|
err = errors.Join(err, e) |
|
} |
|
} |
|
|
|
log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) |
|
case uri.LooksLikeURL(): |
|
log.Debug().Msgf("[startup] downloading %s", url) |
|
|
|
|
|
fileName, e := uri.FilenameFromUrl() |
|
if e != nil { |
|
log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL") |
|
err = errors.Join(err, e) |
|
continue |
|
} |
|
|
|
modelPath := filepath.Join(modelPath, fileName) |
|
|
|
if e := utils.VerifyPath(fileName, modelPath); e != nil { |
|
log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path") |
|
err = errors.Join(err, e) |
|
continue |
|
} |
|
|
|
|
|
if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) { |
|
e := uri.DownloadFile(modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { |
|
utils.DisplayDownloadFunction(fileName, current, total, percent) |
|
}) |
|
if e != nil { |
|
log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model") |
|
err = errors.Join(err, e) |
|
} |
|
} |
|
default: |
|
if _, e := os.Stat(url); e == nil { |
|
log.Debug().Msgf("[startup] resolved local model: %s", url) |
|
|
|
md5Name := utils.MD5(url) |
|
|
|
modelYAML, e := os.ReadFile(url) |
|
if e != nil { |
|
log.Error().Err(e).Str("filepath", url).Msg("error reading model definition") |
|
err = errors.Join(err, e) |
|
continue |
|
} |
|
|
|
modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" |
|
if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil { |
|
log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") |
|
err = errors.Join(err, e) |
|
} |
|
} else { |
|
|
|
e, found := installModel(galleries, url, modelPath, downloadStatus, enforceScan) |
|
if e != nil && found { |
|
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) |
|
err = errors.Join(err, e) |
|
} else if !found { |
|
log.Warn().Msgf("[startup] failed resolving model '%s'", url) |
|
err = errors.Join(err, fmt.Errorf("failed resolving model '%s'", url)) |
|
} |
|
} |
|
} |
|
} |
|
return err |
|
} |
|
|
|
func installModel(galleries []config.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64), enforceScan bool) (error, bool) { |
|
models, err := gallery.AvailableGalleryModels(galleries, modelPath) |
|
if err != nil { |
|
return err, false |
|
} |
|
|
|
model := gallery.FindModel(models, modelName, modelPath) |
|
if model == nil { |
|
return err, false |
|
} |
|
|
|
if downloadStatus == nil { |
|
downloadStatus = utils.DisplayDownloadFunction |
|
} |
|
|
|
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") |
|
err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus, enforceScan) |
|
if err != nil { |
|
return err, true |
|
} |
|
|
|
return nil, true |
|
} |
|
|