import marimo

__generated_with = "0.21.1"
app = marimo.App(width="medium")


@app.cell
def _():
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""
    ## MS2 Lysis Phage Protein Mutation Analsysis

    HTGAA wants us to analyze/find mutations of the lysis proteins. I want to see how many mutations of the lysis protein there are that don't mess up the overlapping maturation and capsid proteins that it overlaps with.
    """)
    return


@app.cell
def _():
    from Bio import SeqIO
    from Bio import Entrez
    Entrez.email = "jay.handfield@gmail.com"
    stream = Entrez.efetch(db="nucleotide", id="NC_001417", rettype="gb")
    record = SeqIO.read(stream, "gb")
    (record.name, record.seq)
    return (record,)


@app.cell
def _(record):
    len(record.seq)
    return


@app.cell
def _(record):
    features = record.features
    [(i, f.qualifiers["gene"][0], f.location.start, f.location.end) for (i,f) in enumerate(features) if f.type =="gene"]
    return


@app.cell
def _(record):
    (capsid_gene, lysis_gene, replicate_gene) = (record.features[3], record.features[5], record.features[7])
    (capsid_gene, lysis_gene, replicate_gene)
    return capsid_gene, lysis_gene, replicate_gene


@app.cell
def _(record):
    type(record.seq)
    return


@app.cell
def _(capsid_gene, lysis_gene, record, replicate_gene):
    lysis_sequence = record.seq[lysis_gene.location.start:lysis_gene.location.end]
    capsid_sequence = record.seq[capsid_gene.location.start:capsid_gene.location.end]
    replicase_sequence = record.seq[replicate_gene.location.start:replicate_gene.location.end]

    return capsid_sequence, lysis_sequence, replicase_sequence


@app.cell
def _(lysis_sequence):
    lysis_sequence.translate(cds=True)
    return


@app.cell
def _():
    from Bio.Data import CodonTable
    standard_table = CodonTable.standard_dna_table
    print(standard_table.forward_table, standard_table.back_table)
    return (standard_table,)


@app.cell
def _(lysis_sequence, standard_table):

    lysis_codons = [lysis_sequence[i:i+3] for i in range(0,len(lysis_sequence), 3)]
    stop_codon = 'TAA'
    lysis_aminos = [standard_table.forward_table[c] for c in lysis_codons if c != stop_codon]
    return lysis_aminos, lysis_codons


@app.cell
def _(lysis_codons, standard_table):
    [(i,c)  for (i, c) in enumerate(lysis_codons) if c not in standard_table.forward_table]
    return


@app.cell
def _(lysis_codons):
    lysis_codons
    return


@app.cell
def _(lysis_aminos):
    lysis_aminos
    return


@app.cell
def _():
    import marimo as mo

    return (mo,)


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""
    For each amino acid and each nucleotide position that encodes that amino acid there are a set of changes that preserve the amino acid, change the amino acid to something else, and another set that break it, i.e. make it a stop codon or some other invalid codon.
    """)
    return


@app.cell
def _(standard_table):
    from collections import defaultdict
    from collections import namedtuple

    ## I need to handle the stop codon in the same way as other amino acids
    ## since the shifted frame spans stop codon for other proteins
    ## To do this I will wrap/extend the basic table with a stop codon mapping
    ## where I represent the stop codon as '*'

    codon_mapping = {**standard_table.forward_table}
    for codon in standard_table.stop_codons:
        codon_mapping[codon] = '*'

    samesense_mutations = defaultdict(set)
    missense_mutations = defaultdict(set)
    for (codon,amino) in codon_mapping.items():
        for (pos, nt) in enumerate(codon):
            #print(pos, nt)
            samesense_mutations[(amino, pos)].add(nt)
            for alt_nt in ['A', 'T', 'G', 'C']:
                mutated_codon = codon[:pos]  + alt_nt + codon[pos+1:]
                if (mutated_codon in codon_mapping) and (codon_mapping[mutated_codon] != amino):
                    missense_mutations[(amino, pos)].add(alt_nt)
    samesense_mutations[('G', 2)], missense_mutations[('G', 'G', 2)]
    return codon_mapping, missense_mutations, nt, samesense_mutations


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""
    Now I have the samesense and misense mutations so I need to iterate over the revelant absolute positions in the DNA and check for the intersection of samesense and misense mutations at a given location.  Those are the "safe" mutations for that position.

    Just realized I don't really handle the stop codon which is smack in the middle there and has to translate into a stop codon.  I guess I can add that into the maps by hand if I need to.  Let's try to get code going for even first few positions
    """)
    return


@app.cell
def _(codon_mapping, missense_mutations, record, samesense_mutations):
    # Going to just do this for whole sections and then snip out the sections I care about.ab
    def find_allowed_missense(dna_seq):
        trace = []
        for (i, nt) in enumerate(dna_seq):
            if (i+3) > len(dna_seq):
             break
            codon_start = i - (i % 3)
            codon = dna_seq[codon_start:codon_start+3]
            nt_pos = i - codon_start
            shift_codon_start = codon_start + 1
            shift_codon = dna_seq[shift_codon_start:shift_codon_start+3]
            amino = codon_mapping[codon]
            shift_amino = codon_mapping[shift_codon]
            shift_pos = (nt_pos - 1) % 3
            samesense = samesense_mutations[(amino, nt_pos)]
            missense = missense_mutations[(shift_amino, nt, shift_pos)]
            allowed_misense = missense.intersection(samesense)
            trace.append((i, nt, nt_pos, codon, amino,  samesense, shift_pos, shift_codon, shift_amino, missense, allowed_misense))
        return trace
    find_allowed_missense(record.seq[0:13])
    return (find_allowed_missense,)


@app.cell
def _(find_allowed_missense, lysis_gene, record):
    allowed_mutations_trace = find_allowed_missense(record.seq)[lysis_gene.location.start-1:lysis_gene.location.end]
    return (allowed_mutations_trace,)


@app.cell
def _(allowed_mutations_trace):
    nontrivial_mutatations = [trace for trace in allowed_mutations_trace if len(trace[-1]) > 0 ]
    return (nontrivial_mutatations,)


@app.cell
def _(lysis_sequence, nontrivial_mutatations):
    len(nontrivial_mutatations), len(lysis_sequence)
    return


@app.cell
def _(allowed_mutations_trace):
    # Figure out total possible number of nt mutations
    allowed_mutations_trace[0:5]
    return


@app.cell
def _(lysis_aminos, lysis_sequence):
    lysis_sequence, lysis_aminos
    return


@app.cell
def _(capsid_gene, record, replicate_gene):
    ## Want to focus on the sequence from capsid protein to replicase with lysis in betweeen and frameshifted
    ## in paritcular we want to have the capside gene start be zero for frameshift/codon purposes
    focus_sequence = record.seq[capsid_gene.location.start: replicate_gene.location.end]
    return (focus_sequence,)


@app.cell
def _(codon_mapping, pl):
    # The code above is sort of awkward, I think what I really want is to just have a single function (which is basically
    # the inside of my loop) which takes a whole sequence and and absolute position and gets the info I want and prints it.  
    # the other loop is just a pain to debug etc. 
    # Going to compute the mutations dynamically also instead of pre-computing in a table
    def mutations_for(codon, pos):
        amino = codon_mapping[codon]
        nt = codon[pos]
        mutations = []
        for alt_nt in (set(['G', 'C', 'A', 'T']) - set(nt)):
            mutated_codon = codon[:pos] + alt_nt + codon[pos+1:]
            mutated_amino = codon_mapping[mutated_codon]
            mutations.append([str(codon), pos, nt, alt_nt,  amino, mutated_amino])
        return mutations
    
    #MutationInfo = namedtuple('MutationInfo', ['nt_point', 'nt', 'frameshift', 'nt_codon_pos', 'codon', 'amino', #'amino_pos', 'samesense', 'missense', 'missense_aminos'])
    def point_mutations_for_loc_and_shift(whole_sequence, nt_index, frameshift=0):
        nt = whole_sequence[nt_index]
        nt_codon_pos = (nt_index + frameshift) % 3
        amino_pos = (nt_index - frameshift) // 3
        codon_start = nt_index - nt_codon_pos 
        codon = whole_sequence[codon_start:codon_start+3]
        results = []
        for m in mutations_for(codon, nt_codon_pos):
            row = [nt_index, frameshift] + m
            results.append(row)
        return pl.DataFrame(results, schema=["sequence_index", "frameshift", "codon", "nt_position", "nt", "mutated_nt", "amino", "mutated_amino"], orient="row")
        # amino = codon_mapping[codon]
        # samesense = samesense_mutations[(amino, nt_codon_pos)] - set(nt)
        # missense = missense_mutations[(amino, nt_codon_pos)] - set(nt)
        # missense_aminos = {}
        # # for new_nt in missense:
        # #     mutated_codon = codon[0:nt_codon_pos] + new_nt + codon[nt_codon_pos+1:]
        # #     mutated_amino = codon_mapping[mutated_codon]
        # #     missense_aminos[f'{nt_point} {nt} to {new_nt}'] = mutated_amino
        # return MutationInfo(nt_point, nt, frameshift, nt_codon_pos, codon, amino, amino_pos, samesense, missense, missense_aminos )


    return (point_mutations_for_loc_and_shift,)


@app.cell
def _(focus_sequence, point_mutations_for_loc_and_shift):
    point_mutations_for_loc_and_shift(focus_sequence, 0, 0 )
    return


@app.cell
def _(focus_sequence, lysis_gene, pl, point_mutations_for_loc_and_shift):
    main_protein_mutations = pl.concat([point_mutations_for_loc_and_shift(focus_sequence, i, 0) for i in range(lysis_gene.location.start, lysis_gene.location.end)])
    lysis_protein_mutations = pl.concat([point_mutations_for_loc_and_shift(focus_sequence, i, 1) for i in range(lysis_gene.location.start, lysis_gene.location.end)])
    return lysis_protein_mutations, main_protein_mutations


@app.cell
def _(main_protein_mutations, pl):
    main_samesense_mutations = main_protein_mutations.filter(pl.col("amino") == pl.col("mutated_amino"))
    main_samesense_mutations
    return (main_samesense_mutations,)


@app.cell
def _(lysis_protein_mutations, pl):
    lysis_missense_mutations = lysis_protein_mutations.filter(pl.col("amino") != pl.col("mutated_amino"))
    lysis_missense_mutations
    return (lysis_missense_mutations,)


@app.cell
def _(lysis_missense_mutations, main_samesense_mutations, pl):
    allowed_mutations = lysis_missense_mutations.join(main_samesense_mutations, on=[pl.col("sequence_index"), pl.col("nt"), pl.col("mutated_nt")])
    allowed_mutations
    return (allowed_mutations,)


@app.cell
def _(allowed_mutations, lysis_gene, pl):
    allowed_mutations_extended = allowed_mutations.with_columns([(pl.col("sequence_index") - lysis_gene.location.start).alias("lysis_index")])
    allowed_mutations_extended
    return (allowed_mutations_extended,)


@app.cell
def _(allowed_mutations_extended, pl):
    allowed_mutations_extended.select([pl.col("lysis_index"), pl.col("mutated_amino")]).unique()
    return


@app.cell
def _(allowed_mutations_extended, pl):
    allowed_mutations_extended.select([pl.col("mutated_amino")]).unique()
    return


@app.cell
def _(capsid_gene, lysis_gene, record):
    [(i,nt) for (i, nt) in zip(range(1725, 1730), record.seq[1725:1730])],(lysis_gene.location.start, lysis_gene.location.end), capsid_gene.location.end
    return


@app.cell
def _(lysis_gene, point_mutations_for_loc_and_shift, record):
    point_mutations_for_loc_and_shift(record.seq, lysis_gene.location.start, 0)
    return


@app.cell
def _(lysis_gene, point_mutations_for_loc_and_shift, record):
    point_mutations_for_loc_and_shift(record.seq, lysis_gene.location.start, 1)
    return


@app.cell
def _(point_mutations_for_loc_and_shift, record):
    point_mutations_for_loc_and_shift(record.seq, 1334, 0)
    return


@app.cell
def _(capsid_gene, capsid_sequence):
    capsid_gene.location.start, capsid_sequence[0:6]
    return


@app.cell
def _(capsid_gene, lysis_gene, replicate_gene):
    lysis_gene.location.start, capsid_gene.location.start, replicate_gene.location.start
    return


@app.cell
def _(capsid_sequence, lysis_sequence, replicase_sequence):
    capsid_sequence[0:10], lysis_sequence[0:10], replicase_sequence[0:10]
    return


@app.cell
def _(capsid_gene, lysis_gene, replicate_gene):
    capsid_gene.location.start - lysis_gene.location.start,  capsid_gene.location.start-replicate_gene.location.start
    return


@app.cell
def _(codon_mapping, nt, point_mutations_at_loc_and_shift):
    def allowed_shifted_mutations(whole_sequence, nt_point, frameshift):
        base_mutations = point_mutations_at_loc_and_shift(whole_sequence, nt_point, 0)
        shifted_mutations = point_mutations_at_loc_and_shift(whole_sequence, nt_point, frameshift)
        base_samesense = base_mutations[-3]
        shifted_missense = shifted_mutations[-2]
        allowed = base_samesense.intersection(shifted_missense)
        missense_aminos = {}
        for new_nt in allowed:
            codon = shifted_mutations[4] 
            nt_codon_pos = shifted_mutations[3]
            mutated_codon = codon[0:nt_codon_pos] + new_nt + codon[nt_codon_pos+1:]
            mutated_amino = codon_mapping[mutated_codon]
            missense_aminos[f'{nt_point} {nt} to {new_nt}'] = mutated_amino
        return (nt_point, whole_sequence[nt_point], base_samesense, shifted_missense, allowed, missense_aminos)

    return (allowed_shifted_mutations,)


@app.cell
def _(allowed_shifted_mutations, lysis_gene, record):
    allowed_shifted_mutations(record.seq, lysis_gene.location.start, 1)
    return


@app.cell
def _(allowed_shifted_mutations, lysis_gene, record):
    [allowed_shifted_mutations(record.seq, i, 1) for i in range(lysis_gene.location.start, lysis_gene.location.start+4)]
    return


@app.cell
def _(allowed_shifted_mutations, lysis_gene, record):
    mutations_scan = [allowed_shifted_mutations(record.seq, i, 1) for i in range(lysis_gene.location.start, lysis_gene.location.end)]
    return (mutations_scan,)


@app.cell
def _(mutations_scan):
    [m for m in mutations_scan if len(m[-1]) > 1]
    return


@app.cell
def _(point_mutations_at_loc_and_shift, record):
    (point_mutations_at_loc_and_shift(record.seq, 1799, 0 ), point_mutations_at_loc_and_shift(record.seq, 1799, 1))
    return


@app.cell
def _(point_mutations_at_loc_and_shift, record):
    point_mutations_at_loc_and_shift(record.seq, 1799, 0 )
    return


@app.cell
def _(point_mutations_at_loc_and_shift, record):
    point_mutations_at_loc_and_shift(record.seq, 1799, 1)
    return


@app.cell
def _(samesense_mutations):
    samesense_mutations[('L', 2)]
    return


@app.cell(hide_code=True)
def _(mo):
    mo.md(r"""
    ## Take 3
    Finally realizing that I really just want to represent everything as a table, there isn't really a great row table in Python, so I am going to do polars and see how much I can push rows and algos to be vertical
    """)
    return


@app.cell
def _():
    import polars as pl
    import numpy as np

    gene_positions = pl.DataFrame({'sequence_location': range(1,10)})
    gene_positions
    return (pl,)


@app.cell
def _(codon_mapping):
    codon_mapping['GAA']
    return


if __name__ == "__main__":
    app.run()
