RNAquarium

Using LLMs to extract information from RNA studies in Zebrafish

Basics

Python Imports

Setting python imports, environment variables, and other crucial set up parameters here.

from alhazen.aliases import *
from alhazen.core import get_langchain_chatmodel, MODEL_TYPE
from alhazen.agent import AlhazenAgent
from alhazen.schema_sqla import *
from alhazen.core import get_langchain_chatmodel, MODEL_TYPE
from alhazen.tools.basic import AddCollectionFromEPMCTool, DeleteCollectionTool
from alhazen.tools.paperqa_emulation_tool import PaperQAEmulationTool
from alhazen.tools.metadata_extraction_tool import * 
from alhazen.tools.protocol_extraction_tool import *
from alhazen.toolkit import *
from alhazen.utils.jats_text_extractor import NxmlDoc

from alhazen.utils.jats_text_extractor import NxmlDoc
from alhazen.utils.ceifns_db import Ceifns_LiteratureDb, create_ceifns_database, drop_ceifns_database, list_databases
from alhazen.utils.searchEngineUtils import *


from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.pgvector import PGVector
from langchain_community.chat_models.ollama import ChatOllama
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import ChatOpenAI

from bs4 import BeautifulSoup,Tag,Comment,NavigableString
from databricks import sql
from datetime import datetime
from importlib_resources import files
import os
import pandas as pd
from pathlib import Path
import re
import requests

from sqlalchemy import create_engine, text, exists, func, or_, and_, not_, desc, asc
from sqlalchemy.orm import sessionmaker, aliased

from time import time,sleep
from tqdm import tqdm
from urllib.request import urlopen
from urllib.parse import quote_plus, quote, unquote
from urllib.error import URLError, HTTPError
import yaml

import pymde
import torch
import local_resources.data_files.cryoet_portal_metadata as cryoet_portal_metadata

from alhazen.utils.searchEngineUtils import load_paper_from_openalex, read_references_from_openalex 
from pyalex import config, Works, Work
config.email = "gully.burns@chanzuckerberg.com"

import requests
import os
import local_resources.data_files.rnaquarium as rnaquarium
from alhazen.utils.queryTranslator import QueryTranslator, QueryType

Environment Variables

Remember to set environmental variables for this code:

  • ALHAZEN_DB_NAME - the name of the PostGresQL database you are storing information into
  • LOCAL_FILE_PATH - the location on disk where you save temporary files, downloaded models or other data.
os.environ['ALHAZEN_DB_NAME'] = 'rnaquarium'
os.environ['LOCAL_FILE_PATH'] = '/users/gully.burns/alhazen/'
if os.path.exists(os.environ['LOCAL_FILE_PATH']) is False:
    os.makedirs(os.environ['LOCAL_FILE_PATH'])
    
if os.environ.get('ALHAZEN_DB_NAME') is None: 
    raise Exception('Which database do you want to use for this application?')
db_name = os.environ['ALHAZEN_DB_NAME']

if os.environ.get('LOCAL_FILE_PATH') is None: 
    raise Exception('Where are you storing your local literature database?')
loc = os.environ['LOCAL_FILE_PATH']

Setup utils, agents, and tools

ldb = Ceifns_LiteratureDb(loc=loc, name=db_name)
llm = ChatOllama(model='mixtral:instruct') 
llm2 = ChatOpenAI(model='gpt-4-1106-preview') 
llm2 = ChatOpenAI(model='gpt-4-1106-preview') 
#llm3 = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)

cb = AlhazenAgent(llm2, llm2)
print('AGENT TOOLS')
for t in cb.tk.get_tools():
    print('\t'+type(t).__name__)

test_tk = MetadataExtractionToolkit(db=ldb, llm=llm2)
print('\nTESTING TOOLS')
for t in test_tk.get_tools():
    print('\t'+type(t).__name__)

Building the database

Scripts to Build / Delete the database

If you need to restore a deleted database from backup, use the following shell commands:

$ createdb em_tech
$ psql -d em_tech -f /local/file/path/em_tech/backup<date_time>.sql
drop_ceifns_database(os.environ['ALHAZEN_DB_NAME'])
create_ceifns_database(os.environ['ALHAZEN_DB_NAME'])

Build CEIFNS database from 900 dois in database

Load data from the spreadsheet

df = pd.read_csv(files(rnaquarium).joinpath('RNAquarium_paper_list.tsv'), sep='\t')
dois = df['DOI'].to_list()
df

Run this cell to execute paged queries (length 40) over the European PMC for each of the DOIs mentioned in the spreadsheet loaded above.

addEMPCCollection_tool = [t for t in cb.tk.get_tools() if isinstance(t, AddCollectionFromEPMCTool)][0]
step = 40
for start_i in range(0, len(dois), step):
    query = ' OR '.join(['doi:\"'+dois[i]+'\"' for i in range(start_i, start_i+step) if i < len(dois)])
    addEMPCCollection_tool.run({'id': '0', 'name':'RNAquarium Papers', 'query':query, 'full_text':True})

Run this cell to check how many papers from the list are loaded in our database.

# Compare contents of database to the list of dois
missing_list = []
titles = []
for doi in dois:
    row = df[df['DOI']==doi]
    doi_in_db = ldb.session.query(SKE).filter(SKE.id=='doi:'+doi.lower()).all()
    if len(doi_in_db) == 0:
        print('DOI: '+doi)
        print('\t%s (%d) %s %s'%(row['Author'].iloc[0],row['Publication Year'].iloc[0],row['Title'].iloc[0],row['Journal Abbreviation'].iloc[0]))
        missing_list.append(doi)
        titles.append(row['Title'].iloc[0])
print('%d Missing DOIs'%(len(missing_list)))

Use OpenAlex as filler to add papers that were missed on EPMC

ldb.session.rollback()
corpus = ldb.session.query(SKC).filter(SKC.id=='0').first()
count = 0
print(len(corpus.has_members))

papers_to_index = []
for i, doi in enumerate(missing_list):
    p = load_paper_from_openalex(doi)
    ldb.session.add(p)
    corpus.has_members.append(p)
    p.member_of.append(corpus)
    for item in p.has_representation:
        for f in item.has_part:
            #f.content = '\n'.join(self.sent_detector.tokenize(f.content))
            f.part_of = item.id
            ldb.session.add(f)
        item.represented_by = p.id
        ldb.session.add(item)
    papers_to_index.append(p)
    ldb.session.flush()

ldb.embed_expression_list(papers_to_index)

ldb.session.commit()

Get full text copies of all the papers about CryoET

This invokes the agent directly to make it easy to run the retrieval tool.

cb.db.session.rollback()
cb.agent_executor.invoke({'input':'Retrieve full text for the collection with id="0".'})

Analyze Collections

Build a basic report of the composition over all collections in the database (listed by types of items).

cb.db.report_collection_composition()
cb.db.report_non_full_text_for_collection(0)

Tests + Checks

Agent tool selection + execution + interpretation

# use this cell to test the agent's 
cb.agent_executor.invoke({'input':'Hi who are you and what can you do?'})

Run MetaData Extraction Chain over listed papers

Here, we run various versions of the metadata extraction tool to examine performance over the cryoet dataset.

str(files(cryoet_portal_metadata).joinpath('temp'))[0:-4]
# Get the metadata extraction tool
t2 = [t for t in test_tk.get_tools() if isinstance(t, MetadataExtraction_EverythingEverywhere_Tool)][0]

# Hack to get the path to the metadata directory as a string
metadata_dir = str(files(rnaquarium).joinpath('temp'))[0:-4]

# Compile the answers from the metadata directory
#t2.compile_answers('cryoet', metadata_dir)

# Create a dataframe to store previously extracted metadata
df = pd.DataFrame()
for d_id in dois:
    item_types = set()
    #d_id = 'doi:'+d
    df2 = pd.DataFrame(t2.read_metadata_extraction_notes(d_id, 'rnaquarium')) 
    df = pd.concat([df, df2]) 
     
# Iterate over papers to run the metadata extraction tool
for d_id in dois[0:10]:
    item_types = set()
    #d_id = 'doi:'+d

    # Skip if the doi is already in the database
    if len(df)>0 and d_id in df.doi.unique():
        continue

    # Run the metadata extraction tool on the doi
    t2.run(tool_input={'paper_id': d_id, 'extraction_type': 'rnaquarium'})

    # Add the results to the dataframe
    df2 = pd.DataFrame(t2.read_metadata_extraction_notes(d_id, 'rnaquarium')) 
    df = pd.concat([df, df2])
q = cb.db.session.query(N) \
    .filter(N.id == NIA.Note_id) \
    .filter(N.type == 'MetadataExtractionNote') \
    .filter(N.name.like('rnaquarium_%')) 
l = []
for n in q.all():
    tup = json.loads(n.content)
    t, doi, label = n.name.split('__')
    tup['doi'] = 'doi:'+doi
    tup['extraction_type'] = t
    tup['run_label'] = label
    l.append(tup)
report_df = pd.DataFrame(l).set_index('doi')
report_df
report_df.to_csv(loc+'/rnaquarium_metadata_extraction_report.tsv', sep='\t')
# Create a dataframe to store previously extracted metadata
df = pd.DataFrame()
for d_id in dois:
    df2 = pd.DataFrame(t2.read_metadata_extraction_notes(d_id, 'rnaquarium')) 
    df = pd.concat([df, df2]) 
df
ldb.session.rollback()
# USE WITH CAUTION - this will delete all extracted metadata notes in the database
# clear all notes across papers listed in `dois` list
for d in list(set(dois[0:10])):
    d_id = 'doi:'+d
    e = ldb.session.query(SKE).filter(SKE.id==d_id).first()
    notes_to_delete = []
    if e is None:
        continue
    for n in ldb.read_notes_about_x(e):
        notes_to_delete.append(n.id)
    for n in notes_to_delete:
        ldb.delete_note(n)

Protocol Modeling + Extraction

ldb = Ceifns_LiteratureDb(loc=loc, name=db_name)
slm = ChatOllama(model='stablelm-zephyr') 
llm = ChatOllama(model='mixtral:instruct') 
llm2 = ChatOpenAI(model='gpt-4-1106-preview') 
d = ("This tool attempts to draw a protocol design from the description of a scientific paper.")
t = ProcotolExtractionTool(db=ldb, llm=llm2, description=d)
t.run(tool_input={'paper_id': 'doi:10.1101/2022.04.12.488077', 'extraction_type': 'cryoet'})
ldb.session.rollback()
rag_embeddings_list = [json.loads(e[0]) for e in ldb.session.execute(text("""
                        SELECT DISTINCT emb.embedding 
                         FROM langchain_pg_embedding as emb, 
                            "ScientificKnowledgeExpression" as ske,
                            "ScientificKnowledgeCollection_has_members" as skc_hm
                         WHERE cmetadata->>'i_type' = 'CitationRecord' AND
                            cmetadata->>'e_id' = ske.id AND 
                            ske.id = skc_hm.has_members_id AND
                            skc_hm."ScientificKnowledgeCollection_id"='0';
                        """)).fetchall()]
rag_embeddings_tensor = torch.FloatTensor(rag_embeddings_list)

proj_embeddings = pymde.preserve_neighbors(rag_embeddings_tensor, constraint=pymde.Standardized()).embed()
pymde.plot(proj_embeddings)