Rare Disease Literature.

Applying AI to understand trends of research in rare disease.

Preliminaries

Here we set up libraries and methods to create and query the local Postgres database we will be using to store our information from the Alhazen tools and agent

from alhazen.aliases import *
from alhazen.core import lookup_chat_models
from alhazen.agent import AlhazenAgent
from alhazen.tools.basic import AddCollectionFromEPMCTool
from alhazen.tools.paperqa_emulation_tool import *
from alhazen.toolkit import *

from alhazen.utils.ceifns_db import Ceifns_LiteratureDb, create_ceifns_database, drop_ceifns_database, list_databases
from alhazen.utils.searchEngineUtils import *

from langchain.vectorstores.pgvector import PGVector
from langchain_community.chat_models.ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_google_vertexai import ChatVertexAI

from datetime import datetime

from importlib_resources import files
import os
import pandas as pd

from sqlalchemy import func, text

from time import time
from tqdm import tqdm

from transformers import pipeline, AutoModel, AutoTokenizer
import torch
import os
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough, RunnableLambda
from operator import itemgetter
from langchain.chat_models import ChatOllama
from langchain.schema import get_buffer_string, OutputParserException, format_document
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from alhazen.utils.output_parsers import JsonEnclosedByTextOutputParser

#from paperqa.prompts import summary_prompt as paperqa_summary_prompt, qa_prompt as paperqa_qa_prompt, select_paper_prompt, citation_prompt, default_system_prompt
from langchain.schema import format_document
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain_core.runnables import RunnableParallel
import local_resources.queries.rao_grantees as rao_files
from alhazen.utils.queryTranslator import QueryTranslator, QueryType

Remember to set environmental variables for this code:

  • ALHAZEN_DB_NAME - the name of the Postgres database you are storing information into
  • LOCAL_FILE_PATH - the location on disk where you save files for your digital library, downloaded models or other data.
if os.environ.get('LOCAL_FILE_PATH') is None: 
    raise Exception('Where are you storing your local literature database?')
if os.path.exists(os.environ['LOCAL_FILE_PATH']) is False:
    os.makedirs(os.environ['LOCAL_FILE_PATH'])    

loc = os.environ['LOCAL_FILE_PATH']
db_name = 'rare_as_one_diseases'

Run this command to destroy your current database

USE WITH CAUTION

drop_ceifns_database(os.environ['ALHAZEN_DB_NAME'])

Run this command to create a new, empty database.

create_ceifns_database(os.environ['ALHAZEN_DB_NAME'])

This command lists all the tools the Alhazen agent system has access to

ldb = Ceifns_LiteratureDb(loc=loc, name=db_name)

llms = lookup_chat_models()
llm_databricks_llama3 = ChatOpenAI(base_url='https://czi-shared-infra-czi-sci-general-prod-databricks.cloud.databricks.com/serving-endpoints', 
                api_key=os.environ['DATABRICKS_API_KEY'], 
                model='databricks-meta-llama-3-70b-instruct')
llm_dbrx = llms.get('gpt4_1106')

cb = AlhazenAgent(db_name=db_name, agent_llm=llm_databricks_llama3, tool_llm=llm_databricks_llama3)
print('AGENT TOOLS')
for t in cb.tk.get_tools():
    print('\t'+type(t).__name__)

Build paper collections

This section will build a literature collection across each of the diseases in the Rare As One Cohorts for cycle 1 and 2.

What diseases are we querying the literature for?

cols_to_include = ['ID', 'CORPUS_NAME', 'TERMS']
df = pd.read_csv(files(rao_files).joinpath('CZI_RAO_diseases.tsv'), sep='\t')
df = df.drop(columns=[c for c in df.columns if c not in cols_to_include])

df

This command iterates over the list of different collections and runs a query for each one on the European website by processing the TERMS column from the dataframe with the QueryTranslator utility. This generates a search query in boolean logic that searches the TITLE_ABS field in the remote database (See https://www.ebi.ac.uk/europepmc/webservices/rest/fields for possible fields to search).

qt = QueryTranslator(df.sort_values('ID'), 'ID', 'TERMS', 'CORPUS_NAME')
(corpus_ids, epmc_queries) = qt.generate_queries(QueryType.epmc, sections=['TITLE_ABS'])
corpus_names = df['CORPUS_NAME']

addEMPCCollection_tool = [t for t in cb.tk.get_tools() if isinstance(t, AddCollectionFromEPMCTool)][0]
for (id, name, query) in zip(corpus_ids, corpus_names, epmc_queries):
    if id < 60:
        continue
    addEMPCCollection_tool.run(tool_input={'id': id, 'name':name, 'query':query, 'full_text':False})
#
# Note - create a new corpus for collaborative discussions with CellXGene team (in particular Maximillian L.)
#
addEMPCCollection_tool = [t for t in cb.tk.get_tools() if isinstance(t, AddCollectionFromEPMCTool)][0]
id = '83'
name = 'Diffuse Midline Glioma'
query = '"diffuse midline glioma" OR "diffuse intrinsic pontine glioma" OR "brainstem glioma" OR "diffuse intrinsic pontine glioma"'
addEMPCCollection_tool.run(tool_input={'id': id, 'name':name, 'query':query, 'full_text':False})
#
# Note - create a new corpus for collaborative discussions with CellXGene team (in particular Maximillian L.)
#
PaperQA_tool = [t for t in cb.tk.get_tools() if isinstance(t, PaperQAEmulationTool)][0]
question = 'Write an essay to answer the question: "What is the connection between SEC61B and the unfolded protein response?'
PaperQA_tool.run(tool_input={'question': question})
cb.agent_executor.invoke({'input':'Write an essay to answer the question: "What known gene variants are associated with Primary Ciliary Dyskinesia?'})

Crazy Bug

Running the skc_id=83 query below takes 6 minutes and skc_id=79 takes less than 1 sec. 

ldb.session.rollback()
ldb.session.execute(text(
    '''SELECT DISTINCT skc.name, ske.id, ske.content, ske.publication_date as pub_date, ske.type as pub_type, emb.embedding, skf.content 
    FROM langchain_pg_embedding as emb, 
        "ScientificKnowledgeExpression" as ske,
        "ScientificKnowledgeCollection_has_members" as skc_hm, 
        "ScientificKnowledgeCollection" as skc, 
        "ScientificKnowledgeFragment" as skf
    WHERE emb.cmetadata->>'i_type' = 'CitationRecord' AND
        emb.cmetadata->>'e_id' = ske.id AND 
        emb.cmetadata->>'f_id' = skf.id AND
        skc_hm."ScientificKnowledgeCollection_id" = skc.id AND
        ske.id = skc_hm.has_members_id AND (skc.id='79')
    ORDER BY pub_date DESC;''')).fetchall()
ldb.session.rollback()
ldb.session.execute(text(
    '''SELECT DISTINCT skc.name, ske.id, ske.content, ske.publication_date as pub_date, ske.type as pub_type, emb.embedding, skf.content 
    FROM langchain_pg_embedding as emb, 
        "ScientificKnowledgeExpression" as ske,
        "ScientificKnowledgeCollection_has_members" as skc_hm, 
        "ScientificKnowledgeCollection" as skc, 
        "ScientificKnowledgeFragment" as skf
    WHERE emb.cmetadata->>'i_type' = 'CitationRecord' AND
        emb.cmetadata->>'e_id' = ske.id AND 
        emb.cmetadata->>'f_id' = skf.id AND
        skc_hm."ScientificKnowledgeCollection_id" = skc.id AND
        ske.id = skc_hm.has_members_id AND (skc.id='83')
    ORDER BY pub_date DESC;''')).fetchall()
q = ldb.session.query(SKE).distinct(SKE.id)
print(q.count())
ldb.embed_expression_list(q.all())
cb.agent_executor.invoke({'input':'Get full text copies of all papers in the collection with id="0".'})
q = ldb.session.query(SKC.id, SKC.name, func.count(SKC_HM.has_members_id)) \
    .filter(SKC.id==SKC_HM.ScientificKnowledgeCollection_id) \
    .group_by(SKC.id, SKC.name) \
    .order_by(SKC.id.cast(Integer))
corpora_df = pd.DataFrame(q.all(), columns=['Corpus ID', 'Corpus Name', 'Paper Count'])

paper_count = ldb.session.query(func.count(SKE.id)).first()
print('Count of all papers in database: %d'%(paper_count[0]))

corpora_df
q3 = ldb.session.query(N) \
        .filter(N.type == 'NoteAboutFragment') 

for n in q3.all():
    n_content = json.loads(n.content)
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print(n.id)
    print(n_content.get('response')) 
    print(n_content.get('data'))
skes = ldb.session.query(SKE).all()
#ldb.embed_expression_list(skes)
print(len(skes))
ldb.session.rollback()
ft_retriever  = [t for t in tk.get_tools() if isinstance(t, RetrieveFullTextTool)][0]

for i, c in corpora_df.iterrows():
    if c['Corpus ID'] != '81':
        continue
    print(c['Corpus Name'])
    ft_count = 0
    no_ft_count = 0
    doi_list = [e.id for e in ldb.list_expressions(collection_id=c['Corpus ID'])]
    for doi in doi_list:
        d2 = doi.replace('doi:', '')
        path = loc+db_name+'/ft/'
        nxml_file_path = path+'/'+d2+'.nxml'
        pdf_file_path = path+'/'+d2+'.pdf'
        html_file_path = path+'/'+d2+'.html'
        if os.path.exists(nxml_file_path) or  \
                os.path.exists(pdf_file_path) or \
                os.path.exists(html_file_path):
            ft_count += 1
        try: 
            no_ft_count += 1
            #print('\t'+doi)
            ft_retriever.run(tool_input={'paper_id': doi})
        except Exception as e:
            print(e)
    print(ft_count)
    print(no_ft_count)
q = ldb.session.query(SKE.id, SKI.id, SKI.type, SKF.id, SKF.type, SKF.offset, SKF.content) \
    .filter(SKC.id==SKC_HM.ScientificKnowledgeCollection_id) \
    .filter(SKC_HM.has_members_id==SKE.id) \
    .filter(SKE.id==SKE_HR.ScientificKnowledgeExpression_id) \
    .filter(SKE_HR.has_representation_id==SKI.id) \
    .filter(SKI.id==SKI_HP.ScientificKnowledgeItem_id) \
    .filter(SKI_HP.has_part_id==SKF.id) \
    .filter(SKE_HR.has_representation_id==SKI.id) \
    .filter(SKF.type=='section') \
    .filter(SKI.type.like('%FullText')) \
    .order_by(SKE.id, SKF.offset)
items_df = pd.DataFrame(q.all(), columns=['doi', 'item_id', 'item_type', 'fragment_id', 'fragment_type', 'offset', 'content'])

items_df

Index the abstracts and run some simple semantic queries

Here we index each paper’s title and abstract to build a simple question / answer interface.

ldb.session.rollback()
for i, c in tqdm(corpora_df.iterrows()):
    if c['Corpus ID'] != '81':
        continue
    expressions = ldb.list_expressions(collection_id=c['Corpus ID'])    
    ldb.embed_expression_list(expressions)
question = 'What is known about genetics underlying Stiff Person Syndrome?'

ldb.query_vectorindex(question, k=10, collection_name='ScienceKnowledgeItem_FullText')

ATTEMPTING TO RECONSTRUCT PAPER-QA PIPELINE IN OUR SYSTEM.

  1. Embed paper sections + question
  2. Given the question, summarize the retrieved paper sections relative to the question
  3. Score and select relevant passages
  4. Put summaries into prompt
  5. Generate answer with prompt
os.environ['PGVECTOR_CONNECTION_STRING'] = "postgresql+psycopg2:///"+ldb.name
vectorstore = PGVector.from_existing_index(
        embedding = ldb.embed_model, 
        collection_name = 'ScienceKnowledgeItem') 
retriever = vectorstore.as_retriever(search_kwargs={'k':15, 'filter': {'skc_ids': 81}})
#retriever = vectorstore.as_retriever()
retriever.invoke(question)
hum_p = '''First, read through the following JSON encoding of {k} research articles: 

Each document has three attributes: (A) a digital object identifier ('DOI') code, (B) a CITATION string containing the authors, publication year, title and publication location, and the (C) CONTENT field with the title and abstract of the paper.  

```json:{context}```

Then, generate a JSON list of summaries of each article in order to help answer the following question:

Question: {question}

Do NOT directly answer the question, instead summarize to give evidence to help answer the question. 
Focus on specific details, including numbers, equations, or specific quotes. 
Reply "Not applicable" if text is irrelevant. 
Restrict each summary to {summary_length} words. 
Also, provide a score from 1-10 indicating relevance to question. Do not explain your score. 

Write this answer as JSON formatted output. Provide a list of {k} dict objects with the following fields: DOI, SUMMARY, RELEVANCE SCORE. 

Do not provide additional explanation for the answer.
Do not include any other response other than a JSON object.
'''
sys_p = '''Answer in a direct and concise tone. Your audience is an expert, so be highly specific. If there are ambiguous terms or acronyms, first define them.'''

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="'DOI': '{ske_id}', CITATION: '{citation}', CONTENT:'{page_content}'")
def combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="},{\n"
):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return '[{'+document_separator.join(doc_strings)+'}]'

template = ChatPromptTemplate.from_messages([
            ("system", sys_p),
            ("human", hum_p)])

qa_chain = (
    RunnableParallel({
        "k": itemgetter("k"),
        "question": itemgetter("question"),
        "summary_length": itemgetter("summary_length"),
        "context": itemgetter("question") | retriever | combine_documents,
    })
    | {
        "summary": template | ChatOllama(model='mixtral') | JsonEnclosedByTextOutputParser(),
        "context": itemgetter("context"),
    }
)

input = {'question': question, 'summary_length': 1000, 'k':5}    
out = qa_chain.invoke(input, config={'callbacks': [ConsoleCallbackHandler()]})
print(json.dumps(out, indent=4))

Discourse Analysis

model_path = '/Users/gully.burns/Documents/2024H1/models/discourse_tagger'
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1", 
                                          truncation=True, 
                                          max_length=512)
labels = ['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS']
lookup = {'LABEL_%d'%(i):l for i, l in enumerate(labels)}
model = AutoModel.from_pretrained(model_path)
model.eval()

classifier = pipeline("text-classification", 
                      model = model_path, 
                      tokenizer=tokenizer, 
                      truncation=True,
                      batch_size=8,
                      device='mps')
# Try an out-of-the-box classifier on the data for discourse tagging.

ldb.session.rollback()
one_year_ago = (datetime.now() - timedelta(days=1*365))

q = ldb.session.query(SKE, SKF) \
    .filter(SKC.id==SKC_HM.ScientificKnowledgeCollection_id) \
    .filter(SKC_HM.has_members_id==SKE.id) \
    .filter(SKE.id==SKE_HR.ScientificKnowledgeExpression_id) \
    .filter(SKE_HR.has_representation_id==SKI.id) \
    .filter(SKI.id==SKI_HP.ScientificKnowledgeItem_id) \
    .filter(SKI_HP.has_part_id==SKF.id) \
    .filter(SKE_HR.has_representation_id==SKI.id) \
    .filter(SKI.type == 'CitationRecord' ) \
    .order_by(SKE.id)

#   .filter(SKC.name == 'The Stiff Person Syndrome' ) \
#   .filter(SKE.publication_date >= one_year_ago) \

s_list = []
for e, f in q.all():
    for i, s in enumerate(ldb.sent_detector.tokenize(f.content)):
        s_list.append([e.id, f.id, i, s])
sent_df = pd.DataFrame(s_list, columns=['doi', 'f_id', 's_id', 'text'])
sent_df
# Predict multipe texts on single CPU and time the inference duration
start = time()

df = sent_df

preds = classifier([row.text for i, row in df.iterrows()])
pred_df = pd.DataFrame(preds)
df['label'] = [lookup[row.label] for i, row in pred_df.iterrows()]
df['score'] = [row.score for i, row in pred_df.iterrows()]

end = time()

print('Prediction time:', str(timedelta(seconds=end-start)))
df
ldb.session.rollback()
# Generate fragment sentences and add them as Notes
ldb.session.rollback()
for i, row in df.iterrows():
    f_q = ldb.session.query(SKF).filter(SKF.id == row.f_id).first()
    i_q = ldb.session.query(SKI).filter(SKI.id == row.f_id.split('.')[0]).first()
    o = i_q.content.find(row.text)
    l = len(row.text)
    sentence_fragment = ScientificKnowledgeFragment(id=f_q.id+'.'+str(row.s_id), \
                                                    content=row.text, \
                                                    offset=o, \
                                                    length=l, \
                                                    type='sentence')
    i_q.has_part.append(sentence_fragment)
    note_content = {'discourse_label': row.label, 'score': row.score}
    n = Note(id=f_q.id+'.'+str(row.s_id)+'.discourse_type',
             content=json.dumps(note_content, indent=4),
             format='json',
             type='NoteAboutFragment')
    sentence_fragment.has_notes.append(n)
    ldb.session.flush()
ldb.session.commit()

Running DRSM Classifiers.

model_path = '/Users/gully.burns/Documents/2024H1/models/drsm_classifier'
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1", 
                                          truncation=True, 
                                          max_length=512)
labels = ['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS']
lookup = {'LABEL_%d'%(i):l for i, l in enumerate(labels)}
model = AutoModel.from_pretrained(model_path)
model.eval()

classifier = pipeline("text-classification", 
                      model = model_path, 
                      tokenizer=tokenizer, 
                      truncation=True,
                      batch_size=8,
                      device='mps')

Topic Modeling over the corpus.

What are the main topics being discussed in each paper?

model_path = '/Users/gully.burns/Documents/2024H1/models/drsm_classifier'
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1", 
                                          truncation=True, 
                                          max_length=512)
labels = ['BACKGROUND', 'OBJECTIVE', 'METHODS', 'RESULTS', 'CONCLUSIONS']
lookup = {'LABEL_%d'%(i):l for i, l in enumerate(labels)}
model = AutoModel.from_pretrained(model_path)
model.eval()

classifier = pipeline("text-classification", 
                      model = model_path, 
                      tokenizer=tokenizer, 
                      truncation=True,
                      batch_size=8,
                      device='mps')

Search for and download Full Text Papers.

Can we search for all Stiff Person Syndrome papers published in the last 10 years?

ldb.session.rollback()

ten_years_ago = (datetime.now() - timedelta(days=10*365))
print(ten_years_ago)

q = ldb.session.query(func.extract('year', SKE.publication_date.cast(Date)), func.count(SKE.id) ) \
    .filter(SKC.id==SKC_HM.ScientificKnowledgeCollection_id) \
    .filter(SKC_HM.has_members_id==SKE.id) \
    .filter(SKE.publication_date >= ten_years_ago) \
    .filter(SKC.name == 'The Stiff Person Syndrome' ) \
    .group_by(func.extract('year', SKE.publication_date.cast(Date))) \
    .order_by(func.extract('year', SKE.publication_date.cast(Date)))
sps_pubcount_df = pd.DataFrame(q.all(), columns=['doi', 'date'])
sps_pubcount_df

Run PaperQA

cb.agent_executor.invoke({'input':'Write a short essay on "What connections between primary ciliary diskinesia and primary cilia have been studied?" based on the collection with ID="70".'})