-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[refactor] move hf model download logic into seperate python file; im…
…plement auto-download of safetensors in preference to pickles.
- Loading branch information
Showing
7 changed files
with
105 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import os | ||
from huggingface_hub import snapshot_download | ||
|
||
# check if >safetensors are present in the model repo | ||
def check_safetensors_present(model_id, revision): | ||
from huggingface_hub import HfApi | ||
# Authenticate with HF token | ||
api = HfApi() | ||
files = api.list_repo_files(repo_id=model_id, | ||
revision=revision) | ||
for file in files: | ||
_, extension = os.path.splitext(file) | ||
if extension == '.safetensors': | ||
return True | ||
return False | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--repo_id', type=str, default=None) | ||
parser.add_argument('--revision', type=str, default=None) | ||
parser.add_argument('--allow_patterns', type=str, default=None) | ||
parser.add_argument('--cache_dir', type=str, default=None) | ||
args = parser.parse_args() | ||
|
||
repo_id = args.repo_id | ||
assert args.repo_id, "Please provide a repo_id" | ||
|
||
revision = args.revision if args.revision else "main" | ||
cache_dir = args.cache_dir if args.cache_dir else None | ||
allow_patterns = args.allow_patterns | ||
|
||
if not allow_patterns: | ||
# Define allowed file patterns for config, tokenizer, and model weights | ||
has_safetensors = check_safetensors_present(repo_id, revision) | ||
# download safetensors if present, otherwise download pickle files | ||
allow_patterns = "*.json,*.safetensors,*.model" if has_safetensors else "*.json,*.bin,*.pth,*.model" | ||
|
||
path = snapshot_download(args.repo_id, | ||
revision=revision, | ||
cache_dir=cache_dir, | ||
allow_patterns=allow_patterns.split(",")) | ||
# print download path | ||
print(path) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/bin/bash | ||
|
||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" | ||
|
||
# Construct the arguments to pass to the 'scalellm' command | ||
ARGS="" | ||
|
||
# Check if HF_MODEL_ID is defined; if so, download the model from the Hugging Face hub | ||
if [ -n "$HF_MODEL_ID" ]; then | ||
echo "Downloading model from the Hugging Face hub for model id: "$HF_MODEL_ID" and revision: "$HF_MODEL_REVISION"" | ||
|
||
MODEL_PATH=$(python3 ${SCRIPT_DIR}/download_hf_models.py --repo_id "$HF_MODEL_ID" --revision "$HF_MODEL_REVISION" --cache_dir "$HF_MODEL_CACHE_DIR" --allow_patterns "$HF_MODEL_ALLOW_PATTERN") | ||
# return if error | ||
if [ $? -ne 0 ]; then | ||
echo "Error downloading model from the Hugging Face hub for model id: "$HF_MODEL_ID" and revision: "$HF_MODEL_REVISION"" | ||
exit 1 | ||
fi | ||
ARGS+=" --model_path "$MODEL_PATH" --model_id "$HF_MODEL_ID"" | ||
fi | ||
|
||
# Run the 'scalellm' with the specified arguments | ||
$SCRIPT_DIR/../build/src/server/scalellm $ARGS "$@" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters