0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-21 00:52:43 -05:00

fix(ml): handle missing context_length field (#6695)

* handle missing `context_length` field

* specify list type
This commit is contained in:
Mert 2024-01-27 19:50:50 -05:00 committed by GitHub
parent e2ac019f51
commit 2249f7d42a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 12 additions and 9 deletions

View file

@ -70,6 +70,8 @@ LOG_LEVELS: dict[str, int] = {
settings = Settings()
log_settings = LogSettings()
LOG_LEVEL = LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO)
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
@ -81,6 +83,7 @@ class CustomRichHandler(RichHandler):
console=console,
rich_tracebacks=True,
tracebacks_suppress=[*self.excluded, concurrent.futures],
tracebacks_show_locals=LOG_LEVEL == logging.DEBUG,
)
# hack to exclude certain modules from rich tracebacks
@ -96,7 +99,7 @@ class CustomRichHandler(RichHandler):
log = logging.getLogger("ml.log")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
log.setLevel(LOG_LEVEL)
# patches this issue https://github.com/encode/uvicorn/discussions/1803

View file

@ -144,11 +144,11 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
def _load(self) -> None:
super()._load()
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
context_length: int = text_cfg.get("context_length", 77)
pad_token: int = self.tokenizer_cfg["pad_token"]
context_length = self.model_cfg["text_cfg"]["context_length"]
pad_token = self.tokenizer_cfg["pad_token"]
size = self.preprocess_cfg["size"]
size: list[int] | int = self.preprocess_cfg["size"]
self.size = size[0] if isinstance(size, list) else size
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
@ -157,7 +157,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
pad_id = self.tokenizer.token_to_id(pad_token)
pad_id: int = self.tokenizer.token_to_id(pad_token)
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
self.tokenizer.enable_truncation(max_length=context_length)
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")

View file

@ -29,6 +29,9 @@ _OPENCLIP_MODELS = {
"ViT-L-14-quickgelu__dfn2b",
"ViT-H-14-quickgelu__dfn5b",
"ViT-H-14-378-quickgelu__dfn5b",
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
"nllb-clip-base-siglip__v1",
"nllb-clip-large-siglip__v1",
}
@ -37,9 +40,6 @@ _MCLIP_MODELS = {
"XLM-Roberta-Large-Vit-B-32",
"XLM-Roberta-Large-Vit-B-16Plus",
"XLM-Roberta-Large-Vit-L-14",
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
"nllb-clip-base-siglip__v1",
"nllb-clip-large-siglip__v1",
}