Categories
Data

Query SRA Sequence Runs with Python

Retrieving data from SRA is a common task. NCBI has provided a nice tool collection named E-utilities to query and retrieve data from it. The example Python snippet below shows how to query NCBI SRA database using sample identifiers and get a table of linked NCBI BioProject, BioSample, Run, Download location and Size.

import sys, os                                                                                                                                                                                     
import subprocess                                                                                                                                                                                  
import shlex                                                                                                                                                                                       
import pandas as pd 
.....


def get_SRR_from_biosamples(csv: str, batch_size=10, debug=True):                                                                                                                                  
    """Gete SRA run ID from BioSample ID.                                                                                                                                                          
    """                                                                                                                                                                                            
    epost_cmd = 'epost -db biosample -format acc'                                                                                                                                                  
    elink_cmd = 'elink -target sra'                                                                                                                                                                
    efetch_cmd = 'efetch -db sra -format runinfo -mode xml'                                                                                                                                        
    xtract_cmd = """xtract -pattern Row -def "NA" -element BioProject\n                                                                                                                            
     BioSample Run download_path size_MB"""                                                                                                                                                        
                                                                                                                                                                                                   
    sample_ids = []                                                                                                                                                                                
    results = []                                                                                                                                                                                   
                                                                                                                                                                                                   
    with open(csv, 'r') as fh:                                                                                                                                                                     
        total_samples = fh.readlines()                                                                                                                                                             
        print(f'Total samples: {total_samples}')                                                                                                                                                   
        for idx, l in enumerate(total_samples):                                                                                                                                                    
            l = l.rstrip()                                                                                                                                                                         
            sample_ids.append(l)                                                                                                                                                                   
            batch_num = int(idx/batch_size) + 1                                                                                                                                                    
            run_flag = True                                                                                                                                                                        
            if debug:                                                                                                                                                                              
                if batch_num > 1:                                                                                                                                                                  
                    print('Debug mode. Stop execution after 1 batch.')                                                                                                                             
                    run_flag = None                                                                                                                                                                
                    break                                                                                                                                                                          
            if run_flag:                                                                                                                                                                           
                if  ((idx+1)%batch_size == 0) | (idx == len(total_samples) - 1):                                                                                                                   
                    print(f'Processing batch {batch_num}: {sample_ids}')                                                                                                                           
                    batch_results = []                                                                                                                                                             
                                                                                                                                                                                                   
                    sample_ids = ','.join(sample_ids)                                                                                                                                              
                    epost_cmd += f' -id "{sample_ids}"'                                                                                                                                            
                    epost = subprocess.Popen(shlex.split(epost_cmd),                                                                                                                               
                                             stdout=subprocess.PIPE,                                                                                                                               
                                             encoding='utf8')                                                                                                                                      
                    elink = subprocess.Popen(shlex.split(elink_cmd),                                                                                                                               
                                             stdin=epost.stdout,                                                                                                                                   
                                             stdout=subprocess.PIPE,                                                                                                                               
                                             encoding='utf8')                                                                                                                                      
                    efetch = subprocess.Popen(shlex.split(efetch_cmd),                                                                                                                             
                                              stdin=elink.stdout,                                                                                                                                  
                                              stdout=subprocess.PIPE,                                                                                                                              
                                              encoding='utf8')                                                                                                                                     
                    xtract = subprocess.Popen(shlex.split(xtract_cmd),                                                                                                                             
                                              stdin=efetch.stdout,                                                                                                                                 
                                              stdout=subprocess.PIPE,                                                                                                                              
                                              encoding='utf8')                                                                                                                                     
                                                                                                                                                                                                   
                    while epost.returncode is None:                                                                                                                                                
                        epost.poll()                                                                                                                                                               
                                                                                                                                                                                                   
                    for l in xtract.stdout.readlines():                                                                                                                                            
                        if not l.startswith('PRJ'):  # "502 Bad Gateway" when server is busy                                                                                                       
                            sys.stderr.write(f'Error processing {sample_ids}: {l}')                                                                                                                
                        else:                                                                                                                                                                      
                            if debug:                                                                                                                                                              
                                print(l.rstrip())                                                                                                                                                  
                            batch_results.append(l.split())                                                                                                                                        
                    print(f'\nTotal SRA Runs in batch {batch_num}: {len(batch_results)}.\n')                                                                                                       
                    results.extend(batch_results)                                                                                                                                                  
                    sample_ids = []                                                                                                                                                                
    print(f'Total runs in collection: {len(results)} with {idx+1} samples.')                                                                                                                       
                                                                                                                                                                                                   
    data = pd.DataFrame(results, columns=['BioProject', 'BioSample', 'Run', 'Download', 'size_MB'])                                                                                                
    return data                                                                                      

These E-utilities tools are used and need to be accessible from the environment: epost, elink, efetch, xtract. The subprocess module in Python is used to chain together these steps similar to Linux pipes. The samples are queried in batches to prevent too frequent queries to NCBI, which could lead to blocking of your future queries. After receiving the sample run identifiers, one can use the prefetch tool from E-utilities to download the files. And, of course, prefetch can be wrapped and chained together as well.

Categories
Uncategorized

PyTorch Geometric and CUDA

PyTorch Geometric (PyG) is an add-on library for developing graph neural networks using Python. It supports CUDA but you’ll have to make sure to install it correctly. Below is one error message I got after installing PyG:

from torch_geometric.data import Data
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
...

OSError: /anaconda3/lib/python3.7/site-packages/torch_sparse/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs

It is clear this error is related to CUDA version. So, I checked it:

print(torch.version.cuda, torch.version)
10.2, 1.9.0

Running $ nvidia-smi, gave a CUDA version 11.2. So my system was somehow messed up with mixed versions of CUDA. To fix the mess and get PyG working, I did the following:

$ pip uninstall torch
$ pip install torch===1.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.1+cu111.html
$ pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.1+cu111.html
$ pip install torch-geometric
$ apt-get install nvidia-modprobe

Note that there is no existing wheel built with CUDA 11.2 (cu112) so I used the closest version (cu111). Now PyG works! The “nvidia-modprobe” kernel extension fixes “RuntimeError: CUDA unknown error – this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero,” which I got after having two Python sessions running and both trying to using CUDA.

Update from some other testing regarding these errors:

RuntimeError: Detected that PyTorch and torch_cluster were compiled with different CUDA versions. PyTorch has CUDA version 11.1 and torch_cluster has CUDA version 10.2. Please reinstall the torch_cluster that matches your PyTorch install.

RuntimeError: Detected that PyTorch and torch_spline_conv were compiled with different CUDA versions. PyTorch has CUDA version 11.1 and torch_spline_conv has CUDA version 10.2. Please reinstall the torch_spline_conv that matches your PyTorch install.

The following commands fixed it:

$ pip install --upgrade pip
$ CUDA=cu111
$ TORCH=1.9.1
$ pip install torch-cluster==1.5.9 -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
$ pip install torch-spline-conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html