Skip to content

Commit

Permalink
Merge pull request #68 from PGScatalog/dev
Browse files Browse the repository at this point in the history
v0.4.3 release
  • Loading branch information
nebfield committed Dec 5, 2023
2 parents 5a25766 + ec0739c commit 6da7eb0
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 114 deletions.
11 changes: 11 additions & 0 deletions pgscatalog_utils/ancestry/ancestry_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,25 @@ def ancestry_analysis():
loc_related_ids=args.ref_related, nPCs=maxPCs)
loc_ref_psam = args.psam
reference_df = extract_ref_psam_cols(loc_ref_psam, args.d_ref, reference_df, keepcols=[args.ref_label])
assert reference_df.shape[0] > 100, "Error: too few reference panel samples. This is an arbitrary threshold " \
"for input QC; however, it is inadvisable to run this analysis with limited " \
"reference panel diversity as empirical percentiles are calculated."

loc_target_sscores = args.target_pcs
target_df = read_pcs(loc_pcs=loc_target_sscores, dataset=args.d_target, nPCs=maxPCs)
assert target_df.shape[0] >= 1, "Error: NO target samples found in PCs file."

# Load PGS data & merge with PCA data
pgs = read_pgs(args.scorefile, onlySUM=True)
scorecols = list(pgs.columns)

## There should be perfect target sample overlap
assert all([x in pgs.loc['reference'].index for x in reference_df.index.get_level_values(1)]), \
"Error: PGS data missing for reference samples with PCA data."
reference_df = pd.merge(reference_df, pgs, left_index=True, right_index=True)

assert all([x in pgs.loc[args.d_target].index for x in target_df.index.get_level_values(1)]), \
"Error: PGS data missing for reference samples with PCA data."
target_df = pd.merge(target_df, pgs, left_index=True, right_index=True)
del pgs # clear raw PGS from memory

Expand Down
11 changes: 7 additions & 4 deletions pgscatalog_utils/ancestry/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def read_pcs(loc_pcs: list[str],dataset: str, loc_related_ids=None, nPCs=None):

for i, path in enumerate(loc_pcs):
logger.debug("Reading PCA projection: {}".format(path))
df = pd.read_csv(path, sep='\t')
df = pd.read_csv(path, sep='\t', converters={"IID": str}, header=0)
df['sampleset'] = dataset
df.set_index(['sampleset', 'IID'], inplace=True)

Expand Down Expand Up @@ -46,13 +46,16 @@ def read_pcs(loc_pcs: list[str],dataset: str, loc_related_ids=None, nPCs=None):
IDs_related = [x.strip() for x in infile.readlines()]
proj.loc[proj.index.get_level_values(level=1).isin(IDs_related), 'Unrelated'] = False
else:
proj['Unrelated'] = np.nan
# if unrelated is all nan -> dtype is float64
# if unrelated is only true / false -> dtype is bool
# if unrelated contains None, dtype stays bool, and pd.concat warning disappears
proj['Unrelated'] = None

return proj


def extract_ref_psam_cols(loc_psam, dataset: str, df_target, keepcols=['SuperPop', 'Population']):
psam = pd.read_csv(loc_psam, sep='\t')
psam = pd.read_csv(loc_psam, sep='\t', header=0)

match (psam.columns[0]):
# handle case of #IID -> IID (happens when #FID is present)
Expand All @@ -76,7 +79,7 @@ def read_pgs(loc_aggscore, onlySUM: bool):
:return:
"""
logger.debug('Reading aggregated score data: {}'.format(loc_aggscore))
df = pd.read_csv(loc_aggscore, sep='\t', index_col=['sampleset', 'IID'])
df = pd.read_csv(loc_aggscore, sep='\t', index_col=['sampleset', 'IID'], converters={"IID": str}, header=0)
if onlySUM:
df = df[[x for x in df.columns if x.endswith('_SUM')]]
rn = [x.rstrip('_SUM') for x in df.columns]
Expand Down
54 changes: 37 additions & 17 deletions pgscatalog_utils/ancestry/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,20 @@ def choose_pval_threshold(args):
return set_threshold


def compare_ancestry(ref_df: pd.DataFrame, ref_pop_col: str, target_df: pd.DataFrame, ref_train_col=None, n_pcs=4, method='RandomForest',
covariance_method='EmpiricalCovariance', p_threshold=None):
def get_covariance_method(method_name):
match method_name:
case 'MinCovDet':
covariance_model = MinCovDet()
case 'EmpiricalCovariance':
covariance_model = EmpiricalCovariance()
case _:
assert False, "Invalid covariance method"

return covariance_model


def compare_ancestry(ref_df: pd.DataFrame, ref_pop_col: str, target_df: pd.DataFrame, ref_train_col=None, n_pcs=4,
method='RandomForest', covariance_method='EmpiricalCovariance', p_threshold=None):
"""
Function to compare target sample ancestry to a reference panel with PCA data
:param ref_df: reference dataset
Expand All @@ -52,7 +64,7 @@ def compare_ancestry(ref_df: pd.DataFrame, ref_pop_col: str, target_df: pd.DataF
# Check that datasets have the correct columns
assert method in comparison_method_threshold.keys(), 'comparison method parameter must be Mahalanobis or RF'
if method == 'Mahalanobis':
assert covariance_method in _mahalanobis_methods, 'ovariance estimation method must be MinCovDet or EmpiricalCovariance'
assert covariance_method in _mahalanobis_methods, 'covariance estimation method must be MinCovDet or EmpiricalCovariance'

cols_pcs = ['PC{}'.format(x + 1) for x in range(0, n_pcs)]
assert all([col in ref_df.columns for col in cols_pcs]), \
Expand All @@ -73,14 +85,27 @@ def compare_ancestry(ref_df: pd.DataFrame, ref_pop_col: str, target_df: pd.DataF
else:
ref_train_df = ref_df

# Check if PCs only capture target/reference stratification
# Check outlier-ness of target with regard to the reference PCA space
compare_info = {}
for col_pc in cols_pcs:
mwu_pc = mannwhitneyu(ref_train_df[col_pc], target_df[col_pc])
compare_info[col_pc] = {'U': mwu_pc.statistic, 'pvalue': mwu_pc.pvalue}
if mwu_pc.pvalue < 1e-4:
logger.warning("{} *may* be capturing target/reference stratification (Mann-Whitney p-value={}), "
"use visual inspection of PC plot to confirm".format(col_pc, mwu_pc.pvalue))
pop = 'ALL'
ref_covariance_model = get_covariance_method(covariance_method)
ref_covariance_fit = ref_covariance_model.fit(ref_train_df[cols_pcs])
colname_dist = 'Mahalanobis_dist_{}'.format(pop)
colname_pval = 'Mahalanobis_P_{}'.format(pop)
target_df[colname_dist] = ref_covariance_fit.mahalanobis(target_df[cols_pcs])
target_df[colname_pval] = chi2.sf(target_df[colname_dist], n_pcs - 1)
compare_info['Mahalanobis_P_ALL'] = dict(target_df[colname_pval].describe())
logger.info('Mahalanobis Probability Distribution (train: all reference samples): {}'.format(
compare_info['Mahalanobis_P_ALL']))

## Check if PCs only capture target/reference stratification
if target_df.shape[0] >= 20:
for col_pc in cols_pcs:
mwu_pc = mannwhitneyu(ref_train_df[col_pc], target_df[col_pc])
compare_info[col_pc] = {'U': mwu_pc.statistic, 'pvalue': mwu_pc.pvalue}
if mwu_pc.pvalue < 1e-4:
logger.warning("{} *may* be capturing target/reference stratification (Mann-Whitney p-value={}), "
"use visual inspection of PC plot to confirm".format(col_pc, mwu_pc.pvalue))

# Run Ancestry Assignment methods
if method == 'Mahalanobis':
Expand All @@ -93,13 +118,7 @@ def compare_ancestry(ref_df: pd.DataFrame, ref_pop_col: str, target_df: pd.DataF
colname_dist = 'Mahalanobis_dist_{}'.format(pop)
colname_pval = 'Mahalanobis_P_{}'.format(pop)

match covariance_method:
case 'MinCovDet':
covariance_model = MinCovDet()
case 'EmpiricalCovariance':
covariance_model = EmpiricalCovariance()
case _:
assert False, "Invalid covariance method"
covariance_model = get_covariance_method(covariance_method)

covariance_fit = covariance_model.fit(ref_train_df.loc[ref_train_df[ref_pop_col] == pop, cols_pcs])

Expand Down Expand Up @@ -379,6 +398,7 @@ def nLL_mu_and_var(theta, df, c_score, l_predictors):
return sum(np.log(np.sqrt(f_var(df[l_predictors], theta_var))) +
(1/2)*(x - f_mu(df[l_predictors], theta_mu))**2/f_var(df[l_predictors], theta_var))


def grdnt_mu_and_var(theta, df, c_score, l_predictors):
"""Gradient used to optimize the nLL_mu_and_var fit function.
Adapted from https://github.com/broadinstitute/palantir-workflows/blob/v0.14/ImputationPipeline/ScoringTasks.wdl,
Expand Down
40 changes: 24 additions & 16 deletions pgscatalog_utils/download/download_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def get_with_user_agent(url: str) -> requests.Response:
return requests.get(url, headers=config.headers())


def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool) -> None:
def download_file(url: str, local_path: str, overwrite: bool,
ftp_fallback: bool) -> None:
if config.OUTDIR.joinpath(local_path).exists():
if not overwrite:
logger.warning(f"{config.OUTDIR.joinpath(local_path)} exists and overwrite is false, skipping download")
logger.warning(
f"{config.OUTDIR.joinpath(local_path)} exists and overwrite is false, skipping download")
return
elif overwrite:
logger.warning(f"Overwriting {config.OUTDIR.joinpath(local_path)}")
Expand All @@ -28,18 +30,24 @@ def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool
attempt: int = 0

while attempt < config.MAX_RETRIES:
response: requests.Response = get_with_user_agent(url)
match response.status_code:
case 200:
with open(config.OUTDIR.joinpath(local_path), "wb") as f:
f.write(response.content)
logger.info("HTTPS download complete")
attempt = 0
break
case _:
logger.warning(f"HTTP status {response.status_code} at download attempt {attempt}")
attempt += 1
time.sleep(5)
try:
response: requests.Response = get_with_user_agent(url)
match response.status_code:
case 200:
with open(config.OUTDIR.joinpath(local_path), "wb") as f:
f.write(response.content)
logger.info("HTTPS download complete")
break
case _:
logger.warning(
f"HTTP status {response.status_code} at download attempt {attempt}")
attempt += 1
time.sleep(5)
except requests.RequestException as e:
logger.warning(f"Connection error: {e}")
attempt += 1
time.sleep(5)
logger.warning(f"Retrying download {attempt=} of {config.MAX_RETRIES}")

if attempt > config.MAX_RETRIES:
if ftp_fallback:
Expand Down Expand Up @@ -67,9 +75,9 @@ def _ftp_fallback_download(url: str, local_path: str) -> None:
except Exception as e:
if "421" in str(e):
retries += 1
logger.debug(f"FTP server is busy. Waiting and retrying. Retry {retries} of {config.MAX_RETRIES}")
logger.debug(
f"FTP server is busy. Waiting and retrying. Retry {retries} of {config.MAX_RETRIES}")
time.sleep(config.DOWNLOAD_WAIT_TIME)
else:
logger.critical(f"Download failed: {e}")
raise Exception

33 changes: 26 additions & 7 deletions pgscatalog_utils/match/combine_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ def combine_matches():
config.OUTDIR = args.outdir

with pl.StringCache():
scorefile = read_scorefile(path=args.scorefile, chrom=None) # chrom=None to read all variants
scorefile = read_scorefile(path=args.scorefile,
chrom=None) # chrom=None to read all variants
logger.debug("Reading matches")
matches = pl.concat([pl.scan_ipc(x, memory_map=False, rechunk=False) for x in args.matches], rechunk=False)
matches = pl.concat(
[pl.scan_ipc(x, memory_map=False, rechunk=False) for x in args.matches],
rechunk=False)

logger.debug("Labelling match candidates")
params: dict[str, bool] = make_params_dict(args)
Expand All @@ -49,7 +52,20 @@ def _check_duplicate_vars(matches: pl.LazyFrame):
.collect()
.get_column('count')
.to_list())
assert max_occurrence == [1], "Duplicate IDs in final matches"

match n := max_occurrence[0]:
case None:
logger.critical("No variant matches found")
logger.critical(
"Did you set the correct genome build? Did you impute your genomes?")
raise ValueError
case _ if n > 1:
logger.critical("Duplicate IDs in final matches")
logger.critical(
"Please double check your genomes for duplicates and try again")
raise ValueError
case _:
logger.info("Scoring files are valid (no duplicate variants found)")


def _parse_args(args=None):
Expand All @@ -61,18 +77,21 @@ def _parse_args(args=None):
parser.add_argument('-m', '--matches', dest='matches', required=True, nargs='+',
help='<Required> List of match files')
parser.add_argument('--min_overlap', dest='min_overlap', required=True,
type=float, help='<Required> Minimum proportion of variants to match before error')
type=float,
help='<Required> Minimum proportion of variants to match before error')
parser.add_argument('-IDs', '--filter_IDs', dest='filter',
help='<Optional> Path to file containing list of variant IDs that can be included in the final scorefile.'
'[useful for limiting scoring files to variants present in multiple datasets]')
parser = add_match_args(parser) # params for labelling matches
parser = add_match_args(parser) # params for labelling matches
parser.add_argument('--outdir', dest='outdir', required=True,
help='<Required> Output directory')
parser.add_argument('--split', dest='split', default=False, action='store_true',
help='<Optional> Write scorefiles split per chromosome?')
parser.add_argument('--combined', dest='combined', default=False, action='store_true',
parser.add_argument('--combined', dest='combined', default=False,
action='store_true',
help='<Optional> Write scorefiles in combined format?')
parser.add_argument('-n', dest='n_threads', default=1, help='<Optional> n threads for matching', type=int)
parser.add_argument('-n', dest='n_threads', default=1,
help='<Optional> n threads for matching', type=int)
parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
help='<Optional> Extra logging information')
return parser.parse_args(args)
Expand Down
Loading

0 comments on commit 6da7eb0

Please sign in to comment.