Skip to content

API

ZipStrain provides a Python API for programmatic access, enabling integration into larger bioinformatics pipelines and workflows. The API also allows for downstream analyses and visualizations.


Database

zipstrain.database

This module provides classes and functions to manage profile and comparison databases for efficient data handling. The ProfileDatabase class manages profiles, while the GenomeComparisonDatabase class handles comparisons between profiles. See the documentation of each class for more details.

GeneComparisonConfig

Bases: BaseModel

Configuration for gene-level comparisons between profiles.

Attributes:

Name Type Description
scope str

The scope of the comparison in format "GENOME:GENE" (e.g., "all:gene1" compares gene1 across all genomes, "genome1:gene1" compares gene1 only in genome1 across samples).

gene_db_id str

An ID given to the gene fasta file used for profiling.

reference_genome_id str

An ID given to the reference fasta file used for profiling.

stb_file_loc str

Location of the scaffold-to-genome mapping file.

min_cov int

Minimum coverage threshold for considering a position.

min_gene_compare_len int

Minimum gene length required for comparison.

Source code in zipstrain/src/zipstrain/database.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
class GeneComparisonConfig(BaseModel):
    """
    Configuration for gene-level comparisons between profiles.

    Attributes:
        scope (str): The scope of the comparison in format "GENOME:GENE" (e.g., "all:gene1" compares gene1 across all genomes, "genome1:gene1" compares gene1 only in genome1 across samples).
        gene_db_id (str): An ID given to the gene fasta file used for profiling.
        reference_genome_id (str): An ID given to the reference fasta file used for profiling.
        stb_file_loc (str): Location of the scaffold-to-genome mapping file.
        min_cov (int): Minimum coverage threshold for considering a position.
        min_gene_compare_len (int): Minimum gene length required for comparison.
    """
    model_config = ConfigDict(extra="forbid")
    gene_db_id:str= Field(default="",description="An ID given to the gene fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with gene database IDs in the Profile Database.")
    reference_genome_id:str= Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
    scope: str = Field(description="Scope in format GENOME:GENE (e.g., 'all:gene1', 'genome1:gene1')")
    stb_file_loc: str = Field(description="Location of the scaffold-to-genome mapping file")
    min_cov: int = Field(default=5, description="Minimum coverage threshold")
    min_gene_compare_len: int = Field(default=100, description="Minimum gene length for comparison")

    @field_validator("scope")
    @classmethod
    def validate_scope(cls, v: str) -> str:
        """Validate that scope follows GENOME:GENE format."""
        if ":" not in v:
            raise ValueError("Scope must be in format 'GENOME:GENE' (e.g., 'all:gene1' or 'genome1:gene1')")
        parts = v.split(":")
        if len(parts) != 2:
            raise ValueError("Scope must have exactly one ':' separator")
        genome_part, gene_part = parts
        if not genome_part or not gene_part:
            raise ValueError("Both genome and gene parts must be non-empty")
        return v

    def is_compatible(self, other: GeneComparisonConfig) -> bool:
        """
        Check if this gene comparison configuration is compatible with another.
        Two configurations are compatible if they have the same parameters, except for scope.
        Scope can be different as long as they are not disjoint. Also, 'all' is compatible with any scope.

        Args:
            other (GeneComparisonConfig): The other gene comparison configuration to check compatibility with.
        """
        attrs=self.__dict__
        for key in attrs:
            if key!="scope":
                if attrs[key] != getattr(other, key):
                    return False
        self_genome_scope, self_gene_scope = self.scope.split(":")
        other_genome_scope, other_gene_scope = other.scope.split(":")
        if self_genome_scope == "all" or other_genome_scope == "all":
            return True
        return self_genome_scope == other_genome_scope and self_gene_scope == other_gene_scope

    @classmethod
    def from_dict(cls, config_dict: dict) -> GeneComparisonConfig:
        """Create a GeneComparisonConfig from a dictionary.

        Deprecated fields are ignored for backward compatibility.
        """
        cleaned_config = _drop_deprecated_fields(config_dict, _DEPRECATED_GENE_COMPARE_CONFIG_FIELDS)
        return cls(**cleaned_config)

    @classmethod
    def from_json(cls,json_file_dir:str)->GeneComparisonConfig:
        """Create a GeneComparisonConfig instance from a json file.

        Deprecated fields are ignored for backward compatibility.
        """
        with open(json_file_dir, 'r') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)

    def to_json(self,json_file_dir:str)->None:
        """Writes the the current object to a json file"""
        with open(json_file_dir,"w") as f:
            json.dump(self.__dict__,f)

    def to_dict(self)->dict:
        """Returns the dictionary representation of the current object"""
        return copy.copy(self.__dict__)

    def get_maximal_scope_config(self, other: GeneComparisonConfig) -> GeneComparisonConfig:
        """
        Get a new GeneComparisonConfig object with the maximal scope that is compatible with the two configurations.

        Args:
            other (GeneComparisonConfig): The other gene comparison configuration to get the maximal scope with.

        Returns:
            GeneComparisonConfig: The new gene comparison configuration with the maximal scope.
        """
        if not self.is_compatible(other):
            raise ValueError("The two comparison configurations are not compatible.")

        self_genome_scope, self_gene_scope = self.scope.split(":")
        other_genome_scope, other_gene_scope = other.scope.split(":")

        if self_genome_scope == "all" and other_genome_scope == "all":
            new_genome_scope = "all"
        elif self_genome_scope == "all":
            new_genome_scope = other_genome_scope
        elif other_genome_scope == "all":
            new_genome_scope = self_genome_scope
        else:
            new_genome_scope = self_genome_scope  # They must be equal if compatible

        if self_gene_scope == "all" and other_gene_scope == "all":
            new_gene_scope = "all"
        elif self_gene_scope == "all":
            new_gene_scope = other_gene_scope
        elif other_gene_scope == "all":
            new_gene_scope = self_gene_scope
        else:
            new_gene_scope = self_gene_scope  # They must be equal if compatible

        curr_config_dict=self.to_dict()
        curr_config_dict["scope"]=f"{new_genome_scope}:{new_gene_scope}"
        return GeneComparisonConfig(**curr_config_dict)
from_dict(config_dict) classmethod

Create a GeneComparisonConfig from a dictionary.

Deprecated fields are ignored for backward compatibility.

Source code in zipstrain/src/zipstrain/database.py
370
371
372
373
374
375
376
377
@classmethod
def from_dict(cls, config_dict: dict) -> GeneComparisonConfig:
    """Create a GeneComparisonConfig from a dictionary.

    Deprecated fields are ignored for backward compatibility.
    """
    cleaned_config = _drop_deprecated_fields(config_dict, _DEPRECATED_GENE_COMPARE_CONFIG_FIELDS)
    return cls(**cleaned_config)
from_json(json_file_dir) classmethod

Create a GeneComparisonConfig instance from a json file.

Deprecated fields are ignored for backward compatibility.

Source code in zipstrain/src/zipstrain/database.py
379
380
381
382
383
384
385
386
387
@classmethod
def from_json(cls,json_file_dir:str)->GeneComparisonConfig:
    """Create a GeneComparisonConfig instance from a json file.

    Deprecated fields are ignored for backward compatibility.
    """
    with open(json_file_dir, 'r') as f:
        config_dict = json.load(f)
    return cls.from_dict(config_dict)
get_maximal_scope_config(other)

Get a new GeneComparisonConfig object with the maximal scope that is compatible with the two configurations.

Parameters:

Name Type Description Default
other GeneComparisonConfig

The other gene comparison configuration to get the maximal scope with.

required

Returns:

Name Type Description
GeneComparisonConfig GeneComparisonConfig

The new gene comparison configuration with the maximal scope.

Source code in zipstrain/src/zipstrain/database.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def get_maximal_scope_config(self, other: GeneComparisonConfig) -> GeneComparisonConfig:
    """
    Get a new GeneComparisonConfig object with the maximal scope that is compatible with the two configurations.

    Args:
        other (GeneComparisonConfig): The other gene comparison configuration to get the maximal scope with.

    Returns:
        GeneComparisonConfig: The new gene comparison configuration with the maximal scope.
    """
    if not self.is_compatible(other):
        raise ValueError("The two comparison configurations are not compatible.")

    self_genome_scope, self_gene_scope = self.scope.split(":")
    other_genome_scope, other_gene_scope = other.scope.split(":")

    if self_genome_scope == "all" and other_genome_scope == "all":
        new_genome_scope = "all"
    elif self_genome_scope == "all":
        new_genome_scope = other_genome_scope
    elif other_genome_scope == "all":
        new_genome_scope = self_genome_scope
    else:
        new_genome_scope = self_genome_scope  # They must be equal if compatible

    if self_gene_scope == "all" and other_gene_scope == "all":
        new_gene_scope = "all"
    elif self_gene_scope == "all":
        new_gene_scope = other_gene_scope
    elif other_gene_scope == "all":
        new_gene_scope = self_gene_scope
    else:
        new_gene_scope = self_gene_scope  # They must be equal if compatible

    curr_config_dict=self.to_dict()
    curr_config_dict["scope"]=f"{new_genome_scope}:{new_gene_scope}"
    return GeneComparisonConfig(**curr_config_dict)
is_compatible(other)

Check if this gene comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope. Scope can be different as long as they are not disjoint. Also, 'all' is compatible with any scope.

Parameters:

Name Type Description Default
other GeneComparisonConfig

The other gene comparison configuration to check compatibility with.

required
Source code in zipstrain/src/zipstrain/database.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def is_compatible(self, other: GeneComparisonConfig) -> bool:
    """
    Check if this gene comparison configuration is compatible with another.
    Two configurations are compatible if they have the same parameters, except for scope.
    Scope can be different as long as they are not disjoint. Also, 'all' is compatible with any scope.

    Args:
        other (GeneComparisonConfig): The other gene comparison configuration to check compatibility with.
    """
    attrs=self.__dict__
    for key in attrs:
        if key!="scope":
            if attrs[key] != getattr(other, key):
                return False
    self_genome_scope, self_gene_scope = self.scope.split(":")
    other_genome_scope, other_gene_scope = other.scope.split(":")
    if self_genome_scope == "all" or other_genome_scope == "all":
        return True
    return self_genome_scope == other_genome_scope and self_gene_scope == other_gene_scope
to_dict()

Returns the dictionary representation of the current object

Source code in zipstrain/src/zipstrain/database.py
394
395
396
def to_dict(self)->dict:
    """Returns the dictionary representation of the current object"""
    return copy.copy(self.__dict__)
to_json(json_file_dir)

Writes the the current object to a json file

Source code in zipstrain/src/zipstrain/database.py
389
390
391
392
def to_json(self,json_file_dir:str)->None:
    """Writes the the current object to a json file"""
    with open(json_file_dir,"w") as f:
        json.dump(self.__dict__,f)
validate_scope(v) classmethod

Validate that scope follows GENOME:GENE format.

Source code in zipstrain/src/zipstrain/database.py
336
337
338
339
340
341
342
343
344
345
346
347
348
@field_validator("scope")
@classmethod
def validate_scope(cls, v: str) -> str:
    """Validate that scope follows GENOME:GENE format."""
    if ":" not in v:
        raise ValueError("Scope must be in format 'GENOME:GENE' (e.g., 'all:gene1' or 'genome1:gene1')")
    parts = v.split(":")
    if len(parts) != 2:
        raise ValueError("Scope must have exactly one ':' separator")
    genome_part, gene_part = parts
    if not genome_part or not gene_part:
        raise ValueError("Both genome and gene parts must be non-empty")
    return v

GeneComparisonDatabase

GeneComparisonDatabase object holds a reference to a gene comparison parquet file. The methods in this class serve to provide functionality for working with the gene comparison data in an easy and efficient manner. The comparison parquet file is the result of running gene-level comparisons, and optionally concatenating multiple compare parquet files from single comparisons. This parquet file must contain the following columns:

  • genome
  • gene
  • total_positions
  • share_allele_pos
  • ani
  • sample_1
  • sample_2

A GeneComparisonDatabase object needs a GeneComparisonConfig object to specify the parameters used for the comparison.

Parameters:

Name Type Description Default
profile_db ProfileDatabase

The profile database used for the comparison.

required
config GeneComparisonConfig

The gene comparison configuration used for the comparison.

required
comp_db_loc str | None

The location of the comparison database parquet file. If None, an empty comparison database is created.

None
Source code in zipstrain/src/zipstrain/database.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
class GeneComparisonDatabase:
    """
    GeneComparisonDatabase object holds a reference to a gene comparison parquet file. The methods in this class serve to provide
    functionality for working with the gene comparison data in an easy and efficient manner.
    The comparison parquet file is the result of running gene-level comparisons, and optionally concatenating multiple compare parquet files from single comparisons.
    This parquet file must contain the following columns:

    - genome
    - gene
    - total_positions
    - share_allele_pos
    - ani
    - sample_1
    - sample_2

    A GeneComparisonDatabase object needs a GeneComparisonConfig object to specify the parameters used for the comparison.

    Args:
        profile_db (ProfileDatabase): The profile database used for the comparison.
        config (GeneComparisonConfig): The gene comparison configuration used for the comparison.
        comp_db_loc (str|None): The location of the comparison database parquet file. If
            None, an empty comparison database is created.
    """
    COLUMN_NAMES = [
        "genome",
        "gene",
        "total_positions",
        "share_allele_pos",
        "ani",
        "sample_1",
        "sample_2"
    ]

    def __init__(self,
                 profile_db: ProfileDatabase,
                 config: GeneComparisonConfig,
                 comp_db_loc: str|None = None,
                 ):
        self.profile_db = profile_db
        self.config = config
        if comp_db_loc is not None:
            self.comp_db_loc = pathlib.Path(comp_db_loc)
            self._comp_db = pl.scan_parquet(self.comp_db_loc)
        else:
            self.comp_db_loc = None
            self._comp_db = pl.LazyFrame({
                "genome": [],
                "gene": [],
                "total_positions": [],
                "share_allele_pos": [],
                "ani": [],
                "sample_1": [],
                "sample_2": []
            }, schema={
                "genome": pl.Utf8,
                "gene": pl.Utf8,
                "total_positions": pl.Int64,
                "share_allele_pos": pl.Int64,
                "ani": pl.Float64,
                "sample_1": pl.Utf8,
                "sample_2": pl.Utf8
            })
            self.comp_db_loc = None

    @property
    def comp_db(self):
        return self._comp_db

    def _validate_db(self)->None:
        """Validate the gene comparison database structure and content."""
        self.profile_db._validate_db()

        if set(self._comp_db.collect_schema()) != set(self.COLUMN_NAMES):
            raise ValueError(f"Your comparison database must provide these extra columns: {set(self.COLUMN_NAMES) - set(self._comp_db.collect_schema())}")

        # Check if all profile names exist in the profile database
        profile_names_in_comp_db = set(self.get_all_profile_names())
        profile_names_in_profile_db = set(self.profile_db.db.select("profile_name").collect(engine="streaming").to_series().to_list())
        if not profile_names_in_comp_db.issubset(profile_names_in_profile_db):
            raise ValueError(f"The following profile names are in the comparison database but not in the profile database: {profile_names_in_comp_db - profile_names_in_profile_db}")

    def get_all_profile_names(self) -> set[str]:
        """
        Get all profile names that are in the comparison database.

        Returns:
            set[str]: Set of all profile names in the comparison database.
        """
        return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
            set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
        )

    def get_remaining_pairs(self) -> pl.LazyFrame:
        """
        Get pairs of profiles that are in the profile database but not in the comparison database.

        Returns:
            pl.LazyFrame: LazyFrame with columns profile_1 and profile_2 containing remaining pairs.
        """
        profiles = self.profile_db.db.select("profile_name")
        pairs = profiles.join(profiles, how="cross").rename({"profile_name": "profile_1", "profile_name_right": "profile_2"}).filter(pl.col("profile_1") < pl.col("profile_2"))
        samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(
            pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), 
            pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")
        ).select(["profile_1", "profile_2"])

        remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1", "profile_2"])
        return remaining_pairs

    def is_complete(self) -> bool:
        """
        Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.

        Returns:
            bool: True if all pairs have been compared, False otherwise.
        """
        return self.get_remaining_pairs().collect(engine="streaming").is_empty()

    def add_comp_database(self, comp_database: GeneComparisonDatabase) -> None:
        """Merge the provided gene comparison database into the current database.

        Args:
            comp_database (GeneComparisonDatabase): The gene comparison database to merge.

        Raises:
            ValueError: If the provided database is invalid or incompatible.
        """
        try:
            comp_database._validate_db()
        except Exception as e:
            raise ValueError(f"The comparison database provided is not valid: {e}")

        if not self.config.is_compatible(comp_database.config):
            raise ValueError("The comparison database provided is not compatible with the current comparison database.")

        self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
        self.config = self.config.get_maximal_scope_config(comp_database.config)

    def save_new_compare_database(self, output_path: str) -> None:
        """Save the database to a parquet file.

        Args:
            output_path (str): The path to save the parquet file to.

        Raises:
            ValueError: If the output path is the same as the current database location.
        """
        output_path = pathlib.Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # The new database must be written to a new location
        if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
            raise ValueError("The output path must be different from the current database location.")

        self.comp_db.sink_parquet(output_path)

    def update_compare_database(self) -> None:
        """Overwrites the comparison database saved on the disk to the current comparison database object.

        Raises:
            Exception: If comp_db_loc is not set or if update fails.
        """
        if self.comp_db_loc is None:
            raise Exception("comp_db_loc attribute is not determined yet!")
        try:
            tmp_path = pathlib.Path(tempfile.mktemp(suffix=".parquet", prefix="tmp_gene_comp_db_", dir=str(self.comp_db_loc.parent)))
            self.comp_db.sink_parquet(tmp_path)
            os.replace(tmp_path, self.comp_db_loc)
            self._comp_db = pl.scan_parquet(self.comp_db_loc)
        except Exception as e:
            raise Exception(f"Something went wrong when updating the comparison database: {e}")

    def dump_obj(self, output_path: str) -> None:
        """Dump the current object to a json file.

        Args:
            output_path (str): The path to save the json file to.
        """
        obj_dict = {
            "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
            "config": self.config.to_dict(),
            "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
        }
        with open(output_path, "w") as f:
            json.dump(obj_dict, f, indent=4)

    @classmethod
    def load_obj(cls, json_path: str) -> GeneComparisonDatabase:
        """Load a GeneComparisonDatabase object from a json file.

        Args:
            json_path (str): The path to the json file.

        Returns:
            GeneComparisonDatabase: The loaded GeneComparisonDatabase object.
        """
        with open(json_path, "r") as f:
            obj_dict = json.load(f)

        return cls(
            profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]),
            config=GeneComparisonConfig.from_dict(obj_dict["config"]),
            comp_db_loc=obj_dict["comp_db_loc"]
        )

    def to_complete_input_table(self) -> pl.LazyFrame:
        """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. 
        The table contains the following columns:

        - sample_name_1
        - sample_name_2
        - profile_location_1
        - profile_location_2

        Returns:
            pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
        """
        lf = self.get_remaining_pairs().rename({"profile_1": "sample_name_1", "profile_2": "sample_name_2"})
        return (lf.join(self.profile_db.db.select(["profile_name", "profile_location"]), 
                       left_on="sample_name_1", right_on="profile_name", how="left")
                .rename({"profile_location": "profile_location_1"})
                .join(self.profile_db.db.select(["profile_name", "profile_location"]), 
                      left_on="sample_name_2", right_on="profile_name", how="left")
                .rename({"profile_location": "profile_location_2"})
               ).sort(["sample_name_1", "sample_name_2"])
add_comp_database(comp_database)

Merge the provided gene comparison database into the current database.

Parameters:

Name Type Description Default
comp_database GeneComparisonDatabase

The gene comparison database to merge.

required

Raises:

Type Description
ValueError

If the provided database is invalid or incompatible.

Source code in zipstrain/src/zipstrain/database.py
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
def add_comp_database(self, comp_database: GeneComparisonDatabase) -> None:
    """Merge the provided gene comparison database into the current database.

    Args:
        comp_database (GeneComparisonDatabase): The gene comparison database to merge.

    Raises:
        ValueError: If the provided database is invalid or incompatible.
    """
    try:
        comp_database._validate_db()
    except Exception as e:
        raise ValueError(f"The comparison database provided is not valid: {e}")

    if not self.config.is_compatible(comp_database.config):
        raise ValueError("The comparison database provided is not compatible with the current comparison database.")

    self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
    self.config = self.config.get_maximal_scope_config(comp_database.config)
dump_obj(output_path)

Dump the current object to a json file.

Parameters:

Name Type Description Default
output_path str

The path to save the json file to.

required
Source code in zipstrain/src/zipstrain/database.py
849
850
851
852
853
854
855
856
857
858
859
860
861
def dump_obj(self, output_path: str) -> None:
    """Dump the current object to a json file.

    Args:
        output_path (str): The path to save the json file to.
    """
    obj_dict = {
        "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
        "config": self.config.to_dict(),
        "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
    }
    with open(output_path, "w") as f:
        json.dump(obj_dict, f, indent=4)
get_all_profile_names()

Get all profile names that are in the comparison database.

Returns:

Type Description
set[str]

set[str]: Set of all profile names in the comparison database.

Source code in zipstrain/src/zipstrain/database.py
758
759
760
761
762
763
764
765
766
767
def get_all_profile_names(self) -> set[str]:
    """
    Get all profile names that are in the comparison database.

    Returns:
        set[str]: Set of all profile names in the comparison database.
    """
    return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
        set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
    )
get_remaining_pairs()

Get pairs of profiles that are in the profile database but not in the comparison database.

Returns:

Type Description
LazyFrame

pl.LazyFrame: LazyFrame with columns profile_1 and profile_2 containing remaining pairs.

Source code in zipstrain/src/zipstrain/database.py
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def get_remaining_pairs(self) -> pl.LazyFrame:
    """
    Get pairs of profiles that are in the profile database but not in the comparison database.

    Returns:
        pl.LazyFrame: LazyFrame with columns profile_1 and profile_2 containing remaining pairs.
    """
    profiles = self.profile_db.db.select("profile_name")
    pairs = profiles.join(profiles, how="cross").rename({"profile_name": "profile_1", "profile_name_right": "profile_2"}).filter(pl.col("profile_1") < pl.col("profile_2"))
    samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(
        pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), 
        pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")
    ).select(["profile_1", "profile_2"])

    remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1", "profile_2"])
    return remaining_pairs
is_complete()

Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.

Returns:

Name Type Description
bool bool

True if all pairs have been compared, False otherwise.

Source code in zipstrain/src/zipstrain/database.py
786
787
788
789
790
791
792
793
def is_complete(self) -> bool:
    """
    Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.

    Returns:
        bool: True if all pairs have been compared, False otherwise.
    """
    return self.get_remaining_pairs().collect(engine="streaming").is_empty()
load_obj(json_path) classmethod

Load a GeneComparisonDatabase object from a json file.

Parameters:

Name Type Description Default
json_path str

The path to the json file.

required

Returns:

Name Type Description
GeneComparisonDatabase GeneComparisonDatabase

The loaded GeneComparisonDatabase object.

Source code in zipstrain/src/zipstrain/database.py
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
@classmethod
def load_obj(cls, json_path: str) -> GeneComparisonDatabase:
    """Load a GeneComparisonDatabase object from a json file.

    Args:
        json_path (str): The path to the json file.

    Returns:
        GeneComparisonDatabase: The loaded GeneComparisonDatabase object.
    """
    with open(json_path, "r") as f:
        obj_dict = json.load(f)

    return cls(
        profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]),
        config=GeneComparisonConfig.from_dict(obj_dict["config"]),
        comp_db_loc=obj_dict["comp_db_loc"]
    )
save_new_compare_database(output_path)

Save the database to a parquet file.

Parameters:

Name Type Description Default
output_path str

The path to save the parquet file to.

required

Raises:

Type Description
ValueError

If the output path is the same as the current database location.

Source code in zipstrain/src/zipstrain/database.py
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
def save_new_compare_database(self, output_path: str) -> None:
    """Save the database to a parquet file.

    Args:
        output_path (str): The path to save the parquet file to.

    Raises:
        ValueError: If the output path is the same as the current database location.
    """
    output_path = pathlib.Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # The new database must be written to a new location
    if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
        raise ValueError("The output path must be different from the current database location.")

    self.comp_db.sink_parquet(output_path)
to_complete_input_table()

This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

  • sample_name_1
  • sample_name_2
  • profile_location_1
  • profile_location_2

Returns:

Type Description
LazyFrame

pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.

Source code in zipstrain/src/zipstrain/database.py
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
def to_complete_input_table(self) -> pl.LazyFrame:
    """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. 
    The table contains the following columns:

    - sample_name_1
    - sample_name_2
    - profile_location_1
    - profile_location_2

    Returns:
        pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
    """
    lf = self.get_remaining_pairs().rename({"profile_1": "sample_name_1", "profile_2": "sample_name_2"})
    return (lf.join(self.profile_db.db.select(["profile_name", "profile_location"]), 
                   left_on="sample_name_1", right_on="profile_name", how="left")
            .rename({"profile_location": "profile_location_1"})
            .join(self.profile_db.db.select(["profile_name", "profile_location"]), 
                  left_on="sample_name_2", right_on="profile_name", how="left")
            .rename({"profile_location": "profile_location_2"})
           ).sort(["sample_name_1", "sample_name_2"])
update_compare_database()

Overwrites the comparison database saved on the disk to the current comparison database object.

Raises:

Type Description
Exception

If comp_db_loc is not set or if update fails.

Source code in zipstrain/src/zipstrain/database.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
def update_compare_database(self) -> None:
    """Overwrites the comparison database saved on the disk to the current comparison database object.

    Raises:
        Exception: If comp_db_loc is not set or if update fails.
    """
    if self.comp_db_loc is None:
        raise Exception("comp_db_loc attribute is not determined yet!")
    try:
        tmp_path = pathlib.Path(tempfile.mktemp(suffix=".parquet", prefix="tmp_gene_comp_db_", dir=str(self.comp_db_loc.parent)))
        self.comp_db.sink_parquet(tmp_path)
        os.replace(tmp_path, self.comp_db_loc)
        self._comp_db = pl.scan_parquet(self.comp_db_loc)
    except Exception as e:
        raise Exception(f"Something went wrong when updating the comparison database: {e}")

GenomeComparisonConfig

Bases: BaseModel

This class defines object which have all necessary options to describe Parameters used to compare profiles:

Attributes:

Name Type Description
gene_db_id str

The ID of the gene fasta database to use for the comparison. The file name is perfect.

reference_id str

The ID of the reference fasta database to use for the comparison. The file name is perfect.

scope str

The scope of the comparison- 'all' if all covered positions are desired. Otherwise, a bunch of genome names separated by commas.

min_cov int

Minimum coverage a base on the reference fasta that must have in order to be compared.

min_gene_compare_len int

Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations

stb_file_loc str

The location of the scaffold to bin file.

Source code in zipstrain/src/zipstrain/database.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class GenomeComparisonConfig(BaseModel):
    """
    This class defines object which have all necessary options to describe 
    Parameters used to compare profiles:

    Attributes:
        gene_db_id (str): The ID of the gene fasta database to use for the comparison. The file name is perfect.
        reference_id (str): The ID of the reference fasta database to use for the comparison. The file name is perfect.
        scope (str): The scope of the comparison- 'all' if all covered positions are desired. Otherwise, a bunch of genome names separated by commas.
        min_cov (int): Minimum coverage a base on the reference fasta that must have in order to be compared.
        min_gene_compare_len (int): Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations
        stb_file_loc (str): The location of the scaffold to bin file.
    """
    model_config = ConfigDict(extra="forbid")
    gene_db_id:str= Field(default="",description="An ID given to the gene fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with gene database IDs in the Profile Database.")
    reference_id:str= Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
    scope: str =Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
    min_cov: int =Field(description="Minimum coverage a base on the reference fasta that must have in order to be compared.")
    min_gene_compare_len: int=Field(description="Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations")
    stb_file_loc:str=Field(description="The location of the scaffold to bin file.")

    def is_compatible(self, other: GenomeComparisonConfig) -> bool:
        """
        Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope.
        Scope can be different as long as they are not disjoint. Also, all is compatible with any scope.
        Args:
            other (GenomeComparisonConfig): The other comparison configuration to check compatibility with.
        Returns:
            bool: True if the configurations are compatible, False otherwise.
        """
        attrs=self.__dict__
        for key in attrs:
            if key!="scope":
                if attrs[key] != getattr(other, key):
                    return False
        if other.scope != "all" and self.scope != "all":
            if (set(other.scope.split(",")).intersection(set(self.scope.split(","))) == set()):
                return False
        return True

    @classmethod
    def from_dict(cls, config_dict: dict) -> GenomeComparisonConfig:
        """Create a GenomeComparisonConfig from a dictionary.

        Deprecated fields are ignored for backward compatibility.
        """
        cleaned_config = _drop_deprecated_fields(config_dict, _DEPRECATED_GENOME_COMPARE_CONFIG_FIELDS)
        return cls(**cleaned_config)

    @classmethod
    def from_json(cls,json_file_dir:str)->GenomeComparisonConfig:
        """Create a GenomeComparisonConfig instance from a json file.

        Deprecated fields are ignored for backward compatibility.
        """
        with open(json_file_dir, 'r') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)

    def to_json(self,json_file_dir:str)->None:
        """Writes the the current object to a json file"""
        with open(json_file_dir,"w") as f:
            json.dump(self.__dict__,f)

    def to_dict(self)->dict:
        """Returns the dictionary representation of the current object"""
        return copy.copy(self.__dict__)


    def get_maximal_scope_config(self, other: GenomeComparisonConfig) -> GenomeComparisonConfig:
        """
        Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations.
        Args:
            other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with.
        Returns:
            GenomeComparisonConfig: The new comparison configuration with the maximal scope.
        """
        if not self.is_compatible(other):
            raise ValueError("The two comparison configurations are not compatible.")

        new_scope=None
        if other.scope == "all" and self.scope == "all":
            new_scope="all"

        elif other.scope == "all":
            new_scope=self.scope.split(",")

        elif self.scope == "all":
            new_scope=other.scope.split(",")

        else:
            new_scope=list(set(self.scope.split(",")).intersection(set(other.scope.split(","))))
        curr_config_dict=self.to_dict()
        curr_config_dict["scope"]=new_scope if new_scope=="all" else ",".join(sorted(new_scope))
        return GenomeComparisonConfig(**curr_config_dict)
from_dict(config_dict) classmethod

Create a GenomeComparisonConfig from a dictionary.

Deprecated fields are ignored for backward compatibility.

Source code in zipstrain/src/zipstrain/database.py
260
261
262
263
264
265
266
267
@classmethod
def from_dict(cls, config_dict: dict) -> GenomeComparisonConfig:
    """Create a GenomeComparisonConfig from a dictionary.

    Deprecated fields are ignored for backward compatibility.
    """
    cleaned_config = _drop_deprecated_fields(config_dict, _DEPRECATED_GENOME_COMPARE_CONFIG_FIELDS)
    return cls(**cleaned_config)
from_json(json_file_dir) classmethod

Create a GenomeComparisonConfig instance from a json file.

Deprecated fields are ignored for backward compatibility.

Source code in zipstrain/src/zipstrain/database.py
269
270
271
272
273
274
275
276
277
@classmethod
def from_json(cls,json_file_dir:str)->GenomeComparisonConfig:
    """Create a GenomeComparisonConfig instance from a json file.

    Deprecated fields are ignored for backward compatibility.
    """
    with open(json_file_dir, 'r') as f:
        config_dict = json.load(f)
    return cls.from_dict(config_dict)
get_maximal_scope_config(other)

Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations. Args: other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with. Returns: GenomeComparisonConfig: The new comparison configuration with the maximal scope.

Source code in zipstrain/src/zipstrain/database.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def get_maximal_scope_config(self, other: GenomeComparisonConfig) -> GenomeComparisonConfig:
    """
    Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations.
    Args:
        other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with.
    Returns:
        GenomeComparisonConfig: The new comparison configuration with the maximal scope.
    """
    if not self.is_compatible(other):
        raise ValueError("The two comparison configurations are not compatible.")

    new_scope=None
    if other.scope == "all" and self.scope == "all":
        new_scope="all"

    elif other.scope == "all":
        new_scope=self.scope.split(",")

    elif self.scope == "all":
        new_scope=other.scope.split(",")

    else:
        new_scope=list(set(self.scope.split(",")).intersection(set(other.scope.split(","))))
    curr_config_dict=self.to_dict()
    curr_config_dict["scope"]=new_scope if new_scope=="all" else ",".join(sorted(new_scope))
    return GenomeComparisonConfig(**curr_config_dict)
is_compatible(other)

Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope. Scope can be different as long as they are not disjoint. Also, all is compatible with any scope. Args: other (GenomeComparisonConfig): The other comparison configuration to check compatibility with. Returns: bool: True if the configurations are compatible, False otherwise.

Source code in zipstrain/src/zipstrain/database.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def is_compatible(self, other: GenomeComparisonConfig) -> bool:
    """
    Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope.
    Scope can be different as long as they are not disjoint. Also, all is compatible with any scope.
    Args:
        other (GenomeComparisonConfig): The other comparison configuration to check compatibility with.
    Returns:
        bool: True if the configurations are compatible, False otherwise.
    """
    attrs=self.__dict__
    for key in attrs:
        if key!="scope":
            if attrs[key] != getattr(other, key):
                return False
    if other.scope != "all" and self.scope != "all":
        if (set(other.scope.split(",")).intersection(set(self.scope.split(","))) == set()):
            return False
    return True
to_dict()

Returns the dictionary representation of the current object

Source code in zipstrain/src/zipstrain/database.py
284
285
286
def to_dict(self)->dict:
    """Returns the dictionary representation of the current object"""
    return copy.copy(self.__dict__)
to_json(json_file_dir)

Writes the the current object to a json file

Source code in zipstrain/src/zipstrain/database.py
279
280
281
282
def to_json(self,json_file_dir:str)->None:
    """Writes the the current object to a json file"""
    with open(json_file_dir,"w") as f:
        json.dump(self.__dict__,f)

GenomeComparisonDatabase

GenomeComparisonDatabase object holds a reference to a comparison parquet file. The methods in this class serve to provide functionality for working with the comparison data in an easy and efficient manner. The comparison parquet file the result of running compare, and optionally concatenating multiple compare parquet file from single comparisons. This parquet file must contain the following columns:

  • genome

  • total_positions

  • share_allele_pos

  • genome_pop_ani

  • max_consecutive_length

  • shared_genes_count

  • identical_gene_count

  • sample_1

  • sample_2

A ComparisonDatabase object needs a ComparisonConfig object to specify the parameters used for the comparison.

Parameters:

Name Type Description Default
profile_db ProfileDatabase

The profile database used for the comparison.

required
config GenomeComparisonConfig

The comparison configuration used for the comparison.

required
comp_db_loc str | None

The location of the comparison database parquet file. If None, an empty comparison database is created.

None
Source code in zipstrain/src/zipstrain/database.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
class GenomeComparisonDatabase:
    """
    GenomeComparisonDatabase object holds a reference to a comparison parquet file. The methods in this class serve to provide
    functionality for working with the comparison data in an easy and efficient manner.
    The comparison parquet file the result of running compare, and optionally concatenating multiple compare parquet file from single comparisons.
    This parquet file must contain the following columns:

    - genome

    - total_positions

    - share_allele_pos

    - genome_pop_ani

    - max_consecutive_length

    - shared_genes_count

    - identical_gene_count

    - sample_1

    - sample_2

    A ComparisonDatabase object needs a ComparisonConfig object to specify the parameters used for the comparison.

    Args:
        profile_db (ProfileDatabase): The profile database used for the comparison.
        config (GenomeComparisonConfig): The comparison configuration used for the comparison.
        comp_db_loc (str|None): The location of the comparison database parquet file. If
            None, an empty comparison database is created.

    """
    COLUMN_NAMES = [
        "genome",
        "total_positions",
        "share_allele_pos",
        "genome_pop_ani",
        "max_consecutive_length",
        "shared_genes_count",
        "identical_gene_count",
        "perc_id_genes",
        "sample_1",
        "sample_2"
    ]
    COLUMN_DTYPES = {
        "genome": pl.Utf8,
        "total_positions": pl.Int64,
        "share_allele_pos": pl.Int64,
        "genome_pop_ani": pl.Float64,
        "max_consecutive_length": pl.Int64,
        "shared_genes_count": pl.Int64,
        "identical_gene_count": pl.Int64,
        "perc_id_genes": pl.Float64,
        "sample_1": pl.Utf8,
        "sample_2": pl.Utf8,
    }

    def __init__(self,
                 profile_db: ProfileDatabase,
                 config: GenomeComparisonConfig,
                 comp_db_loc: str|None = None,
                 ):
        self.profile_db = profile_db
        self.config = config
        if comp_db_loc is not None:
            self.comp_db_loc = pathlib.Path(comp_db_loc)
            self._comp_db = self._normalize_comp_db_columns(pl.scan_parquet(self.comp_db_loc))
        else:
            self.comp_db_loc = None
            self._comp_db=pl.LazyFrame({
                "genome": [],
                "total_positions": [],
                "share_allele_pos": [],
                "genome_pop_ani": [],
                "max_consecutive_length": [],
                "shared_genes_count": [],
                "identical_gene_count": [],
                "perc_id_genes": [],
                "sample_1": [],
                "sample_2": []
            }, schema=self.COLUMN_DTYPES)
            self.comp_db_loc=None

    @property
    def comp_db(self):
        return self._comp_db

    def _normalize_comp_db_columns(self, comp_db: pl.LazyFrame) -> pl.LazyFrame:
        schema = comp_db.collect_schema()
        casts = [
            pl.col(col_name).cast(dtype, strict=False)
            for col_name, dtype in self.COLUMN_DTYPES.items()
            if col_name in schema
        ]
        if casts:
            comp_db = comp_db.with_columns(casts)

        missing = [col_name for col_name in self.COLUMN_NAMES if col_name not in schema]
        if missing:
            comp_db = comp_db.with_columns(
                [pl.lit(None).cast(self.COLUMN_DTYPES[col_name]).alias(col_name) for col_name in missing]
            )
        return comp_db.select(self.COLUMN_NAMES)

    def _validate_db(self)->None:
        self.profile_db._validate_db()

        if set(self._comp_db.collect_schema()) != set(self.COLUMN_NAMES):
            raise ValueError(f"Your comparison database must provide these extra columns: { set(self.COLUMN_NAMES)-set(self._comp_db.collect_schema())}")
        #check if all profile names exist in the profile database
        profile_names_in_comp_db = set(self.get_all_profile_names())
        profile_names_in_profile_db = set(self.profile_db.db.select("profile_name").collect(engine="streaming").to_series().to_list())
        if not profile_names_in_comp_db.issubset(profile_names_in_profile_db):
            raise ValueError(f"The following profile names are in the comparison database but not in the profile database: {profile_names_in_comp_db - profile_names_in_profile_db}")

    def get_all_profile_names(self) -> set[str]:
        """
        Get all profile names that are in the comparison database.
        """
        return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
            set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
        )
    def get_remaining_pairs(self) -> pl.LazyFrame:
        """
        Get pairs of profiles that are in the profile database but not in the comparison database.
        """
        profiles = self.profile_db.db.select("profile_name")
        pairs=profiles.join(profiles,how="cross").rename({"profile_name":"profile_1","profile_name_right":"profile_2"}).filter(pl.col("profile_1")<pl.col("profile_2"))
        samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")).select(["profile_1", "profile_2"])

        remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1","profile_2"])
        return remaining_pairs

    def is_complete(self) -> bool:
        """
        Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
        """
        return self.get_remaining_pairs().collect(engine="streaming").is_empty()

    def add_comp_database(self, comp_database: GenomeComparisonDatabase) -> None:
        """Merge the provided comparison database into the current database.

        Args:
            comp_database (ComparisonDatabase): The comparison database to merge.
        """
        try:
            comp_database._validate_db()

        except Exception as e:
            raise ValueError(f"The comparison database provided is not valid: {e}")

        if not self.config.is_compatible(comp_database.config):
            raise ValueError("The comparison database provided is not compatible with the current comparison database.")

        self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
        self.config = self.config.get_maximal_scope_config(comp_database.config)


    def save_new_compare_database(self, output_path: str) -> None:
        """Save the database to a parquet file."""
        output_path = pathlib.Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # The new database must be written to a new location
        if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
            raise ValueError("The output path must be different from the current database location.")

        self.comp_db.sink_parquet(output_path)


    def update_compare_database(self)->None:
        """Overwrites the comparison database saved on the disk to the current comparison database object
        """
        if self.comp_db_loc is None:
            raise Exception("comp_db_loc attribute is not determined yet!")
        try:
            tmp_path=pathlib.Path(tempfile.mktemp(suffix=".parquet",prefix="tmp_comp_db_",dir=str(self.comp_db_loc.parent)))
            self.comp_db.sink_parquet(tmp_path)
            os.replace(tmp_path,self.comp_db_loc)
            self._comp_db=pl.scan_parquet(self.comp_db_loc)
        except Exception as e:
            raise Exception(f"Something went wrong when updating the comparison database:{e}")

    def dump_obj(self, output_path: str) -> None:
        """Dump the current object to a json file.

        Args:
            output_path (str): The path to save the json file to.
        """
        obj_dict = {
            "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
            "config": self.config.to_dict(),
            "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
        }
        with open(output_path, "w") as f:
            json.dump(obj_dict, f, indent=4)

    @classmethod
    def load_obj(cls, json_path: str) -> GenomeComparisonDatabase:
        """Load a GenomeComparisonDatabase object from a json file.

        Args:
            json_path (str): The path to the json file.

        Returns:
            GenomeComparisonDatabase: The loaded GenomeComparisonDatabase object.
        """
        with open(json_path, "r") as f:
            obj_dict = json.load(f)

        return cls(profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]) , 
                   config=GenomeComparisonConfig.from_dict(obj_dict["config"]), 
                   comp_db_loc=obj_dict["comp_db_loc"])


    def to_complete_input_table(self)->pl.LazyFrame:
        """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

        - sample_name_1

        - sample_name_2

        - profile_location_1

        - profile_location_2

        Returns:
            pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
        """
        lf=self.get_remaining_pairs().rename({"profile_1":"sample_name_1","profile_2":"sample_name_2"})
        return (lf.join(self.profile_db.db.select(["profile_name","profile_location"]),left_on="sample_name_1",right_on="profile_name",how="left")
                .rename({"profile_location":"profile_location_1"})
                .join(self.profile_db.db.select(["profile_name","profile_location"]),left_on="sample_name_2",right_on="profile_name",how="left")
                .rename({"profile_location":"profile_location_2"})
               ).sort(["sample_name_1","sample_name_2"])
add_comp_database(comp_database)

Merge the provided comparison database into the current database.

Parameters:

Name Type Description Default
comp_database ComparisonDatabase

The comparison database to merge.

required
Source code in zipstrain/src/zipstrain/database.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
def add_comp_database(self, comp_database: GenomeComparisonDatabase) -> None:
    """Merge the provided comparison database into the current database.

    Args:
        comp_database (ComparisonDatabase): The comparison database to merge.
    """
    try:
        comp_database._validate_db()

    except Exception as e:
        raise ValueError(f"The comparison database provided is not valid: {e}")

    if not self.config.is_compatible(comp_database.config):
        raise ValueError("The comparison database provided is not compatible with the current comparison database.")

    self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
    self.config = self.config.get_maximal_scope_config(comp_database.config)
dump_obj(output_path)

Dump the current object to a json file.

Parameters:

Name Type Description Default
output_path str

The path to save the json file to.

required
Source code in zipstrain/src/zipstrain/database.py
622
623
624
625
626
627
628
629
630
631
632
633
634
def dump_obj(self, output_path: str) -> None:
    """Dump the current object to a json file.

    Args:
        output_path (str): The path to save the json file to.
    """
    obj_dict = {
        "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
        "config": self.config.to_dict(),
        "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
    }
    with open(output_path, "w") as f:
        json.dump(obj_dict, f, indent=4)
get_all_profile_names()

Get all profile names that are in the comparison database.

Source code in zipstrain/src/zipstrain/database.py
554
555
556
557
558
559
560
def get_all_profile_names(self) -> set[str]:
    """
    Get all profile names that are in the comparison database.
    """
    return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
        set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
    )
get_remaining_pairs()

Get pairs of profiles that are in the profile database but not in the comparison database.

Source code in zipstrain/src/zipstrain/database.py
561
562
563
564
565
566
567
568
569
570
def get_remaining_pairs(self) -> pl.LazyFrame:
    """
    Get pairs of profiles that are in the profile database but not in the comparison database.
    """
    profiles = self.profile_db.db.select("profile_name")
    pairs=profiles.join(profiles,how="cross").rename({"profile_name":"profile_1","profile_name_right":"profile_2"}).filter(pl.col("profile_1")<pl.col("profile_2"))
    samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")).select(["profile_1", "profile_2"])

    remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1","profile_2"])
    return remaining_pairs
is_complete()

Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.

Source code in zipstrain/src/zipstrain/database.py
572
573
574
575
576
def is_complete(self) -> bool:
    """
    Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
    """
    return self.get_remaining_pairs().collect(engine="streaming").is_empty()
load_obj(json_path) classmethod

Load a GenomeComparisonDatabase object from a json file.

Parameters:

Name Type Description Default
json_path str

The path to the json file.

required

Returns:

Name Type Description
GenomeComparisonDatabase GenomeComparisonDatabase

The loaded GenomeComparisonDatabase object.

Source code in zipstrain/src/zipstrain/database.py
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
@classmethod
def load_obj(cls, json_path: str) -> GenomeComparisonDatabase:
    """Load a GenomeComparisonDatabase object from a json file.

    Args:
        json_path (str): The path to the json file.

    Returns:
        GenomeComparisonDatabase: The loaded GenomeComparisonDatabase object.
    """
    with open(json_path, "r") as f:
        obj_dict = json.load(f)

    return cls(profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]) , 
               config=GenomeComparisonConfig.from_dict(obj_dict["config"]), 
               comp_db_loc=obj_dict["comp_db_loc"])
save_new_compare_database(output_path)

Save the database to a parquet file.

Source code in zipstrain/src/zipstrain/database.py
597
598
599
600
601
602
603
604
605
606
def save_new_compare_database(self, output_path: str) -> None:
    """Save the database to a parquet file."""
    output_path = pathlib.Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # The new database must be written to a new location
    if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
        raise ValueError("The output path must be different from the current database location.")

    self.comp_db.sink_parquet(output_path)
to_complete_input_table()

This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

  • sample_name_1

  • sample_name_2

  • profile_location_1

  • profile_location_2

Returns:

Type Description
LazyFrame

pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.

Source code in zipstrain/src/zipstrain/database.py
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
def to_complete_input_table(self)->pl.LazyFrame:
    """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

    - sample_name_1

    - sample_name_2

    - profile_location_1

    - profile_location_2

    Returns:
        pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
    """
    lf=self.get_remaining_pairs().rename({"profile_1":"sample_name_1","profile_2":"sample_name_2"})
    return (lf.join(self.profile_db.db.select(["profile_name","profile_location"]),left_on="sample_name_1",right_on="profile_name",how="left")
            .rename({"profile_location":"profile_location_1"})
            .join(self.profile_db.db.select(["profile_name","profile_location"]),left_on="sample_name_2",right_on="profile_name",how="left")
            .rename({"profile_location":"profile_location_2"})
           ).sort(["sample_name_1","sample_name_2"])
update_compare_database()

Overwrites the comparison database saved on the disk to the current comparison database object

Source code in zipstrain/src/zipstrain/database.py
609
610
611
612
613
614
615
616
617
618
619
620
def update_compare_database(self)->None:
    """Overwrites the comparison database saved on the disk to the current comparison database object
    """
    if self.comp_db_loc is None:
        raise Exception("comp_db_loc attribute is not determined yet!")
    try:
        tmp_path=pathlib.Path(tempfile.mktemp(suffix=".parquet",prefix="tmp_comp_db_",dir=str(self.comp_db_loc.parent)))
        self.comp_db.sink_parquet(tmp_path)
        os.replace(tmp_path,self.comp_db_loc)
        self._comp_db=pl.scan_parquet(self.comp_db_loc)
    except Exception as e:
        raise Exception(f"Something went wrong when updating the comparison database:{e}")

ProfileDatabase

The profile database simply holds profile information. Does not need to be specific to a comparison database. The data behind a profile is stored in a parquet file. It is basically a table with the following columns:

  • profile_name: An arbitrary name given to the profile (Usually sample name or name of the parquet file)

  • profile_location: The location of the profile

  • reference_db_id: The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.

  • gene_db_id: The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.

Parameters:

Name Type Description Default
db_loc str | None

The location of the profile database parquet file. If None, an empty database is created.

None
Source code in zipstrain/src/zipstrain/database.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class ProfileDatabase:
    """
    The profile database simply holds profile information. Does not need to be specific to a comparison database.
    The data behind a profile is stored in a parquet file. It is basically a table with the following columns:

    - profile_name: An arbitrary name given to the profile (Usually sample name or name of the parquet file)

    - profile_location: The location of the profile

    - reference_db_id: The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.

    - gene_db_id: The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.

    Args:
        db_loc (str|None): The location of the profile database parquet file. If None, an empty database is created.

    """
    def __init__(self,
                 db_loc: str|None = None,
                 ):
        if db_loc is not None:
            self.db_loc = pathlib.Path(db_loc)
            self._db = pl.scan_parquet(self.db_loc).select(_PROFILE_DB_COLUMNS)
        else:
            self._db=pl.LazyFrame({
                "profile_name": [],
                "profile_location": [],
                "reference_db_id": [],
                "gene_db_id": []
            }, schema={
                "profile_name": pl.Utf8,
                "profile_location": pl.Utf8,
                "reference_db_id": pl.Utf8,
                "gene_db_id": pl.Utf8
            })
            self.db_loc=None

    @property
    def db(self):
        return self._db

    def _validate_db(self,check_profile_exists: bool=True)->None:
        """Simple method to see if the database has the minimum required structure."""

        ### Next check if the database has the required columns
        required_columns = _PROFILE_DB_COLUMNS
        for col in required_columns:
            if col not in self.db.collect_schema().names():
                raise ValueError(f"Missing required column: {col}")

        if check_profile_exists:
            # Check if the profile exists in the database
            db_path_validated= self.db.select(pl.col("profile_location")).collect(engine="streaming").with_columns(
                (pl.col("profile_location").map_elements(lambda x: pathlib.Path(x).exists(),return_dtype=pl.Boolean)).alias("profile_exists")
            ).filter(~ pl.col("profile_exists"))
            if db_path_validated.height != 0:
                raise ValueError(f"There are {db_path_validated.height} profiles that do not exist: {db_path_validated['profile_location'].to_list()}")
            ### add log later

    def add_profile(self,
                    data: dict
                    ) -> None:
        """Add a profile to the database.
        The data dictionary must contain the following and only the following keys:

        - profile_name

        - profile_location

        - reference_db_id

        - gene_db_id

        Args:
            data (dict): The profile data to add.
        """
        try:
            profile_item = ProfileItem(**data)
            lf=pl.LazyFrame({
                "profile_name": [profile_item.profile_name],
                "profile_location": [profile_item.profile_location],
                "reference_db_id": [profile_item.reference_db_id],
                "gene_db_id": [profile_item.gene_db_id]
            })
            self._db = pl.concat([self.db, lf]).unique()
            self._validate_db()
        except Exception as e:
            raise ValueError(f"The profile data provided is not valid: {e}")


    def add_database(self, profile_database: ProfileDatabase) -> None:
        """Merge the provided profile database into the current database.

        Args:
            profile_database (ProfileDatabase): The profile database to merge.
        """
        try:
            profile_database._validate_db()

        except Exception as e:
            raise ValueError(f"The profile database provided is not valid: {e}")

        self._db = pl.concat([self._db, profile_database.db]).unique()


    def save_as_new_database(self, output_path: str) -> None:
        """Save the database to a parquet file.

        Args:
            output_path (str): The path to save the database to.
        """
        #The new database must be written to a new location
        if self.db_loc is not None and str(self.db_loc.absolute()) == str(pathlib.Path(output_path).absolute()):
            raise ValueError("The output path must be different from the current database location.")

        try:
            self.db.sink_parquet(output_path)
            self.db_loc=pathlib.Path(output_path)
        ### add log later
        except Exception as e:
            pass 
        ### add log later

    def update_database(self)->None:
        """Overwrites the database saved on the disk to the current database object
        """
        if self.db_loc is None:
            raise Exception("db_loc attribute is not determined yet!")
        try:
            self.db.collect(engine="streaming").write_parquet(self.db_loc)
        except Exception as e:
            raise Exception(f"Something went wrong when updating the database:{e}")


    @classmethod
    def from_csv(cls, csv_path: str) -> ProfileDatabase:
        """Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

        Args:
            csv_path (str): The path to the CSV file.

        Returns:
            ProfileDatabase: The created ProfileDatabase instance.
        """
        lf=pl.scan_csv(csv_path).select(_PROFILE_DB_COLUMNS).collect().lazy() # To avoid clash when using to_csv on same file
        prof_db=cls()
        prof_db._db=lf
        prof_db._validate_db()
        return prof_db

    def to_csv(self,output_dir:str)->None:
        """Writes the the current database object to a csv file"

        Args:
            output_dir (str): The path to save the CSV file.

        Returns:
            None
        """
        self.db.sink_csv(output_dir,engine="streaming")
add_database(profile_database)

Merge the provided profile database into the current database.

Parameters:

Name Type Description Default
profile_database ProfileDatabase

The profile database to merge.

required
Source code in zipstrain/src/zipstrain/database.py
144
145
146
147
148
149
150
151
152
153
154
155
156
def add_database(self, profile_database: ProfileDatabase) -> None:
    """Merge the provided profile database into the current database.

    Args:
        profile_database (ProfileDatabase): The profile database to merge.
    """
    try:
        profile_database._validate_db()

    except Exception as e:
        raise ValueError(f"The profile database provided is not valid: {e}")

    self._db = pl.concat([self._db, profile_database.db]).unique()
add_profile(data)

Add a profile to the database. The data dictionary must contain the following and only the following keys:

  • profile_name

  • profile_location

  • reference_db_id

  • gene_db_id

Parameters:

Name Type Description Default
data dict

The profile data to add.

required
Source code in zipstrain/src/zipstrain/database.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def add_profile(self,
                data: dict
                ) -> None:
    """Add a profile to the database.
    The data dictionary must contain the following and only the following keys:

    - profile_name

    - profile_location

    - reference_db_id

    - gene_db_id

    Args:
        data (dict): The profile data to add.
    """
    try:
        profile_item = ProfileItem(**data)
        lf=pl.LazyFrame({
            "profile_name": [profile_item.profile_name],
            "profile_location": [profile_item.profile_location],
            "reference_db_id": [profile_item.reference_db_id],
            "gene_db_id": [profile_item.gene_db_id]
        })
        self._db = pl.concat([self.db, lf]).unique()
        self._validate_db()
    except Exception as e:
        raise ValueError(f"The profile data provided is not valid: {e}")
from_csv(csv_path) classmethod

Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

Parameters:

Name Type Description Default
csv_path str

The path to the CSV file.

required

Returns:

Name Type Description
ProfileDatabase ProfileDatabase

The created ProfileDatabase instance.

Source code in zipstrain/src/zipstrain/database.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
@classmethod
def from_csv(cls, csv_path: str) -> ProfileDatabase:
    """Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

    Args:
        csv_path (str): The path to the CSV file.

    Returns:
        ProfileDatabase: The created ProfileDatabase instance.
    """
    lf=pl.scan_csv(csv_path).select(_PROFILE_DB_COLUMNS).collect().lazy() # To avoid clash when using to_csv on same file
    prof_db=cls()
    prof_db._db=lf
    prof_db._validate_db()
    return prof_db
save_as_new_database(output_path)

Save the database to a parquet file.

Parameters:

Name Type Description Default
output_path str

The path to save the database to.

required
Source code in zipstrain/src/zipstrain/database.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def save_as_new_database(self, output_path: str) -> None:
    """Save the database to a parquet file.

    Args:
        output_path (str): The path to save the database to.
    """
    #The new database must be written to a new location
    if self.db_loc is not None and str(self.db_loc.absolute()) == str(pathlib.Path(output_path).absolute()):
        raise ValueError("The output path must be different from the current database location.")

    try:
        self.db.sink_parquet(output_path)
        self.db_loc=pathlib.Path(output_path)
    ### add log later
    except Exception as e:
        pass 
to_csv(output_dir)

Writes the the current database object to a csv file"

Parameters:

Name Type Description Default
output_dir str

The path to save the CSV file.

required

Returns:

Type Description
None

None

Source code in zipstrain/src/zipstrain/database.py
204
205
206
207
208
209
210
211
212
213
def to_csv(self,output_dir:str)->None:
    """Writes the the current database object to a csv file"

    Args:
        output_dir (str): The path to save the CSV file.

    Returns:
        None
    """
    self.db.sink_csv(output_dir,engine="streaming")
update_database()

Overwrites the database saved on the disk to the current database object

Source code in zipstrain/src/zipstrain/database.py
177
178
179
180
181
182
183
184
185
def update_database(self)->None:
    """Overwrites the database saved on the disk to the current database object
    """
    if self.db_loc is None:
        raise Exception("db_loc attribute is not determined yet!")
    try:
        self.db.collect(engine="streaming").write_parquet(self.db_loc)
    except Exception as e:
        raise Exception(f"Something went wrong when updating the database:{e}")

ProfileItem

Bases: BaseModel

This class describes all necessary attributes of a profile and makes sure they comply with the necessary formating.

Source code in zipstrain/src/zipstrain/database.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ProfileItem(BaseModel):
    """
    This class describes all necessary attributes of a profile and makes sure they comply with the necessary formating.
    """
    model_config = ConfigDict(extra="forbid")
    profile_name: str = Field(description="An arbitrary name given to the profile (Usually sample name or name of the parquet file)")
    profile_location: str = Field(description="The location of the profile")
    reference_db_id: str = Field(description="The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.")
    gene_db_id:str= Field(default="",description="The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.")

    @field_validator("profile_location")
    def check_file_exists(cls, v):
        if not os.path.exists(v):
            raise ValueError(f"The file {v} does not exist.")
        return v

    @field_validator("reference_db_id","gene_db_id")
    def check_reference_db_id(cls, v):
        if not v:
            raise ValueError("The reference_db_id and gene_db_id cannot be empty.")
        return v

Profile

zipstrain.profile

This module provides functions and utilities to profile a bamfile. By profile we mean generating gene, genome, and nucleotide counts at each position on the reference. This is a fundamental step for downstream analysis in zipstrain.

adjust_for_sequence_errors(mpile_frame, null_model)

Adjust the mpile frame for sequence errors based on the null model.

Parameters:

Name Type Description Default
mpile_frame LazyFrame

The input LazyFrame containing coverage data.

required
null_model LazyFrame

The null model LazyFrame containing error counts.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Adjusted LazyFrame with sequence errors accounted for.

Source code in zipstrain/src/zipstrain/profile.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def adjust_for_sequence_errors(mpile_frame:pl.LazyFrame, null_model:pl.LazyFrame) -> pl.LazyFrame:
    """
    Adjust the mpile frame for sequence errors based on the null model.

    Args:
        mpile_frame (pl.LazyFrame): The input LazyFrame containing coverage data.
        null_model (pl.LazyFrame): The null model LazyFrame containing error counts.

    Returns:
        pl.LazyFrame: Adjusted LazyFrame with sequence errors accounted for.
    """
    mpile_frame = mpile_frame.with_columns(
        pl.sum_horizontal(["A", "T", "C", "G"]).alias("cov")
    )
    return mpile_frame.join(null_model, on="cov", how="left").with_columns([
        pl.when(pl.col(base) >= pl.col("max_error_count"))
        .then(pl.col(base))
        .otherwise(0)
        .alias(base)
        for base in ["A", "T", "C", "G"]
    ]).drop("max_error_count")

build_gene_loc_table(fasta_file, scaffold)

Build a gene location table from a FASTA file.

Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Returns: pl.DataFrame: A Polars DataFrame containing gene locations.

Source code in zipstrain/src/zipstrain/profile.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def build_gene_loc_table(fasta_file:pathlib.Path,scaffold:set)->pl.DataFrame:
    """
    Build a gene location table from a FASTA file.

    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    pl.DataFrame: A Polars DataFrame containing gene locations.
    """
    scaffolds = []
    gene_ids = []
    pos=[]
    for genes in parse_gene_loc_table(fasta_file):
        if genes[1] in scaffold:
            scaffolds.extend([genes[1]]* (int(genes[3])-int(genes[2])+1))
            gene_ids.extend([genes[0]]* (int(genes[3])-int(genes[2])+1))
            pos.extend(list(range(int(genes[2]), int(genes[3])+1)))
    return pl.DataFrame({
        "scaffold":scaffolds,
        "gene":gene_ids,
        "pos":pos
    })

build_gene_range_table(fasta_file)

Build a gene location table in the form of from a FASTA file. Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Returns: pl.DataFrame: A Polars DataFrame containing gene locations.

Source code in zipstrain/src/zipstrain/profile.py
85
86
87
88
89
90
91
92
93
94
95
96
97
def build_gene_range_table(fasta_file:pathlib.Path)->pl.DataFrame:
    """
    Build a gene location table in the form of <gene scaffold start end> from a FASTA file.
    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    pl.DataFrame: A Polars DataFrame containing gene locations.
    """
    out=[]
    for parsed_annot in parse_gene_loc_table(fasta_file):
        out.append(parsed_annot)
    return pl.DataFrame(out, schema=["gene", "scaffold", "start", "end"],orient='row')

get_strain_hetrogeneity(profile, stb, min_cov=5, freq_threshold=0.8)

Calculate strain heterogeneity for each genome based on nucleotide frequencies. The definition of strain heterogeneity here is the fraction of sites that have enough coverage (min_cov) and have a dominant nucleotide with frequency less than freq_threshold.

Parameters:

Name Type Description Default
profile LazyFrame

The profile LazyFrame containing nucleotide counts.

required
stb LazyFrame

The scaffold-to-bin mapping LazyFrame. First column is 'scaffold', second column is 'bin'.

required
min_cov int

The minimum coverage threshold.

5
freq_threshold float

The frequency threshold for dominant nucleotides.

0.8

Returns: pl.LazyFrame: A LazyFrame containing strain heterogeneity information grouped by genome.

Source code in zipstrain/src/zipstrain/profile.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def get_strain_hetrogeneity(profile:pl.LazyFrame,
                            stb:pl.LazyFrame, 
                            min_cov=5,
                            freq_threshold=0.8)->pl.LazyFrame:
    """
    Calculate strain heterogeneity for each genome based on nucleotide frequencies.
    The definition of strain heterogeneity here is the fraction of sites that have enough coverage
    (min_cov) and have a dominant nucleotide with frequency less than freq_threshold.

    Args:
        profile (pl.LazyFrame): The profile LazyFrame containing nucleotide counts.
        stb (pl.LazyFrame): The scaffold-to-bin mapping LazyFrame. First column is 'scaffold', second column is 'bin'.
        min_cov (int): The minimum coverage threshold.
        freq_threshold (float): The frequency threshold for dominant nucleotides.

    Returns:
    pl.LazyFrame: A LazyFrame containing strain heterogeneity information grouped by genome.
    """
    # Calculate the total number of sites with sufficient coverage
    profile = profile.with_columns(
        (pl.col("A")+pl.col("T")+pl.col("C")+pl.col("G")).alias("coverage")
    ).filter(pl.col("coverage") >= min_cov)

    profile = profile.with_columns(
        (pl.max_horizontal(["A", "T", "C", "G"])/pl.col("coverage") < freq_threshold)
        .cast(pl.Int8)
        .alias("heterogeneous_site")
    )

    profile = profile.join(stb, left_on="chrom", right_on="scaffold", how="left").group_by("genome").agg([
        pl.len().alias(f"total_sites_at_{min_cov}_coverage"),
        pl.sum("heterogeneous_site").alias("heterogeneous_sites")
    ])

    strain_heterogeneity = profile.with_columns(
        (pl.col("heterogeneous_sites")/pl.col(f"total_sites_at_{min_cov}_coverage")).alias("strain_heterogeneity")
    )
    return strain_heterogeneity

parse_gene_loc_table(fasta_file)

Extract gene locations from a FASTA assuming it is from prodigal yield gene info.

Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Tuple: A tuple containing: - gene_ID - scaffold - start - end

Source code in zipstrain/src/zipstrain/profile.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def parse_gene_loc_table(fasta_file:pathlib.Path) -> Generator[tuple,None,None]:
    """
    Extract gene locations from a FASTA assuming it is from prodigal yield gene info.

    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    Tuple: A tuple containing:
        - gene_ID
        - scaffold
        - start
        - end
    """
    with open(fasta_file, 'r') as f:
        for line in f:
            if line.startswith('>'):
                parts = line[1:].strip().split()
                gene_id = parts[0]
                scaffold = "_".join(gene_id.split('_')[:-1])
                start = parts[2]
                end=parts[4]      
                yield gene_id, scaffold,start,end

profile_bam(bed_file, bam_file, gene_range_table, stb, null_model, output_dir, num_workers=4)

Profile a BAM file in chunks using provided BED files.

Parameters: bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled. bam_file (pathlib.Path): Path to the BAM file. gene_range_table (pathlib.Path): Path to the gene range table. stb (pl.LazyFrame): Scaffold-to-genome mapping table. null_model (pl.LazyFrame): The null model to be used for adjusting for sequence errors. output_dir (pathlib.Path): Directory to save output files. num_workers (int): Number of concurrent workers to use.

Source code in zipstrain/src/zipstrain/profile.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def profile_bam(
    bed_file:str,
    bam_file:str,
    gene_range_table:str,
    stb:pl.LazyFrame,
    null_model:pl.LazyFrame,
    output_dir:str,
    num_workers:int=4,
)->None:
    """
    Profile a BAM file in chunks using provided BED files.

    Parameters:
    bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
    bam_file (pathlib.Path): Path to the BAM file.
    gene_range_table (pathlib.Path): Path to the gene range table.
    stb (pl.LazyFrame): Scaffold-to-genome mapping table.
    null_model (pl.LazyFrame): The null model to be used for adjusting for sequence errors.
    output_dir (pathlib.Path): Directory to save output files.
    num_workers (int): Number of concurrent workers to use.
    """
    asyncio.run(profile_bam_in_chunks(
        bed_file=bed_file,
        bam_file=bam_file,
        gene_range_table=gene_range_table,
        stb=stb,
        null_model=null_model,
        output_dir=output_dir,
        num_workers=num_workers,
    ))

profile_bam_in_chunks(bed_file, bam_file, gene_range_table, stb, null_model, output_dir, num_workers=4) async

Profile a BAM file in chunks using provided BED files.

Parameters: bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled. bam_file (pathlib.Path): Path to the BAM file. gene_range_table (pathlib.Path): Path to the gene range table. stb (pl.LazyFrame): The scaffold-to-genome mapping table. null_model (pl.LazyFrame): The null model to be used for adjusting for sequence errors. output_dir (pathlib.Path): Directory to save output files. num_workers (int): Number of concurrent workers to use.

Source code in zipstrain/src/zipstrain/profile.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
async def profile_bam_in_chunks(
    bed_file:str,
    bam_file:str,
    gene_range_table:str,
    stb:pl.LazyFrame,
    null_model:pl.LazyFrame,
    output_dir:str,
    num_workers:int=4,
)->None:
    """
    Profile a BAM file in chunks using provided BED files.

    Parameters:
    bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
    bam_file (pathlib.Path): Path to the BAM file.
    gene_range_table (pathlib.Path): Path to the gene range table.
    stb (pl.LazyFrame): The scaffold-to-genome mapping table.
    null_model (pl.LazyFrame): The null model to be used for adjusting for sequence errors.
    output_dir (pathlib.Path): Directory to save output files.
    num_workers (int): Number of concurrent workers to use.
    """

    output_dir=pathlib.Path(output_dir)
    bam_file=pathlib.Path(bam_file)
    bed_file=pathlib.Path(bed_file)
    gene_range_table=pathlib.Path(gene_range_table)

    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir/"tmp").mkdir(exist_ok=True)
    bed_lf=pl.scan_csv(bed_file,has_header=False,separator="\t")
    gene_range_lf = pl.scan_csv(
        gene_range_table,
        has_header=False,
        separator="\t",
    ).rename({
        "column_1": "gene",
        "column_2": "scaffold",
        "column_3": "start",
        "column_4": "end",
    })
    bed_chunks=utils.split_lf_to_chunks(bed_lf,num_workers)
    bed_chunk_files=[]
    for chunk_id, bed_file in enumerate(bed_chunks):
        bed_file.sink_csv(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed",include_header=False,separator="\t")
        bed_chunk_files.append(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed")
    tasks = []
    for chunk_id, bed_chunk_file in enumerate(bed_chunk_files):
        tasks.append(_profile_chunk_task(
            bed_file=bed_chunk_file,
            bam_file=bam_file,
            gene_range_table=gene_range_table,
            stb=stb,
            null_model=null_model,
            output_dir=output_dir/"tmp",
            chunk_id=chunk_id
        ))
    await asyncio.gather(*tasks) 
    pfs=[(output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet", output_dir/"tmp"/f"{bam_file.stem}_read_locs_{chunk_id}.parquet" ) for chunk_id in range(len(bed_chunk_files)) if (output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet").exists()]

    mpile_container: list[pl.LazyFrame] = []
    read_loc_pfs: list[pl.LazyFrame] = []

    for pf, read_loc_pf in pfs:

        if pf.exists():
            mpile_container.append(pl.scan_parquet(pf).lazy())

        if read_loc_pf.exists():
            read_loc_pfs.append(pl.scan_parquet(read_loc_pf).lazy())
    if mpile_container:
        mpileup_df = pl.concat(mpile_container)
        mpileup_df=mpileup_df.sort(["genome", "chrom", "pos"],descending=[False, False, False])
        with pl.StringCache():
            mpileup_df = mpileup_df.with_columns([
                pl.col("chrom").cast(pl.Categorical),
                pl.col("genome").cast(pl.Categorical),
                pl.col("gene").cast(pl.Categorical),
            ])
        mpileup_df.sink_parquet(output_dir/f"{bam_file.stem}_profile.parquet", compression='zstd', engine='streaming')
        utils.get_gene_stats(
            profile=mpileup_df,
            gene_bed=gene_range_lf,
            stb=stb,
        ).sink_parquet(
            output_dir/f"{bam_file.stem}_gene_stats.parquet",
            compression='zstd',
            engine='streaming',
        )

    if read_loc_pfs:
        read_loc_df = pl.concat(read_loc_pfs).rename(
            {
                "chrom":"scaffold",
            "pos":"loc",
        }
    )

    if mpile_container and read_loc_pfs:
        utils.get_genome_stats(
            profile=mpileup_df,
            read_loc_table=read_loc_df,
            stb=stb,
            bed=bed_lf.rename({"column_1":"scaffold","column_2":"start","column_3":"end"}),
        ).sink_parquet(output_dir/f"{bam_file.stem}_genome_stats.parquet", compression='zstd', engine='streaming')


    os.system(f"rm -r {output_dir}/tmp")

Compare

zipstrain.compare

This module provides all comparison functions for zipstrain.

PolarsANIExpressions

Any kind of ANI calculation based on two profiles should be implemented as a method of this class. In defining this method, the following rules should be followed:

  • The method returns a Polars expression (pl.Expr).

  • When applied to a row, the method returns a zero if that position is a SNV. Otherwise it should return a number greater than zero.

  • A, T, C, G columns in the first profile are named "A", "T", "C", "G" and in the second profile they are named "A_2", "T_2", "C_2", "G_2".

  • popani: Population ANI based on the shared alleles between two profiles.

  • conani: Consensus ANI based on the consensus alleles between two profiles.
  • cosani_: Generalized cosine similarity ANI where threshold is a float value between 0 and 1. Once the similarity is below the threshold, it is considered a SNV.
Source code in zipstrain/src/zipstrain/compare.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class PolarsANIExpressions:
    """ 
    Any kind of ANI calculation based on two profiles should be implemented as a method of this class.
    In defining this method, the following rules should be followed:

    -   The method returns a Polars expression (pl.Expr).

    -   When applied to a row, the method returns a zero if that position is a SNV. Otherwise it should return a number greater than zero.

    -   A, T, C, G columns in the first profile are named "A", "T", "C", "G" and in the second profile they are named "A_2", "T_2", "C_2", "G_2".

    1. popani: Population ANI based on the shared alleles between two profiles.
    2. conani: Consensus ANI based on the consensus alleles between two profiles.
    3. cosani_<threshold>: Generalized cosine similarity ANI where threshold is a float value between 0 and 1. Once the similarity is below the threshold, it is considered a SNV.
    """
    MPILE_1_BASES = ["A", "T", "C", "G"]
    MPILE_2_BASES = ["A_2", "T_2", "C_2", "G_2"]

    def popani(self):
        return pl.col("A")*pl.col("A_2") + pl.col("C")*pl.col("C_2") + pl.col("G")*pl.col("G_2") + pl.col("T")*pl.col("T_2")

    def conani(self):
        max_base_1=pl.max_horizontal(*[pl.col(base) for base in self.MPILE_1_BASES])
        max_base_2=pl.max_horizontal(*[pl.col(base) for base in self.MPILE_2_BASES])
        return pl.when((pl.col("A")==max_base_1) & (pl.col("A_2")==max_base_2) | 
                       (pl.col("T")==max_base_1) & (pl.col("T_2")==max_base_2) | 
                       (pl.col("C")==max_base_1) & (pl.col("C_2")==max_base_2) | 
                       (pl.col("G")==max_base_1) & (pl.col("G_2")==max_base_2)).then(1).otherwise(0)

    def generalized_cos_ani(self,threshold:float=0.4):
        dot_product = pl.col("A")*pl.col("A_2") + pl.col("C")*pl.col("C_2") + pl.col("G")*pl.col("G_2") + pl.col("T")*pl.col("T_2")
        magnitude_1 = (pl.col("A")**2 + pl.col("C")**2 + pl.col("G")**2 + pl.col("T")**2)**0.5
        magnitude_2 = (pl.col("A_2")**2 + pl.col("C_2")**2 + pl.col("G_2")**2 + pl.col("T_2")**2)**0.5
        cos_sim = dot_product / (magnitude_1 * magnitude_2)
        return pl.when(cos_sim >= threshold).then(1).otherwise(0)

    def __getattribute__(self, name):
        if name.startswith("cosani_"):
            try:
                threshold = float(name.split("_")[1])
            except ValueError:
                raise AttributeError(f"Invalid threshold in method name: {name}")
            return lambda: self.generalized_cos_ani(threshold)
        else:
            return super().__getattribute__(name)

add_contiguity_info(mpile_contig)

Adds group id information to the lazy frame. If on the same scaffold and not popANI, then they are in the same group.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with group id information added.

Source code in zipstrain/src/zipstrain/compare.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
def add_contiguity_info(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """ Adds group id information to the lazy frame. If on the same scaffold and not popANI, then they are in the same group.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with group id information added.
    """

    sort_cols = ["scaffold", "pos"]
    break_expr = (pl.col("scaffold") != pl.col("scaffold").shift(1).fill_null(pl.col("scaffold"))) | (pl.col("surr") == 0)
    if "genome" in mpile_contig.collect_schema().names():
        sort_cols = ["genome", "scaffold", "pos"]
        break_expr = (
            (pl.col("genome") != pl.col("genome").shift(1).fill_null(pl.col("genome")))
            | (pl.col("scaffold") != pl.col("scaffold").shift(1).fill_null(pl.col("scaffold")))
            | (pl.col("surr") == 0)
        )
    mpile_contig = mpile_contig.sort(sort_cols)
    mpile_contig = mpile_contig.with_columns(
        break_expr.cast(pl.Int64).cum_sum().alias("group_id")
    )
    return mpile_contig

add_genome_info(mpile_contig, scaffold_to_genome)

Adds genome information to the mpileup LazyFrame based on scaffold to genome mapping.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
scaffold_to_genome LazyFrame

The LazyFrame mapping scaffolds to genomes.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with genome information added.

Source code in zipstrain/src/zipstrain/compare.py
910
911
912
913
914
915
916
917
918
919
920
921
922
923
def add_genome_info(mpile_contig:pl.LazyFrame, scaffold_to_genome:pl.LazyFrame) -> pl.LazyFrame:
    """
    Adds genome information to the mpileup LazyFrame based on scaffold to genome mapping.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        scaffold_to_genome (pl.LazyFrame): The LazyFrame mapping scaffolds to genomes.

    Returns:
        pl.LazyFrame: Updated LazyFrame with genome information added.
    """
    return mpile_contig.join(
        scaffold_to_genome, on="scaffold", how="left"
    ).fill_null("NA")

calculate_pop_ani(mpile_contig)

Calculates the population ANI (Average Nucleotide Identity) for the given mpileup LazyFrame. NOTE: Remember that this function should be applied to the merged mpileup using get_shared_locs.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with population ANI information added.

Source code in zipstrain/src/zipstrain/compare.py
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
def calculate_pop_ani(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """
    Calculates the population ANI (Average Nucleotide Identity) for the given mpileup LazyFrame.
    NOTE: Remember that this function should be applied to the merged mpileup using get_shared_locs.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with population ANI information added.
    """
    return mpile_contig.group_by("genome").agg(
            total_positions=pl.len(),
            share_allele_pos=(pl.col("surr") > 0 ).sum()
        ).with_columns(
            genome_pop_ani=pl.col("share_allele_pos")/pl.col("total_positions")*100,
        )

compare_genes(mpile_contig_1, mpile_contig_2, min_cov=5, min_gene_compare_len=100, genome_scope='all', gene_scope='all', ani_method='popani', duckdb_memory_limit=None, duckdb_temp_directory=None, duckdb_threads=None, engine='polars')

Compare two profiles at gene level with selectable execution engine.

Source code in zipstrain/src/zipstrain/compare.py
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
def compare_genes(
    mpile_contig_1: Union[str, Path, pl.LazyFrame],
    mpile_contig_2: Union[str, Path, pl.LazyFrame],
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    gene_scope: str = "all",
    ani_method: str = "popani",
    duckdb_memory_limit: Optional[str] = None,
    duckdb_temp_directory: Optional[Union[str, Path]] = None,
    duckdb_threads: Optional[int] = None,
    engine: Literal["polars", "duckdb"] = "polars",
) -> pl.LazyFrame:
    """Compare two profiles at gene level with selectable execution engine."""
    if engine == "polars":
        return compare_genes_polars(
            mpile_contig_1=mpile_contig_1,
            mpile_contig_2=mpile_contig_2,
            min_cov=min_cov,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            gene_scope=gene_scope,
            ani_method=ani_method,
        )
    if engine == "duckdb":
        return _compare_genes_mixed(
            mpile_contig_1=mpile_contig_1,
            mpile_contig_2=mpile_contig_2,
            min_cov=min_cov,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            gene_scope=gene_scope,
            ani_method=ani_method,
            duckdb_memory_limit=duckdb_memory_limit,
            duckdb_temp_directory=duckdb_temp_directory,
            duckdb_threads=duckdb_threads,
        )
    raise ValueError(f"Unsupported engine: {engine}")

compare_genes_polars(mpile_contig_1, mpile_contig_2, min_cov=5, min_gene_compare_len=100, genome_scope='all', gene_scope='all', ani_method='popani')

Compare two profiles fully in Polars and return gene-level ANI statistics.

Source code in zipstrain/src/zipstrain/compare.py
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def compare_genes_polars(
    mpile_contig_1: Union[str, Path, pl.LazyFrame],
    mpile_contig_2: Union[str, Path, pl.LazyFrame],
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    gene_scope: str = "all",
    ani_method: str = "popani",
) -> pl.LazyFrame:
    """Compare two profiles fully in Polars and return gene-level ANI statistics."""
    shared = _shared_loci_polars(
        mpile1=mpile_contig_1,
        mpile2=mpile_contig_2,
        min_cov=min_cov,
        genome_scope=genome_scope,
        gene_scope=gene_scope,
        ani_method=ani_method,
    )
    return (
        shared.select(["genome", "gene", "surr"])
        .group_by(["genome", "gene"])
        .agg(
            total_positions=pl.len(),
            share_allele_pos=(pl.col("surr") > 0).sum(),
        )
        .filter(pl.col("total_positions") >= min_gene_compare_len)
        .with_columns(
            ani=pl.col("share_allele_pos") / pl.col("total_positions") * 100,
        )
    )

compare_genomes(mpile_contig_1, mpile_contig_2, min_cov=5, min_gene_compare_len=100, genome_scope='all', ani_method='popani', duckdb_memory_limit=None, duckdb_temp_directory=None, duckdb_threads=None, engine='polars', stb_file=None, calculate=None)

Compare two profiles with selectable execution engine.

Source code in zipstrain/src/zipstrain/compare.py
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
def compare_genomes(
    mpile_contig_1: Union[str, Path, pl.LazyFrame],
    mpile_contig_2: Union[str, Path, pl.LazyFrame],
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    ani_method: str = "popani",
    duckdb_memory_limit: Optional[str] = None,
    duckdb_temp_directory: Optional[Union[str, Path]] = None,
    duckdb_threads: Optional[int] = None,
    engine: Literal["polars", "duckdb"] = "polars",
    stb_file: Optional[Union[str, Path]] = None,
    calculate: Optional[Union[str, Iterable[str]]] = None,
) -> pl.LazyFrame:
    """Compare two profiles with selectable execution engine."""
    calculations = parse_genome_calculations(calculate)
    if engine == "polars":
        return compare_genomes_polars(
            mpile_contig_1=mpile_contig_1,
            mpile_contig_2=mpile_contig_2,
            min_cov=min_cov,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            ani_method=ani_method,
            stb_file=stb_file,
            calculate=calculations,
        )
    if engine == "duckdb":
        return duckdb_compare_genomes(
            mpile1=mpile_contig_1,
            mpile2=mpile_contig_2,
            min_cov=min_cov,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            ani_method=ani_method,
            calculate=calculations,
            stb_file=stb_file,
            memory_limit=duckdb_memory_limit,
            temp_directory=duckdb_temp_directory,
            threads=duckdb_threads,
        )
    raise ValueError(f"Unsupported engine: {engine}")

compare_genomes_polars(mpile_contig_1, mpile_contig_2, min_cov=5, min_gene_compare_len=100, genome_scope='all', ani_method='popani', stb_file=None, calculate=None)

Compare two profiles fully in Polars and return genome-level statistics.

Source code in zipstrain/src/zipstrain/compare.py
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
def compare_genomes_polars(
    mpile_contig_1: Union[str, Path, pl.LazyFrame],
    mpile_contig_2: Union[str, Path, pl.LazyFrame],
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    ani_method: str = "popani",
    stb_file: Optional[Union[str, Path]] = None,
    calculate: Optional[Union[str, Iterable[str]]] = None,
) -> pl.LazyFrame:
    """Compare two profiles fully in Polars and return genome-level statistics."""
    calculations = parse_genome_calculations(calculate)
    shared = _shared_loci_polars(
        mpile1=mpile_contig_1,
        mpile2=mpile_contig_2,
        min_cov=min_cov,
        genome_scope=genome_scope,
        ani_method=ani_method,
    )
    genome_comp_parts: list[pl.LazyFrame] = []
    if "ani" in calculations:
        genome_comp_parts.append(calculate_pop_ani(shared))
    if "ibs" in calculations:
        genome_comp_parts.append(get_longest_consecutive_blocks(add_contiguity_info(shared)))
    if "identical_genes" in calculations:
        genome_comp_parts.append(get_gene_ani(shared, min_gene_compare_len))

    if genome_comp_parts:
        genome_comp = genome_comp_parts[0]
        for part in genome_comp_parts[1:]:
            genome_comp = genome_comp.join(part, on="genome", how="left")
    else:
        genome_comp = shared.select("genome").unique()

    if stb_file is not None:
        genomes_utf8 = (
            pl.scan_csv(stb_file, separator="\t", has_header=False)
            .select(pl.col("column_2").cast(pl.Utf8).alias("genome"))
            .unique()
        )
        genome_dtype = genome_comp.collect_schema().get("genome")
        if genome_dtype == pl.Categorical:
            # Use a fixed category domain to safely align categorical join keys.
            categories = sorted(
                set(
                    genomes_utf8.select("genome")
                    .collect(engine="streaming")["genome"]
                    .to_list()
                )
            )
            enum_dtype = pl.Enum(categories)
            genomes = genomes_utf8.with_columns(pl.col("genome").cast(enum_dtype))
            genome_comp = genome_comp.with_columns(pl.col("genome").cast(enum_dtype))
        else:
            genomes = genomes_utf8.with_columns(pl.col("genome").cast(genome_dtype))
        if genome_scope != "all":
            genomes = genomes.filter(pl.col("genome") == genome_scope)
        genome_comp = genomes.join(genome_comp, on="genome", how="left")

    casts: list[pl.Expr] = []
    if "ani" in calculations:
        casts.extend(
            [
                pl.col("total_positions").fill_null(0).cast(pl.Int64),
                pl.col("share_allele_pos").fill_null(0).cast(pl.Int64),
                pl.col("genome_pop_ani").fill_null(0.0).cast(pl.Float64),
            ]
        )
    if "ibs" in calculations:
        casts.append(pl.col("max_consecutive_length").fill_null(0).cast(pl.Int64))
    if "identical_genes" in calculations:
        casts.extend(
            [
                pl.col("shared_genes_count").fill_null(0).cast(pl.Int64),
                pl.col("identical_gene_count").fill_null(0).cast(pl.Int64),
                pl.col("perc_id_genes").fill_null(0.0).cast(pl.Float64),
            ]
        )
    if casts:
        genome_comp = genome_comp.with_columns(casts)

    return genome_comp.select(genome_metric_output_columns(calculations))

duckdb_compare_genes_to_parquet(mpile1, mpile2, output_file, sample_1_name, sample_2_name, min_cov=5, min_gene_compare_len=100, genome_scope='all', gene_scope='all', ani_method='popani', memory_limit=None, temp_directory=None, threads=None)

Run gene comparison in DuckDB and write final output directly to parquet.

Source code in zipstrain/src/zipstrain/compare.py
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def duckdb_compare_genes_to_parquet(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    output_file: Union[str, Path],
    sample_1_name: str,
    sample_2_name: str,
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    gene_scope: str = "all",
    ani_method: str = "popani",
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> None:
    """Run gene comparison in DuckDB and write final output directly to parquet."""
    con = duckdb.connect()
    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_source")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_source")
        shared_query = _duckdb_shared_query(
            p1_source=p1_source,
            p2_source=p2_source,
            min_cov=min_cov,
            genome_scope=genome_scope,
            gene_scope=gene_scope,
            ani_method=ani_method,
        )
        sample_1_sql = _duckdb_quote_sql_string(sample_1_name)
        sample_2_sql = _duckdb_quote_sql_string(sample_2_name)
        query = f"""
        WITH shared AS (
          {shared_query}
        )
        SELECT
          genome,
          gene,
          COUNT(*)::BIGINT AS total_positions,
          SUM(CASE WHEN surr > 0 THEN 1 ELSE 0 END)::BIGINT AS share_allele_pos,
          SUM(CASE WHEN surr > 0 THEN 1 ELSE 0 END) * 100.0 / NULLIF(COUNT(*), 0) AS ani,
          '{sample_1_sql}' AS sample_1,
          '{sample_2_sql}' AS sample_2
        FROM shared
        GROUP BY genome, gene
        HAVING COUNT(*) >= {min_gene_compare_len}
        ORDER BY genome, gene
        """
        _duckdb_copy_query_to_parquet(con, query, output_file)
    finally:
        con.close()

duckdb_compare_genomes(mpile1, mpile2, min_cov=5, min_gene_compare_len=100, genome_scope='all', ani_method='popani', calculate=None, stb_file=None, memory_limit=None, temp_directory=None, threads=None)

Run genome comparison in DuckDB and return selected metrics as a LazyFrame.

Source code in zipstrain/src/zipstrain/compare.py
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
def duckdb_compare_genomes(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    ani_method: str = "popani",
    calculate: Optional[Union[str, Iterable[str]]] = None,
    stb_file: Optional[Union[str, Path]] = None,
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> pl.LazyFrame:
    """Run genome comparison in DuckDB and return selected metrics as a LazyFrame."""
    con = duckdb.connect()
    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_source")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_source")
        shared_query = _duckdb_shared_query(
            p1_source=p1_source,
            p2_source=p2_source,
            min_cov=min_cov,
            genome_scope=genome_scope,
            ani_method=ani_method,
        )
        query = _duckdb_genome_compare_query(
            shared_query=shared_query,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            calculate=calculate,
            stb_file=stb_file,
        )
        table = con.execute(query).fetch_arrow_table()
        return pl.from_arrow(table).lazy()
    finally:
        con.close()

duckdb_compare_genomes_to_parquet(mpile1, mpile2, output_file, stb_file, sample_1_name, sample_2_name, min_cov=5, min_gene_compare_len=100, genome_scope='all', ani_method='popani', calculate=None, memory_limit=None, temp_directory=None, threads=None)

Run genome comparison in DuckDB and write final output directly to parquet.

This path avoids materializing large intermediate tables in Python memory.

Source code in zipstrain/src/zipstrain/compare.py
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
def duckdb_compare_genomes_to_parquet(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    output_file: Union[str, Path],
    stb_file: Union[str, Path],
    sample_1_name: str,
    sample_2_name: str,
    min_cov: int = 5,
    min_gene_compare_len: int = 100,
    genome_scope: str = "all",
    ani_method: str = "popani",
    calculate: Optional[Union[str, Iterable[str]]] = None,
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> None:
    """Run genome comparison in DuckDB and write final output directly to parquet.

    This path avoids materializing large intermediate tables in Python memory.
    """
    con = duckdb.connect()
    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_source")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_source")
        shared_query = _duckdb_shared_query(
            p1_source=p1_source,
            p2_source=p2_source,
            min_cov=min_cov,
            genome_scope=genome_scope,
            ani_method=ani_method,
        )
        query = _duckdb_genome_compare_query(
            shared_query=shared_query,
            min_gene_compare_len=min_gene_compare_len,
            genome_scope=genome_scope,
            calculate=calculate,
            stb_file=stb_file,
            sample_1_name=sample_1_name,
            sample_2_name=sample_2_name,
        )
        _duckdb_copy_query_to_parquet(con, query, output_file)
    finally:
        con.close()

duckdb_filter_join(mpile1, mpile2, min_cov, genome_scope='all', ani_method='popani', gene_scope='all', memory_limit=None, temp_directory=None, threads=None)

Filter two profile sources and inner-join shared loci in DuckDB.

Inputs can be parquet paths or Polars LazyFrames. Coverage, genome scope, and optional gene scope are pushed down into DuckDB. The returned lazy frame contains: surr, scaffold, pos, gene, and genome.

Source code in zipstrain/src/zipstrain/compare.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
def duckdb_filter_join(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    min_cov: int,
    genome_scope: str = "all",   # or specific genome
    ani_method: str = "popani",  # "popani" / "conani" / "cosani_0.4"
    gene_scope: str = "all",
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> pl.LazyFrame:
    """Filter two profile sources and inner-join shared loci in DuckDB.

    Inputs can be parquet paths or Polars LazyFrames. Coverage, genome scope,
    and optional gene scope are pushed down into DuckDB. The returned lazy frame
    contains: `surr`, `scaffold`, `pos`, `gene`, and `genome`.
    """
    con = duckdb.connect()

    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_source")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_source")
        query = _duckdb_shared_query(
            p1_source=p1_source,
            p2_source=p2_source,
            min_cov=min_cov,
            genome_scope=genome_scope,
            ani_method=ani_method,
            gene_scope=gene_scope,
        )
        table = con.execute(query).fetch_arrow_table()
        return pl.from_arrow(table).lazy()
    finally:
        con.close()

duckdb_filter_join_with_genome_stats(mpile1, mpile2, min_cov, genome_scope='all', ani_method='popani', gene_scope='all', memory_limit=None, temp_directory=None, threads=None)

Build shared loci and genome-level popANI/contiguity stats in DuckDB.

Returns:

Type Description
tuple[LazyFrame, LazyFrame]

tuple[pl.LazyFrame, pl.LazyFrame]: - shared loci projected for gene aggregation (surr, gene, genome) - genome stats (genome, total_positions, share_allele_pos, genome_pop_ani, max_consecutive_length)

Source code in zipstrain/src/zipstrain/compare.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
def duckdb_filter_join_with_genome_stats(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    min_cov: int,
    genome_scope: str = "all",
    ani_method: str = "popani",
    gene_scope: str = "all",
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> tuple[pl.LazyFrame, pl.LazyFrame]:
    """Build shared loci and genome-level popANI/contiguity stats in DuckDB.

    Returns:
        tuple[pl.LazyFrame, pl.LazyFrame]:
            - shared loci projected for gene aggregation (`surr`, `gene`, `genome`)
            - genome stats (`genome`, `total_positions`, `share_allele_pos`,
              `genome_pop_ani`, `max_consecutive_length`)
    """
    con = duckdb.connect()
    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_source")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_source")
        shared_query = _duckdb_shared_query(
            p1_source=p1_source,
            p2_source=p2_source,
            min_cov=min_cov,
            genome_scope=genome_scope,
            ani_method=ani_method,
            gene_scope=gene_scope,
        )
        con.execute(f"CREATE TEMP TABLE shared AS {shared_query}")

        # Only transfer columns needed for downstream Polars gene aggregation.
        shared_table = con.execute(
            "SELECT surr, gene, genome FROM shared"
        ).fetch_arrow_table()

        genome_stats_query = _duckdb_build_query_with_ctes(
            _duckdb_contig_pop_max_ctes(shared_source="shared"),
            _duckdb_genome_stats_select(),
        )
        genome_stats_table = con.execute(genome_stats_query).fetch_arrow_table()
        return pl.from_arrow(shared_table).lazy(), pl.from_arrow(genome_stats_table).lazy()
    finally:
        con.close()

duckdb_prefilter_by_scope(mpile1, mpile2, genome_scope='all', gene_scope='all', memory_limit=None, temp_directory=None, threads=None)

Scope-filter both profiles in DuckDB and return in-memory LazyFrames.

Source code in zipstrain/src/zipstrain/compare.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def duckdb_prefilter_by_scope(
    mpile1: Union[str, Path, pl.LazyFrame],
    mpile2: Union[str, Path, pl.LazyFrame],
    genome_scope: str = "all",
    gene_scope: str = "all",
    memory_limit: Optional[str] = None,
    temp_directory: Optional[Union[str, Path]] = None,
    threads: Optional[int] = None,
) -> tuple[pl.LazyFrame, pl.LazyFrame]:
    """Scope-filter both profiles in DuckDB and return in-memory LazyFrames."""
    scope_requested = genome_scope != "all" or gene_scope != "all"
    if not scope_requested:
        return _as_lazy_profile(mpile1), _as_lazy_profile(mpile2)

    con = duckdb.connect()
    try:
        _duckdb_configure_connection(
            con,
            memory_limit=memory_limit,
            temp_directory=temp_directory,
            threads=threads,
        )
        p1_source = _duckdb_from_source(con, mpile1, "mpile1_prefilter")
        p2_source = _duckdb_from_source(con, mpile2, "mpile2_prefilter")
        genome_scope_sql = _duckdb_quote_sql_string(genome_scope)
        gene_scope_sql = _duckdb_quote_sql_string(gene_scope)
        where_sql = (
            f"('{genome_scope_sql}' = 'all' OR genome = '{genome_scope_sql}') "
            f"AND ('{gene_scope_sql}' = 'all' OR gene = '{gene_scope_sql}')"
        )
        p1_table = con.execute(
            f"""
            SELECT chrom, pos, gene, genome, A, T, C, G
            FROM {p1_source}
            WHERE {where_sql}
            """
        ).fetch_arrow_table()
        p2_table = con.execute(
            f"""
            SELECT chrom, pos, gene, genome, A, T, C, G
            FROM {p2_source}
            WHERE {where_sql}
            """
        ).fetch_arrow_table()
        return pl.from_arrow(p1_table).lazy(), pl.from_arrow(p2_table).lazy()
    finally:
        con.close()

genome_metric_output_columns(calculate=None)

Return ordered output columns for selected genome-level calculations.

Source code in zipstrain/src/zipstrain/compare.py
65
66
67
68
69
70
71
72
73
74
75
def genome_metric_output_columns(calculate: Optional[Union[str, Iterable[str]]] = None) -> list[str]:
    """Return ordered output columns for selected genome-level calculations."""
    calculations = parse_genome_calculations(calculate)
    cols = ["genome"]
    if "ani" in calculations:
        cols.extend(["total_positions", "share_allele_pos", "genome_pop_ani"])
    if "ibs" in calculations:
        cols.append("max_consecutive_length")
    if "identical_genes" in calculations:
        cols.extend(["shared_genes_count", "identical_gene_count", "perc_id_genes"])
    return cols

get_gene_ani(mpile_contig, min_gene_compare_len)

Calculates gene ANI (Average Nucleotide Identity) for each gene in each genome.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
min_gene_compare_len int

Minimum length of the gene to consider for comparison.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with gene ANI information added.

Source code in zipstrain/src/zipstrain/compare.py
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
def get_gene_ani(mpile_contig:pl.LazyFrame, min_gene_compare_len:int) -> pl.LazyFrame:
    """
    Calculates gene ANI (Average Nucleotide Identity) for each gene in each genome.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        min_gene_compare_len (int): Minimum length of the gene to consider for comparison.

    Returns:
        pl.LazyFrame: Updated LazyFrame with gene ANI information added.
    """
    return mpile_contig.select(["genome", "gene", "surr"]).filter(
        pl.col("gene") != "NA"
    ).group_by(["genome", "gene"]).agg(
        total_positions=pl.len(),
        share_allele_pos=(pl.col("surr") > 0).sum()
    ).filter(pl.col("total_positions") >= min_gene_compare_len).with_columns(
        identical=(pl.col("share_allele_pos") == pl.col("total_positions")),
    ).group_by("genome").agg(
        shared_genes_count=pl.len(),
        identical_gene_count=pl.col("identical").sum()
    ).with_columns(perc_id_genes=pl.col("identical_gene_count") / pl.col("shared_genes_count") * 100)

get_longest_consecutive_blocks(mpile_contig)

Calculates the longest consecutive blocks for each genome in the mpileup LazyFrame for any genome.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with longest consecutive blocks information added.

Source code in zipstrain/src/zipstrain/compare.py
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
def get_longest_consecutive_blocks(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """
    Calculates the longest consecutive blocks for each genome in the mpileup LazyFrame for any genome.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with longest consecutive blocks information added.
    """
    block_lengths = (
        mpile_contig.group_by(["genome", "scaffold", "group_id"])
        .agg(pl.len().alias("length"))
    ) 
    return block_lengths.group_by("genome").agg(pl.col("length").max().alias("max_consecutive_length"))

get_shared_locs(mpile_contig_1, mpile_contig_2, ani_method='popani')

Returns a lazyframe with ATCG information for shared scaffolds and positions between two mpileup files.

Parameters:

Name Type Description Default
mpile_contig_1 LazyFrame

The first mpileup LazyFrame.

required
mpile_contig_2 LazyFrame

The second mpileup LazyFrame.

required
ani_method str

The ANI calculation method to use. Default is "popani".

'popani'

Returns:

Type Description
LazyFrame

pl.LazyFrame: Merged LazyFrame containing shared scaffolds and positions with ATCG information.

Source code in zipstrain/src/zipstrain/compare.py
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
def get_shared_locs(mpile_contig_1:pl.LazyFrame, mpile_contig_2:pl.LazyFrame,ani_method:str="popani") -> pl.LazyFrame:
    """
    Returns a lazyframe with ATCG information for shared scaffolds and positions between two mpileup files.

    Args:
        mpile_contig_1 (pl.LazyFrame): The first mpileup LazyFrame.
        mpile_contig_2 (pl.LazyFrame): The second mpileup LazyFrame.
        ani_method (str): The ANI calculation method to use. Default is "popani".

    Returns:
        pl.LazyFrame: Merged LazyFrame containing shared scaffolds and positions with ATCG information.
    """
    ani_expr=getattr(PolarsANIExpressions(), ani_method)()

    mpile_contig= mpile_contig_1.join(
        mpile_contig_2,
        on=["chrom", "pos"],
        how="inner",
        suffix="_2"  # To distinguish lf2 columns
    ).with_columns(
        ani_expr.alias("surr")
    ).select(
        pl.col("surr"),
        scaffold=pl.col("chrom"),
        pos=pl.col("pos"),
        gene=pl.col("gene")
    )
    return mpile_contig

get_unique_scaffolds(mpile_contig, batch_size=10000)

Retrieves unique scaffolds from the mpileup LazyFrame.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
batch_size int

The number of rows to process in each batch. Default is 10000.

10000

Returns: set: A set of unique scaffold names.

Source code in zipstrain/src/zipstrain/compare.py
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
def get_unique_scaffolds(mpile_contig:pl.LazyFrame,batch_size:int=10000) -> set:
    """
    Retrieves unique scaffolds from the mpileup LazyFrame.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        batch_size (int): The number of rows to process in each batch. Default is 10000.
    Returns:
        set: A set of unique scaffold names.
    """
    scaffolds = set()
    start_index = 0
    while True:
        batch = mpile_contig.slice(start_index, batch_size).select("chrom").collect()
        if batch.height == 0:
            break
        scaffolds.update(batch["chrom"].to_list())
        start_index += batch_size
    return scaffolds 

parse_genome_calculations(calculate=None)

Parse and normalize genome metric selection tokens.

Accepted input formats
  • None -> default ("ani", "ibs", "identical_genes")
  • "ani+ibs+identical_genes"
  • "all"
  • iterable of token strings
Source code in zipstrain/src/zipstrain/compare.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def parse_genome_calculations(calculate: Optional[Union[str, Iterable[str]]] = None) -> tuple[str, ...]:
    """Parse and normalize genome metric selection tokens.

    Accepted input formats:
      - None -> default ("ani", "ibs", "identical_genes")
      - "ani+ibs+identical_genes"
      - "all"
      - iterable of token strings
    """
    if calculate is None:
        return GENOME_COMPARISON_DEFAULT_CALCULATIONS

    raw_tokens: list[str] = []
    if isinstance(calculate, str):
        for plus_part in calculate.split("+"):
            for comma_part in plus_part.split(","):
                token = comma_part.strip().lower()
                if token:
                    raw_tokens.append(token)
    else:
        raw_tokens = [str(token).strip().lower() for token in calculate if str(token).strip()]

    if not raw_tokens:
        return tuple()
    if "all" in raw_tokens:
        return GENOME_COMPARISON_CALCULATIONS

    normalized: set[str] = set()
    for token in raw_tokens:
        mapped = GENOME_COMPARISON_CALCULATION_ALIASES.get(token)
        if mapped is None:
            supported = "all," + ",".join(GENOME_COMPARISON_CALCULATIONS)
            raise ValueError(f"Unsupported calculation '{token}'. Supported values: {supported}")
        normalized.add(mapped)
    return tuple(metric for metric in GENOME_COMPARISON_CALCULATIONS if metric in normalized)

Utils

zipstrain.utils

This module provides utility functions for profiling and compare operations.

CallPresence

This class provides methods to use the information

Source code in zipstrain/src/zipstrain/utils.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class CallPresence:
    """This class provides methods to use the information """
    def validate_input(self,lf:pl.LazyFrame)->pl.LazyFrame:
        required_columns = {"genome", "coverage", "breadth", "ber", "fug"}
        missing_columns = required_columns - set(lf.collect_schema().names)
        if missing_columns:
            raise ValueError(f"Input LazyFrame is missing required columns: {missing_columns}")
        return lf

    def metapresence(self,
                       lf:pl.LazyFrame,
                       ber:float=0.5,
                       fug:float=0.5,
                       min_cov_use_fug:int=0.1
                       )->pl.LazyFrame:
        """
        Call presence/absence of genomes based on breadth, coverage, ber, and fug.
        Parameters:
        lf (pl.LazyFrame): Input LazyFrame with genome statistics.
        ber (float): Breadth error rate threshold.
        fug (float): Fragmented unassembled genome threshold.
        min_cov_use_fug (int): Minimum coverage to use fug for presence call.
        Returns:
        pl.LazyFrame: LazyFrame with presence/absence calls.
        """
        lf=lf.with_columns(
            pl.when(pl.col("coverage") > min_cov_use_fug)
            .then(
                pl.col("ber") > ber
                ).otherwise(
                    (pl.col("fug")/0.632 < fug) &
                    (pl.col("ber") > ber)
                ).fill_null(False).alias("is_present"))
        return lf.select(
            pl.col("genome"),
            pl.col("coverage"),
            pl.col("breadth"),
            pl.col("ber"),
            pl.col("fug"),
            pl.col("is_present")
        )

    def breadth_only(
        self,
        lf:pl.LazyFrame,
        breadth:float=0.5
        )->pl.LazyFrame:
        """
        Call presence/absence of genomes based on breadth only.
        Parameters:
        lf (pl.LazyFrame): Input LazyFrame with genome statistics.
        breadth (float): Breadth threshold.
        Returns:
        pl.LazyFrame: LazyFrame with presence/absence calls.
        """

        lf=lf.with_columns(
            (pl.col("breadth") > breadth).fill_null(False).alias("is_present"))
        return lf.select(
            pl.col("genome"),
            pl.col("coverage"),
            pl.col("breadth"),
            pl.col("ber"),
            pl.col("fug"),
            pl.col("is_present")
        )

    def coverage_only(
        self,
        lf:pl.LazyFrame,
        coverage:float=0.1
        )->pl.LazyFrame:
        """
        Call presence/absence of genomes based on coverage only.
        Parameters:
        lf (pl.LazyFrame): Input LazyFrame with genome statistics.
        coverage (float): Coverage threshold.
        Returns:
        pl.LazyFrame: LazyFrame
        """
        lf=lf.with_columns(
            (pl.col("coverage") > coverage).fill_null(False).alias("is_present"))
        return lf.select(
            pl.col("genome"),
            pl.col("coverage"),
            pl.col("breadth"),
            pl.col("ber"),
            pl.col("fug"),
            pl.col("is_present")
        )           

    def __call__(self, method: str, lf:pl.LazyFrame, **kwargs) -> pl.LazyFrame:
        self.validate_input(lf)
        return self.__getattribute__(method)(lf, **kwargs)
breadth_only(lf, breadth=0.5)

Call presence/absence of genomes based on breadth only. Parameters: lf (pl.LazyFrame): Input LazyFrame with genome statistics. breadth (float): Breadth threshold. Returns: pl.LazyFrame: LazyFrame with presence/absence calls.

Source code in zipstrain/src/zipstrain/utils.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def breadth_only(
    self,
    lf:pl.LazyFrame,
    breadth:float=0.5
    )->pl.LazyFrame:
    """
    Call presence/absence of genomes based on breadth only.
    Parameters:
    lf (pl.LazyFrame): Input LazyFrame with genome statistics.
    breadth (float): Breadth threshold.
    Returns:
    pl.LazyFrame: LazyFrame with presence/absence calls.
    """

    lf=lf.with_columns(
        (pl.col("breadth") > breadth).fill_null(False).alias("is_present"))
    return lf.select(
        pl.col("genome"),
        pl.col("coverage"),
        pl.col("breadth"),
        pl.col("ber"),
        pl.col("fug"),
        pl.col("is_present")
    )
coverage_only(lf, coverage=0.1)

Call presence/absence of genomes based on coverage only. Parameters: lf (pl.LazyFrame): Input LazyFrame with genome statistics. coverage (float): Coverage threshold. Returns: pl.LazyFrame: LazyFrame

Source code in zipstrain/src/zipstrain/utils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def coverage_only(
    self,
    lf:pl.LazyFrame,
    coverage:float=0.1
    )->pl.LazyFrame:
    """
    Call presence/absence of genomes based on coverage only.
    Parameters:
    lf (pl.LazyFrame): Input LazyFrame with genome statistics.
    coverage (float): Coverage threshold.
    Returns:
    pl.LazyFrame: LazyFrame
    """
    lf=lf.with_columns(
        (pl.col("coverage") > coverage).fill_null(False).alias("is_present"))
    return lf.select(
        pl.col("genome"),
        pl.col("coverage"),
        pl.col("breadth"),
        pl.col("ber"),
        pl.col("fug"),
        pl.col("is_present")
    )           
metapresence(lf, ber=0.5, fug=0.5, min_cov_use_fug=0.1)

Call presence/absence of genomes based on breadth, coverage, ber, and fug. Parameters: lf (pl.LazyFrame): Input LazyFrame with genome statistics. ber (float): Breadth error rate threshold. fug (float): Fragmented unassembled genome threshold. min_cov_use_fug (int): Minimum coverage to use fug for presence call. Returns: pl.LazyFrame: LazyFrame with presence/absence calls.

Source code in zipstrain/src/zipstrain/utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def metapresence(self,
                   lf:pl.LazyFrame,
                   ber:float=0.5,
                   fug:float=0.5,
                   min_cov_use_fug:int=0.1
                   )->pl.LazyFrame:
    """
    Call presence/absence of genomes based on breadth, coverage, ber, and fug.
    Parameters:
    lf (pl.LazyFrame): Input LazyFrame with genome statistics.
    ber (float): Breadth error rate threshold.
    fug (float): Fragmented unassembled genome threshold.
    min_cov_use_fug (int): Minimum coverage to use fug for presence call.
    Returns:
    pl.LazyFrame: LazyFrame with presence/absence calls.
    """
    lf=lf.with_columns(
        pl.when(pl.col("coverage") > min_cov_use_fug)
        .then(
            pl.col("ber") > ber
            ).otherwise(
                (pl.col("fug")/0.632 < fug) &
                (pl.col("ber") > ber)
            ).fill_null(False).alias("is_present"))
    return lf.select(
        pl.col("genome"),
        pl.col("coverage"),
        pl.col("breadth"),
        pl.col("ber"),
        pl.col("fug"),
        pl.col("is_present")
    )

EstimateAbundance

This class provides methods to estimate abundance of genomes based on coverage.

Source code in zipstrain/src/zipstrain/utils.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class EstimateAbundance:
    """This class provides methods to estimate abundance of genomes based on coverage."""
    def validate_input(self,lf:pl.LazyFrame)->pl.LazyFrame:
        required_columns = {"genome", "coverage","is_present","Rn"}
        missing_columns = required_columns - set(lf.collect_schema().names)
        if missing_columns:
            raise ValueError(f"Input LazyFrame is missing required columns: {missing_columns}")
        return lf

    def coverage_ratio(
        self,
        lf:pl.LazyFrame
        )->pl.LazyFrame:
        """
        Estimate abundance based on coverage ratio.
        Parameters:
        lf (pl.LazyFrame): Input LazyFrame with genome statistics.
        Returns:
        pl.LazyFrame: LazyFrame with estimated abundance.
        """
        lf=lf.with_columns(
            abundance=pl.when(pl.col("is_present"))
            .then(
                pl.col("coverage") /pl.col("coverage").sum()
            ).otherwise(pl.lit(0.0))
        )
        return lf

    def reads_ratio(
        self,
        lf:pl.LazyFrame
        )->pl.LazyFrame:
        """
        Estimate abundance based on reads ratio.
        Parameters:
        lf (pl.LazyFrame): Input LazyFrame with genome statistics.
        Returns:
        pl.LazyFrame: LazyFrame with estimated abundance.
        """
        lf=lf.with_columns(
            abundance=pl.when(pl.col("is_present"))
            .then(
                pl.col("Rn") /pl.col("total_reads").sum()
            ).otherwise(pl.lit(0.0))
        )
        return lf
coverage_ratio(lf)

Estimate abundance based on coverage ratio. Parameters: lf (pl.LazyFrame): Input LazyFrame with genome statistics. Returns: pl.LazyFrame: LazyFrame with estimated abundance.

Source code in zipstrain/src/zipstrain/utils.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def coverage_ratio(
    self,
    lf:pl.LazyFrame
    )->pl.LazyFrame:
    """
    Estimate abundance based on coverage ratio.
    Parameters:
    lf (pl.LazyFrame): Input LazyFrame with genome statistics.
    Returns:
    pl.LazyFrame: LazyFrame with estimated abundance.
    """
    lf=lf.with_columns(
        abundance=pl.when(pl.col("is_present"))
        .then(
            pl.col("coverage") /pl.col("coverage").sum()
        ).otherwise(pl.lit(0.0))
    )
    return lf
reads_ratio(lf)

Estimate abundance based on reads ratio. Parameters: lf (pl.LazyFrame): Input LazyFrame with genome statistics. Returns: pl.LazyFrame: LazyFrame with estimated abundance.

Source code in zipstrain/src/zipstrain/utils.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def reads_ratio(
    self,
    lf:pl.LazyFrame
    )->pl.LazyFrame:
    """
    Estimate abundance based on reads ratio.
    Parameters:
    lf (pl.LazyFrame): Input LazyFrame with genome statistics.
    Returns:
    pl.LazyFrame: LazyFrame with estimated abundance.
    """
    lf=lf.with_columns(
        abundance=pl.when(pl.col("is_present"))
        .then(
            pl.col("Rn") /pl.col("total_reads").sum()
        ).otherwise(pl.lit(0.0))
    )
    return lf

build_null_poisson(error_rate=0.001, max_total_reads=10000, p_threshold=0.05)

Build a null model to correct for sequencing errors based on the Poisson distribution.

Parameters: error_rate (float): Error rate for the sequencing technology. max_total_reads (int): Maximum total reads to consider. p_threshold (float): Significance threshold for the Poisson distribution.

Returns: pl.DataFrame: DataFrame containing total reads and maximum error count thresholds.

Source code in zipstrain/src/zipstrain/utils.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def build_null_poisson(error_rate:float=0.001,
                       max_total_reads:int=10000,
                       p_threshold:float=0.05)->list[float]:
    """
    Build a null model to correct for sequencing errors based on the Poisson distribution.

    Parameters:
    error_rate (float): Error rate for the sequencing technology.
    max_total_reads (int): Maximum total reads to consider.
    p_threshold (float): Significance threshold for the Poisson distribution.

    Returns:
    pl.DataFrame: DataFrame containing total reads and maximum error count thresholds.
    """ 
    records = []
    for n in range(1, max_total_reads + 1):
        lam = n * (error_rate / 3)
        k = 0
        while poisson.sf(k - 1, lam) > p_threshold:
            k += 1
        records.append((n, k - 1))
    return records

clean_bases(bases, indel_re)

Remove read start/end markers and indels from bases string using regex. Returns cleaned uppercase string of bases only. Args: bases (str): The bases string from mpileup. indel_re (re.Pattern): Compiled regex pattern to match indels and markers.

Source code in zipstrain/src/zipstrain/utils.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def clean_bases(bases: str, indel_re: re.Pattern) -> str:
    """
    Remove read start/end markers and indels from bases string using regex.
    Returns cleaned uppercase string of bases only.
    Args:
        bases (str): The bases string from mpileup.
        indel_re (re.Pattern): Compiled regex pattern to match indels and markers.

    """
    cleaned = []
    i = 0
    while i < len(bases):
        m = indel_re.match(bases, i)
        if m:
            if m.group(0).startswith('+') or m.group(0).startswith('-'):
                # indel length
                indel_len = int(m.group(1))
                i = m.end() + indel_len
            else:
                i = m.end()
        else:
            cleaned.append(bases[i].upper())
            i += 1
    return ''.join(cleaned)

collect_breadth_tables(breadth_tables)

Collect multiple genome breadth tables into a single LazyFrame.

Parameters: breadth_tables (list[pl.LazyFrame]): List of LazyFrames containing genome breadth data.

Returns: pl.LazyFrame: A LazyFrame containing the combined genome breadth data.

Source code in zipstrain/src/zipstrain/utils.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def collect_breadth_tables(
    breadth_tables: list[pl.LazyFrame],
) -> pl.LazyFrame:
    """
    Collect multiple genome breadth tables into a single LazyFrame.

    Parameters:
    breadth_tables (list[pl.LazyFrame]): List of LazyFrames containing genome breadth data.

    Returns:
    pl.LazyFrame: A LazyFrame containing the combined genome breadth data.
    """
    if not breadth_tables:
        raise ValueError("No breadth tables provided.")

    return reduce(lambda x, y: x.join(y, on="genome", how="outer", coalesce=True), breadth_tables)

count_bases(bases)

Count occurrences of A, C, G, T in the cleaned bases string. Args: bases (str): Cleaned bases string. Returns: dict: Dictionary with counts of A, C, G, T.

Source code in zipstrain/src/zipstrain/utils.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def count_bases(bases: str):
    """
    Count occurrences of A, C, G, T in the cleaned bases string.
    Args:
        bases (str): Cleaned bases string.
    Returns:
        dict: Dictionary with counts of A, C, G, T.
    """
    counts = Counter(bases)
    return {
        'A': counts.get('A', 0),
        'C': counts.get('C', 0),
        'G': counts.get('G', 0),
        'T': counts.get('T', 0),
    }

extract_genome_length(stb, bed_table)

Extract the genome length information from the scaffold-to-genome mapping table.

Parameters: stb (pl.LazyFrame): Scaffold-to-bin mapping table. bed_table (pl.LazyFrame): BED table containing genomic regions.

Returns: pl.LazyFrame: A LazyFrame containing the genome lengths.

Source code in zipstrain/src/zipstrain/utils.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def extract_genome_length(stb: pl.LazyFrame, bed_table: pl.LazyFrame) -> pl.LazyFrame:
    """
    Extract the genome length information from the scaffold-to-genome mapping table.

    Parameters:
    stb (pl.LazyFrame): Scaffold-to-bin mapping table.
    bed_table (pl.LazyFrame): BED table containing genomic regions.

    Returns:
    pl.LazyFrame: A LazyFrame containing the genome lengths.
    """
    lf= bed_table.select(
        pl.col("scaffold"),
        (pl.col("end") - pl.col("start")).alias("scaffold_length")
    ).group_by("scaffold").agg(
        scaffold_length=pl.sum("scaffold_length")
    ).select(
        pl.col("scaffold").alias("scaffold"),
        pl.col("scaffold_length")
    ).join(
        stb.select(
            pl.col("scaffold").alias("scaffold"),
            pl.col("genome").alias("genome")
        ),
        on="scaffold",
        how="left"
    ).group_by("genome").agg(
        genome_length=pl.sum("scaffold_length")
    ).select(
        pl.col("genome"),
        pl.col("genome_length")
    )
    return lf

get_genome_breadth_matrix(profile, name, genome_length, stb, min_cov=1)

Get the genome breadth matrix from the provided profiles and scaffold-to-genome mapping. Parameters: profiles (list): List of tuples containing profile names and their corresponding LazyFrames. stb (pl.LazyFrame): Scaffold-to-genome mapping table. min_cov (int): Minimum coverage to consider a position. Returns: pl.LazyFrame: A LazyFrame containing the genome breadth matrix.

Source code in zipstrain/src/zipstrain/utils.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def get_genome_breadth_matrix(
                              profile:pl.LazyFrame,
                              name:str,
                              genome_length: pl.LazyFrame,
                              stb: pl.LazyFrame,
                              min_cov: int = 1)-> pl.LazyFrame:
    """
    Get the genome breadth matrix from the provided profiles and scaffold-to-genome mapping.
    Parameters:
    profiles (list): List of tuples containing profile names and their corresponding LazyFrames.
    stb (pl.LazyFrame): Scaffold-to-genome mapping table.
    min_cov (int): Minimum coverage to consider a position. 
    Returns:
    pl.LazyFrame: A LazyFrame containing the genome breadth matrix.
    """
    profile = profile.filter((pl.col("A") + pl.col("C") + pl.col("G") + pl.col("T")) >= min_cov)
    profile=profile.group_by("chrom").agg(
        breadth=pl.count()
    ).select(
        pl.col("chrom").alias("scaffold"),
        pl.col("breadth")
    ).join(
        stb,
        on="scaffold",
        how="left"
    )
    profile=profile.join(genome_length, on="genome", how="left")

    profile=profile.group_by("genome").agg(
        genome_length=pl.first("genome_length"),
        breadth=pl.col("breadth").sum())
    profile = profile.with_columns(
        (pl.col("breadth")/ pl.col("genome_length")).alias("breadth")
    )
    return profile.select(
            pl.col("genome"),
            pl.col("breadth").alias(name)
        )

make_the_bed(db_fasta_dir, max_scaffold_length=500000)

Create a BED file from the database in fasta format.

Parameters: db_fasta_dir (Union[str, pathlib.Path]): Path to the fasta file. max_scaffold_length (int): Splits scaffolds longer than this into multiple entries of length <= max_scaffold_length.

Returns: pl.LazyFrame: A LazyFrame containing the BED data.

Source code in zipstrain/src/zipstrain/utils.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def make_the_bed(db_fasta_dir: str | pathlib.Path, max_scaffold_length: int = 500_000) -> pl.DataFrame:
    """
    Create a BED file from the database in fasta format.

    Parameters:
    db_fasta_dir (Union[str, pathlib.Path]): Path to the fasta file.
    max_scaffold_length (int): Splits scaffolds longer than this into multiple entries of length <= max_scaffold_length.

    Returns:
    pl.LazyFrame: A LazyFrame containing the BED data.
    """
    db_fasta_dir = pathlib.Path(db_fasta_dir)
    if not db_fasta_dir.is_file():
        raise FileNotFoundError(f"{db_fasta_dir} is not a valid fasta file.")

    records = []
    with db_fasta_dir.open() as f:
        scaffold = None
        seq_chunks = []

        for line in f:
            line = line.strip()
            if line.startswith(">"):
                # Process the previous scaffold
                if scaffold is not None:
                    seq = ''.join(seq_chunks)
                    for start in range(0, len(seq), max_scaffold_length):
                        end = min(start + max_scaffold_length, len(seq))
                        records.append((scaffold, start, end))
                # Start new scaffold
                scaffold = line[1:].split()[0]  # ID only (up to first whitespace)
                seq_chunks = []
            else:
                seq_chunks.append(line)

        # Don't forget the last scaffold
        if scaffold is not None:
            seq = ''.join(seq_chunks)
            for start in range(0, len(seq), max_scaffold_length):
                end = min(start + max_scaffold_length, len(seq))
                records.append((scaffold, start, end))

    return pl.DataFrame(records, schema=["scaffold", "start", "end"], orient="row")

process_mpileup_function(batch_size, output_file)

Process mpileup files and save the results in a Parquet file.

Parameters: gene_range_table_loc (str): Path to the gene range table in TSV format. batch_bed (str): Path to the batch BED file. batch_size (int): Buffer size for processing stdin from samtools. output_file (str): Path to save the output Parquet file.

Source code in zipstrain/src/zipstrain/utils.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def process_mpileup_function(batch_size, output_file):
    """
    Process mpileup files and save the results in a Parquet file.

    Parameters:
    gene_range_table_loc (str): Path to the gene range table in TSV format.
    batch_bed (str): Path to the batch BED file.
    batch_size (int): Buffer size for processing stdin from samtools.
    output_file (str): Path to save the output Parquet file.
    """
    indel_re = re.compile(r'\^.|[\$]|[+-](\d+)')


    schema = pa.schema([
        ('chrom', pa.string()),
        ('pos', pa.int32()),
        ('A', pa.uint16()),
        ('C', pa.uint16()),
        ('G', pa.uint16()),
        ('T', pa.uint16()),
    ])

    chroms = []
    positions = []
    As = []
    Cs = []
    Gs = []
    Ts = []

    writer = None
    def flush_batch():
        nonlocal writer
        if not chroms:
            return
        batch = pa.RecordBatch.from_arrays([
            pa.array(chroms, type=pa.string()),
            pa.array(positions, type=pa.int32()),
            pa.array(As, type=pa.uint16()),
            pa.array(Cs, type=pa.uint16()),
            pa.array(Gs, type=pa.uint16()),
            pa.array(Ts, type=pa.uint16()),
        ], schema=schema)

        if writer is None:
            # Open writer for the first time
            writer = pq.ParquetWriter(output_file, schema, compression='zstd')
        writer.write_table(pa.Table.from_batches([batch]))

        # Clear buffers
        chroms.clear()
        positions.clear()
        As.clear()
        Cs.clear()
        Gs.clear()
        Ts.clear()
    for line in sys.stdin:
        if not line.strip():
            continue
        fields = line.strip().split('\t')
        if len(fields) < 5:
            continue
        chrom, pos, _, _, bases = fields[:5]

        cleaned = clean_bases(bases, indel_re)
        counts = count_bases(cleaned)

        chroms.append(chrom)
        positions.append(int(pos))
        As.append(counts['A'])
        Cs.append(counts['C'])
        Gs.append(counts['G'])
        Ts.append(counts['T'])

        if len(chroms) >= batch_size:
            flush_batch()

    # Flush remaining data
    flush_batch()

    if writer:
        writer.close()

process_read_location(output_file, batch_size=10000)

This function takes the output of samtools view -F 132 and processes it to extract read locations in a parquet file.

Source code in zipstrain/src/zipstrain/utils.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def process_read_location(output_file:str, batch_size:int=10000)->None:
    """
    This function takes the output of samtools view -F 132 and processes it to extract read locations in a parquet file.
    """
    schema = pa.schema([
        ('chrom', pa.string()),
        ('pos', pa.int32()),
    ])
    writer = None
    chroms = []
    positions = []
    def flush_batch():
        nonlocal writer
        if not chroms:
            return
        batch = pa.RecordBatch.from_arrays([
            pa.array(chroms, type=pa.string()),
            pa.array(positions, type=pa.int32()),
        ], schema=schema)

        if writer is None:
            # Open writer for the first time
            writer = pq.ParquetWriter(output_file, schema, compression='zstd')
        writer.write_table(pa.Table.from_batches([batch]))

        # Clear buffers
        chroms.clear()
        positions.clear()
    for line in sys.stdin:
        if not line.strip():
            continue
        fields = line.strip().split('\t')
        if len(fields) < 4:
            continue
        chrom, pos = fields[2], fields[3]
        chroms.append(chrom)
        positions.append(int(pos))
        if len(chroms) >= batch_size:
            flush_batch()
    # Flush remaining data
    flush_batch()
    if writer:
        writer.close()

split_lf_to_chunks(lf, num_chunks)

Split a Polars LazyFrame into smaller chunks.

Parameters: lf (pl.LazyFrame): The input LazyFrame to be split. num_chunks (int): The number of chunks to split the LazyFrame into.

Returns: list[pl.LazyFrame]: A list of smaller LazyFrames.

Source code in zipstrain/src/zipstrain/utils.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def split_lf_to_chunks(lf:pl.LazyFrame,num_chunks:int)->list[pl.LazyFrame]:
    """
    Split a Polars LazyFrame into smaller chunks.

    Parameters:
    lf (pl.LazyFrame): The input LazyFrame to be split.
    num_chunks (int): The number of chunks to split the LazyFrame into.

    Returns:
    list[pl.LazyFrame]: A list of smaller LazyFrames.
    """
    total_rows = lf.select(pl.count()).collect().item()
    chunk_size = total_rows // num_chunks
    chunks = []
    for i in range(num_chunks):
        start = i * chunk_size
        end = (i + 1) * chunk_size if i < num_chunks - 1 else total_rows
        chunk = lf.slice(start, end - start)
        chunks.append(chunk)
    return chunks

Visualize

zipstrain.visualize

This module provides statistical analysis and visualization functions for profiling and compare operations.

calculate_ibs(sample_to_population, comps_lf, max_perc_id_genes=15, min_total_positions=10000)

Calculate the Identity By State (IBS) between two populations for a given genome. The IBS is defined as the percentage of genes that are identical between two populations for a given genome. Args: sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples. max_perc_id_genes (float, optional): Maximum percentage of identical genes to consider. Defaults to 0.15. Returns: pl.LazyFrame: LazyFrame containing the IBS information for the given genome and populations.

Source code in zipstrain/src/zipstrain/visualize.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def calculate_ibs(
    sample_to_population:pl.LazyFrame, 
    comps_lf:pl.LazyFrame,
    max_perc_id_genes:float=15,
    min_total_positions:int=10000,
)->pl.DataFrame:
    """
    Calculate the Identity By State (IBS) between two populations for a given genome.
    The IBS is defined as the percentage of genes that are identical between two populations for a given genome.
    Args:
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
        max_perc_id_genes (float, optional): Maximum percentage of identical genes to consider. Defaults to 0.15.
    Returns:
        pl.LazyFrame: LazyFrame containing the IBS information for the given genome and populations.
    """
    comps_lf_filtered = comps_lf.filter(
        (pl.col('perc_id_genes') <= max_perc_id_genes) &
        (pl.col('total_positions')>min_total_positions)
    )
    comps_lf_filtered=comps_lf_filtered.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='inner',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='inner',
        suffix='_2'
    ).rename(
        {"population":"population_2"}
    )
    comps_lf_filtered = comps_lf_filtered.with_columns(
    pl.when(pl.col("population_1") == pl.col("population_2"))
    .then(
        pl.lit("within_population_")
        + pl.col("population_1")
        + pl.lit("|")
        + pl.col("population_2")
    )
    .otherwise(
        pl.concat_str(
            [
                pl.lit("between_population_"),
                pl.concat_str(
                    [
                        pl.min_horizontal("population_1", "population_2"),
                        pl.lit("|"),
                        pl.max_horizontal("population_1", "population_2"),
                    ]
                ),
            ]
        )
    )
    .alias("comparison_type")
    ).fill_null(-1)

    return comps_lf_filtered.group_by(["genome","comparison_type"]).agg(
        pl.col("max_consecutive_length"),
    ).collect(engine="streaming").pivot(
        index="genome",
        columns="comparison_type",
        values="max_consecutive_length",
    ).with_columns(
        pl.col("*").exclude("genome").fill_null([-1])
    )

calculate_identical_frac_vs_popani(genome, population_1, population_2, sample_to_population, comps_lf, min_shared_genes_count=100, min_total_positions=10000)

Calculate the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations. Args: genome (str): The genome to calculate the fraction of identical genes vs popANI for. population_1 (str): The first population to compare. population_2 (str): The second population to compare. sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples Returns: pl.LazyFrame: LazyFrame containing the fraction of identical genes vs popANI information for

Source code in zipstrain/src/zipstrain/visualize.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def calculate_identical_frac_vs_popani(
    genome:str,
    population_1:str,
    population_2:str,
    sample_to_population:pl.LazyFrame,
    comps_lf:pl.LazyFrame,
    min_shared_genes_count:int=100,
    min_total_positions:int=10000
    ):
    """
    Calculate the fraction of identical genes vs  popANI for a given genome and two samples in any possible combination of populations.
    Args:
        genome (str): The genome to calculate the fraction of identical genes vs popANI for.
        population_1 (str): The first population to compare.
        population_2 (str): The second population to compare.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples
    Returns:
        pl.LazyFrame: LazyFrame containing the fraction of identical genes vs popANI information for
    """
    comps_lf_filtered=comps_lf.filter(
        (pl.col('genome') == genome) &
        (pl.col("shared_genes_count")>min_shared_genes_count) &
        (pl.col("total_positions")>min_total_positions)
    ).collect(engine="streaming").lazy()

    comps_lf_filtered=comps_lf_filtered.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='left',
        suffix='_2'
    ).rename(
        {"population":"population_2"}
    )
    comps_lf_filtered = comps_lf_filtered.filter(
        (pl.col("population_1").is_in({population_1, population_2})) &
        (pl.col("population_2").is_in({population_1, population_2}))
    ).collect(engine="streaming").lazy()
    groups={
        "same_1":f"{population_1}-{population_1}",
        "same_2":f"{population_2}-{population_2}",
        "diff":f"{population_1}-{population_2}",
    }
    comps_lf_filtered=comps_lf_filtered.with_columns(
        pl.when((pl.col("population_1")==population_1) & (pl.col("population_2")==population_1))
        .then(pl.lit(groups["same_1"]))
        .when((pl.col("population_1")==population_2) & (pl.col("population_2")==population_2))
        .then(pl.lit(groups["same_2"]))
        .otherwise(pl.lit(groups["diff"]))
        .alias("relationship")
    )
    return comps_lf_filtered.group_by("relationship").agg(
        pl.col("perc_id_genes"),
        pl.col("genome_pop_ani")
    ).collect(engine="streaming")

calculate_strainsharing(comps_lf, breadth_lf, sample_to_population, min_breadth=0.5, strain_similarity_threshold=99.9, min_total_positions=10000)

Calculate strain sharing between populations based on popANI between genomes in their profiles. Strain sharing between two samples is defined as the ratio of genomes passing a strain similarity threshold over the total number of genomes in each sample. So, for two samples A and B, the strain sharing is defined as (Note the assymetric nature of the calculation): strain_sharing(A, B) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in A) strain_sharing(B, A) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in B)

Parameters:

Name Type Description Default
comps_lf LazyFrame

LazyFrame containing the gene profiles of the samples.

required
breadth_lf LazyFrame

LazyFrame containing the genome breadth information.

required
sample_to_population LazyFrame

LazyFrame containing the sample to population mapping.

required
min_breadth float

Minimum genome breadth to consider a genome for strain sharing. Defaults to 0.5.

0.5
strain_similarity_threshold float

Threshold for strain similarity. Defaults to 0.99.

99.9
min_total_positions int

Minimum total positions to consider a genome for strain sharing. Defaults to 10000.

10000

Returns: pl.LazyFrame: LazyFrame containing the strain sharing information between populations. It will be in the following form [Sample A, Sample B, Strain Sharing, Relationship]

Source code in zipstrain/src/zipstrain/visualize.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def calculate_strainsharing(
                            comps_lf:pl.LazyFrame,
                            breadth_lf:pl.LazyFrame,
                            sample_to_population:pl.LazyFrame,
                            min_breadth:float=0.5,
                            strain_similarity_threshold:float=99.9,
                            min_total_positions:int=10000
                            )->dict[str, list[float]]:


    """
    Calculate strain sharing between populations based on popANI between genomes in their profiles.
    Strain sharing between two samples is defined as the ratio of genomes passing a strain similarity threshold over the total number of genomes in each sample.
    So, for two samples A and B, the strain sharing is defined as (Note the assymetric nature of the calculation):
    strain_sharing(A, B) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in A)
    strain_sharing(B, A) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in B)

    Args:
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
        breadth_lf (pl.LazyFrame): LazyFrame containing the genome breadth information.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        min_breadth (float, optional): Minimum genome breadth to consider a genome for strain sharing. Defaults to 0.5.
        strain_similarity_threshold (float, optional): Threshold for strain similarity. Defaults to 0.99.
        min_total_positions (int, optional): Minimum total positions to consider a genome for strain sharing. Defaults to 10000.
    Returns:
        pl.LazyFrame: LazyFrame containing the strain sharing information between populations. It will be in the following form [Sample A, Sample B, Strain Sharing, Relationship]
    """
    comps_lf=comps_lf.filter(
        (pl.col("total_positions")>min_total_positions)
    ).collect(engine="streaming").lazy()
    breadth_lf=breadth_lf.fill_null(0.0)
    breadth_lf_long=(
        breadth_lf.unpivot(
            index=["genome"],
            variable_name="sample",
            value_name="breadth"
        )
    )
    breadth_lf=breadth_lf_long.group_by("sample").agg(num_genomes=(pl.col("breadth")>=min_breadth).sum())
    comps_lf=comps_lf.join(breadth_lf,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"num_genomes":"num_genomes_1"}
    ).join(
        breadth_lf,
        left_on='sample_2',
        right_on='sample',
        how='left',
    ).rename(
        {"num_genomes":"num_genomes_2"}
    )
    comps_lf = comps_lf.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_2"}
    )
    comps_lf=comps_lf.join(
        breadth_lf_long,
        left_on=["genome","sample_1"],
        right_on=['genome','sample'],
        how='left',
    ).rename(
        {"breadth":"breadth_1"}
    ).join(
        breadth_lf_long,
        left_on=["genome","sample_2"],
        right_on=['genome','sample'],
        how='left',
    ).rename(
        {"breadth":"breadth_2"}
    )
    comps_lf=comps_lf.filter(
        (pl.col("breadth_1") >= min_breadth) &
        (pl.col("breadth_2") >= min_breadth) &
        (pl.col("genome_pop_ani") >= strain_similarity_threshold)
    )

    comps_lf=comps_lf.group_by(
        ["sample_1", "sample_2"]
    ).agg(
        pl.col("genome").count().alias("shared_strain_count"),
        pl.col("num_genomes_1").first().alias("num_genomes_1"),
        pl.col("num_genomes_2").first().alias("num_genomes_2"),
        pl.col("population_1").first().alias("population_1"),
        pl.col("population_2").first().alias("population_2"),
    ).collect(engine="streaming")
    strainsharingrates=defaultdict(list)
    for row in comps_lf.iter_rows(named=True):
        strainsharingrates[row["population_1"]+"_"+ row["population_2"]].append(row["shared_strain_count"] / row["num_genomes_1"])
        strainsharingrates[row["population_2"]+"_"+ row["population_1"]].append(row["shared_strain_count"] / row["num_genomes_2"])
    return strainsharingrates

get_cdf(data, num_bins=10000)

Calculate the cumulative distribution function (CDF) of the given data.

Source code in zipstrain/src/zipstrain/visualize.py
17
18
19
20
21
22
23
24
25
26
def get_cdf(data, num_bins=10000):
    """Calculate the cumulative distribution function (CDF) of the given data."""
    if data[0] == -1:
        return [-1], [-1]
    counts, bin_edges = np.histogram(data, bins=np.linspace(0, 50000, num_bins))
    counts = counts[::-1]
    bin_edges = bin_edges[::-1]
    cummulative_counts = np.cumsum(counts)
    cdf= cummulative_counts / cummulative_counts[-1]
    return bin_edges, cdf

plot_clustermap(comps_lf, genome, sample_to_population, min_comp_len=10000, impute_method=97.0, max_null_samples=200, color_map=None)

Plot a clustermap for the given genome and its associated samples. Args: comps_lf (pl.LazyFrame): LazyFrame containing the comparison data. genome (str): The genome to plot. sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. Returns: go.Figure: Plotly figure containing the clustermap.

Source code in zipstrain/src/zipstrain/visualize.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def plot_clustermap(
    comps_lf:pl.LazyFrame,
    genome:str,
    sample_to_population:pl.LazyFrame,
    min_comp_len:int=10000,
    impute_method:str|float=97.0,
    max_null_samples:int=200,
    color_map:dict|None=None,
):
    """
    Plot a clustermap for the given genome and its associated samples.
    Args:
        comps_lf (pl.LazyFrame): LazyFrame containing the comparison data.
        genome (str): The genome to plot.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
    Returns:
        go.Figure: Plotly figure containing the clustermap.
    """
    # Filter the comparison data for the specific genome
    comps_lf_filtered = comps_lf.filter(
        (pl.col("genome") == genome) & (pl.col("total_positions") > min_comp_len)
    ).select(
        pl.col("sample_1"),
        pl.col("sample_2"),
        pl.col("genome_pop_ani"),
    )
    comps_lf_filtered_oposite = comps_lf_filtered.select(
        pl.col("sample_2").alias("sample_1"),
        pl.col("sample_1").alias("sample_2"),
        pl.col("genome_pop_ani"),
    )
    # Combine the filtered data with its opposite pairs
    comps_lf_filtered = pl.concat([comps_lf_filtered, comps_lf_filtered_oposite])
    # Make a synthetic table for similarity of samples with themselves of all samples in sample_1 and sample_2 but each sample exists only once
    self_similarity =(
    pl.concat([
        comps_lf_filtered.select(pl.col("sample_1").alias("sample_1")),
        comps_lf_filtered.select(pl.col("sample_2").alias("sample_1"))
    ])
    .unique()
    .sort("sample_1").with_columns(
        pl.col("sample_1").alias("sample_2"),
        pl.lit(100.0).alias("genome_pop_ani"),
    )
    )

    # Combine the self similarity with the filtered data
    comps_lf_filtered = pl.concat([self_similarity, comps_lf_filtered]).collect(engine="streaming")
    # Pivot the data for the clustermap
    clustermap_data = comps_lf_filtered.pivot(
        index="sample_1",
        columns="sample_2",
        values="genome_pop_ani"
    )
    # We want to make this a similarity matrix, so we need to frop null values, have sample_1 and sample_2 as index and columns as we
    # Create the clustermap
    exclude_samples=clustermap_data.null_count().transpose(include_header=True, header_name="column", column_names=["null_count"]).filter(pl.col("null_count")>max_null_samples)["column"].to_list()
    # Only include rows and cols not in exclude_samples
    clustermap_data = clustermap_data.filter(~pl.col("sample_1").is_in(exclude_samples))
    clustermap_data = clustermap_data.select(*[col for col in clustermap_data.columns if col not in exclude_samples])
    if isinstance(impute_method, str):
        pass # To be implemented later
    elif isinstance(impute_method, (int, float)):
        clustermap_data = clustermap_data.fill_null(impute_method)
    sample_to_population = clustermap_data.select(pl.col("sample_1")).join(
        sample_to_population.collect(engine="streaming"),
        left_on="sample_1",
        right_on="sample",
        how="left")
    sample_to_population_dict = dict(zip(sample_to_population["sample_1"], sample_to_population["population"]))
    if color_map is None:

        num_categories = sample_to_population["population"].n_unique()
        groups= sample_to_population["population"].unique().sort().to_list()
        qualitative_palette = sns.color_palette("hls", num_categories)
        row_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data["sample_1"]]
        col_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data.columns if sample != "sample_1"]
    else:
        groups= list(color_map.keys())
        qualitative_palette= list(color_map.values())
        row_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data["sample_1"]]
        col_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data.columns if sample != "sample_1"]
    fig = sns.clustermap(
        clustermap_data.to_pandas().set_index("sample_1"),
        figsize=(30, 30),
        xticklabels=True, 
        yticklabels=True,
        row_colors=row_colors,
        col_colors=col_colors
    )
    fig.ax_heatmap.set_xticklabels(fig.ax_heatmap.get_xmajorticklabels(), fontsize=0.1)
    fig.ax_heatmap.set_yticklabels(fig.ax_heatmap.get_ymajorticklabels(), fontsize=0.1)
    legend_handles = [mpatches.Patch(color=color, label=label)
                  for label, color in zip(groups, qualitative_palette)]
    fig.ax_heatmap.legend(handles=legend_handles,
                          title='Population',
                        title_fontsize=16,   # bigger title
                        fontsize=14,         # bigger labels
                        handlelength=2.5,    # wider color boxes
                        handleheight=2,
                    bbox_to_anchor=(-0.15, 1), loc="lower left")
    return fig

plot_ibs(df, genome, population_1, population_2, vert_thresh_hor_distance=0.001, num_bins=10000, title='IBS for <GENOME>: <POPULATION_1> vs <POPULATION_2>', xaxis_title='Max Consecutive Length', yaxis_title='CDF')

Plot the Identity By State (IBS) for a given genome and two populations. Args: df (pl.DataFrame): DataFrame containing the IBS information. genome (str): The genome to plot the IBS for. population_1 (str): The first population to plot the IBS for. population_2 (str): The second population to plot the IBS for. title (str, optional): Title of the plot. Defaults to "IBS for ". xaxis_title (str, optional): Title of the x-axis. Defaults to "Membership". yaxis_title (str, optional): Title of the y-axis. Defaults to "Max Consecutive Length". Returns: go.Figure: Plotly figure containing the IBS plot.

Source code in zipstrain/src/zipstrain/visualize.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def plot_ibs(df:pl.DataFrame,
            genome:str,
            population_1:str,
            population_2:str,
            vert_thresh_hor_distance:float=0.001,
            num_bins:int=10000,
            title:str="IBS for <GENOME>: <POPULATION_1> vs <POPULATION_2>",
            xaxis_title:str="Max Consecutive Length",
            yaxis_title:str="CDF"
            ):
    """
    Plot the Identity By State (IBS) for a given genome and two populations.
    Args:
        df (pl.DataFrame): DataFrame containing the IBS information.
        genome (str): The genome to plot the IBS for.
        population_1 (str): The first population to plot the IBS for.
        population_2 (str): The second population to plot the IBS for.
        title (str, optional): Title of the plot. Defaults to "IBS for <GENOME>".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Membership".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Max Consecutive Length".
    Returns:
        go.Figure: Plotly figure containing the IBS plot.
    """
    df_filtered = df.filter(pl.col("genome") == genome)
    if df_filtered.is_empty():
        raise ValueError(f"Genome {genome} not found in the dataframe.")
    plot_data = {}
    key_within_1=f"within_population_{population_1}|{population_1}"
    key_within_2=f"within_population_{population_2}|{population_2}"
    key_between=f"between_population_{min(population_1,population_2)}|{max(population_1,population_2)}"
    if df_filtered.get_column(key_within_1).list.len()[0]==0 or df_filtered.get_column(key_within_2).list.len()[0]==0 or df_filtered.get_column(key_between).list.len()[0]==0:
        raise ValueError(f"Not enough data for populations {population_1} and {population_2} in genome {genome}.")
    plot_data["within_population"]=df_filtered.get_column(key_within_1)[0].to_list()+df_filtered.get_column(key_within_2)[0].to_list()
    plot_data["between_population"]=df_filtered.get_column(key_between)[0].to_list()
    fig = go.Figure()
    between_pop_cdf=get_cdf(plot_data["between_population"], num_bins=num_bins)
    fig.add_trace(go.Scatter(
        x=between_pop_cdf[0][1:].copy(),
        y=between_pop_cdf[1][1:].copy(),
        mode='lines',
        name='between_population',
        line=dict(color='blue')
    ))
    within_pop_cdf=get_cdf(plot_data["within_population"], num_bins=num_bins)
    fig.add_trace(go.Scatter(
        x=within_pop_cdf[0][1:].copy(),
        y=within_pop_cdf[1][1:].copy(),
        mode='lines',
        name='within_population',
        line=dict(color='green')
    ))

    bin_edges=within_pop_cdf[0]
    cdf=within_pop_cdf[1]
    within_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]
    bin_edges=between_pop_cdf[0]
    cdf=between_pop_cdf[1]
    between_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]  
    distance=within_intersect-between_intersect

    fig.update_layout(
        title={"text": title.replace("<GENOME>", genome).replace("<POPULATION_1>", population_1).replace("<POPULATION_2>", population_2), "x": 0.5},
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,

    )
    ###Add a horizontal line from (between_intersect, vert_thresh_hor_distance) to (within_intersect, vert_thresh_hor_distance)
    fig.add_trace(go.Scatter(
        x=[between_intersect, within_intersect],
        y=[vert_thresh_hor_distance, vert_thresh_hor_distance],
        mode='lines+markers',
        line=dict(color='black'),
        showlegend=False
    ))
    ###Add a text annotation at the middle of the horizontal line with the distance
    fig.add_trace(go.Scatter(
        x=[(between_intersect+within_intersect)/2],
        y=[vert_thresh_hor_distance],
        mode="text",
        text=int(distance),
        textposition="top center",
        showlegend=False
    ))
    ##make both axes logarithmic
    fig.update_xaxes(type='log')
    fig.update_yaxes(type='log')

    return fig

plot_ibs_heatmap(df, vert_thresh=0.001, populations=None, num_bins=10000, min_member=50, title='IBS Heatmap', xaxis_title='Population Pair', yaxis_title='Genome')

Plot the Identity By State (IBS) heatmap for a given genome and two populations. Args: df (pl.DataFrame): DataFrame containing the IBS information. title (str, optional): Title of the plot. Defaults to "IBS Heatmap". xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair". yaxis_title (str, optional): Title of the y-axis. Defaults to "Genome". Returns: go.Figure: Plotly figure containing the IBS heatmap.

Source code in zipstrain/src/zipstrain/visualize.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def plot_ibs_heatmap(
    df:pl.DataFrame,
    vert_thresh:float=0.001,
    populations:list[str]|None=None,
    num_bins:int=10000,
    min_member:int=50,
    title:str="IBS Heatmap",
    xaxis_title:str="Population Pair",
    yaxis_title:str="Genome",

):
    """
    Plot the Identity By State (IBS) heatmap for a given genome and two populations.
    Args:
        df (pl.DataFrame): DataFrame containing the IBS information.
        title (str, optional): Title of the plot. Defaults to "IBS Heatmap".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Genome".
    Returns:
        go.Figure: Plotly figure containing the IBS heatmap.
    """
    df = df.with_columns(
    [
        pl.when(pl.col(c).list.len() < min_member)
        .then(pl.lit([-1]))
        .otherwise(pl.col(c))
        .alias(c)
        for c in df.columns if c != "genome"
    ]
)
    if populations is None:
        populations=set(chain.from_iterable(i.replace("within_population_","").replace("between_population_","").split("|") for i in df.columns if i!="genome"))
        populations=sorted(populations)
    heatmap_data = df.rows_by_key("genome", unique=True,include_key=False,named=True)
    fig_data={}
    for genome, genome_data in heatmap_data.items():
        fig_data[genome]={}
        for pop1,pop2 in combinations(populations,2):
            key_between=f"between_population_{min(pop1,pop2)}|{max(pop1,pop2)}"
            key_within_1=f"within_population_{pop1}|{pop1}"
            key_within_2=f"within_population_{pop2}|{pop2}"
            if genome_data.get(key_between, [-1])==[-1] or genome_data.get(key_within_1, [-1])==[-1] or genome_data.get(key_within_2, [-1])==[-1]:
                fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=-1
                continue
            between=get_cdf(genome_data[key_between], num_bins=num_bins)
            within=get_cdf(genome_data[key_within_1]+genome_data[key_within_2], num_bins=num_bins)

            between_intersect=between[0][np.where(between[1]>=vert_thresh)[0][0]]
            within_intersect=within[0][np.where(within[1]>=vert_thresh)[0][0]]
            distance=within_intersect-between_intersect
            fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=distance
    ###Filter the dataframe to only have useful information
    heatmap_df = pd.DataFrame(fig_data).T
    heatmap_df=heatmap_df.mask(heatmap_df < 0, 0)
    heatmap_df=heatmap_df[heatmap_df.sum(axis=1)>0]
    heatmap_df=heatmap_df[[col for col in heatmap_df.columns if heatmap_df[col].sum()>0]]
    heatmap_df_sorted = heatmap_df.assign(row_sum=heatmap_df.sum(axis=1)).sort_values("row_sum", ascending=True).drop(columns="row_sum")

    fig = go.Figure(data=go.Heatmap(
        z=heatmap_df_sorted.values,
        x=heatmap_df_sorted.columns,
        y=heatmap_df_sorted.index
    ))
    return fig

plot_identical_frac_vs_popani(df, genome, title='Fraction of Identical Genes vs popANI for <GENOME>', xaxis_title='Genome-Wide popANI', yaxis_title='Fraction of Identical Genes')

Plot the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations. Args: df (pl.DataFrame): DataFrame containing the fraction of identical genes vs popANI information. title (str, optional): Title of the plot. Defaults to "Fraction of Identical Genes vs popANI". xaxis_title (str, optional): Title of the x-axis. Defaults to "popANI". yaxis_title (str, optional): Title of the y-axis. Defaults to "Fraction of Identical Genes". Returns: go.Figure: Plotly figure containing the fraction of identical genes vs popANI plot.

Source code in zipstrain/src/zipstrain/visualize.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def plot_identical_frac_vs_popani(df:pl.DataFrame,
                                  genome:str,
                                  title:str="Fraction of Identical Genes vs popANI for <GENOME>",
                                  xaxis_title:str="Genome-Wide popANI",
                                  yaxis_title:str="Fraction of Identical Genes",
                                  ):
    """
    Plot the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations.
    Args:
        df (pl.DataFrame): DataFrame containing the fraction of identical genes vs popANI information.
        title (str, optional): Title of the plot. Defaults to "Fraction of Identical Genes vs popANI".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "popANI".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Fraction of Identical Genes".
    Returns:
        go.Figure: Plotly figure containing the fraction of identical genes vs popANI plot.
    """
    fig = go.Figure()
    for group, perc_id_genes, genome_pop_ani in zip(df["relationship"], df["perc_id_genes"], df["genome_pop_ani"]):
        fig.add_trace(go.Scatter(
            x=genome_pop_ani,
            y=perc_id_genes,
            mode='markers',
            name=group
        ))
    fig.update_layout(
        title=title.replace("<GENOME>", genome),
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title
    )
    return fig

plot_strainsharing(strainsharingrates, sample_frac=1, title='Strain Sharing Rates', xaxis_title='Population Pair', yaxis_title='Strain Sharing Rate')

Plot the strain sharing rates between populations. Args: strainsharingrates (dict[str, list[float]]): Dictionary containing the strain sharing rates between populations. title (str, optional): Title of the plot. Defaults to "Strain Sharing". xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair". yaxis_title (str, optional): Title of the y-axis. Defaults to "Strain Sharing Rate". Returns: go.Figure: Plotly figure containing the strain sharing plot.

Source code in zipstrain/src/zipstrain/visualize.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def plot_strainsharing(
    strainsharingrates:dict[str, list[float]],
    sample_frac:float=1,
    title:str="Strain Sharing Rates",
    xaxis_title:str="Population Pair",
    yaxis_title:str="Strain Sharing Rate",
):
    """
    Plot the strain sharing rates between populations.
    Args:
        strainsharingrates (dict[str, list[float]]): Dictionary containing the strain sharing rates between populations.
        title (str, optional): Title of the plot. Defaults to "Strain Sharing".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Strain Sharing Rate".
    Returns:
        go.Figure: Plotly figure containing the strain sharing plot.
    """
    for key in strainsharingrates.keys():
        strainsharingrates[key] = np.random.choice(strainsharingrates[key], size=int(len(strainsharingrates[key]) * sample_frac), replace=False)
    fig = go.Figure()
    for pair, rates in strainsharingrates.items():
        fig.add_trace(go.Box(
            y=rates,
            name=pair,
            boxpoints='all',
            jitter=0.3,
            pointpos=0
        ))
    fig.update_layout(
        title={"text": title, "x": 0.5},
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title
    )
    return fig

For more advanced users, new workflows can be developed using the task_manager module to create custom pipelines tailored to specific research needs.


Task Manager

zipstrain.task_manager

Lightweight, asyncio-driven orchestration primitives for building and running scientific data-processing pipelines. This module provides a small, composable framework for defining Tasks with explicit inputs/outputs, bundling Tasks into Batches (local or Slurm), and coordinating their execution with a live terminal UI. It is designed to be easy to extend for new Task types and execution environments. For most users, this module is not directly used. However, it can be used to define new pipelines that chain together multiple steps with clear input/outputs. The unit of execution is a batch, which is a collection of tasks to be executed together. Each batch can have an optional finalization step that runs after all tasks are complete.

Key concepts
  • Inputs and Outputs: These classes encapsulate task inputs and outputs with validation logics. By default, Input and output classes for files, strings, and integers are provided. If needed, new types can be defined by subclassing Input or Output.

  • Engines: Any task object can use a container engine (Docker or Apptainer) or run natively (LocalEngine).

  • Task Each task runs a unit of bash script with defined inputs and expected outputs. If an engine is provided, the command will be wrapped accordingly to run inside the container.

  • Batches: A batch is a collection of tasks to be executed together. Batches can be run locally or submitted to Slurm. Each batch monitors the status of its tasks and updates its own status accordingly. A batch can also have expected outputs that are checked after all tasks are complete. Additionally, a batch can have a finalization step that runs after all tasks are complete.

  • Runner: The Runner class orchestrates task generation, batching, and execution. It manages concurrent batch execution, monitors progress, and provides a live terminal UI using the rich library.

Batch

Bases: ABC

Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly. A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.

Parameters:

Name Type Description Default
tasks list[Task]

List of Task objects to be included in the batch.

required
id str

Unique identifier for the batch.

required
run_dir Path

Directory where the batch will be executed.

required
expected_outputs list[Output]

List of expected outputs for the batch.

required
Source code in zipstrain/src/zipstrain/task_manager.py
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
class Batch(ABC):
    """Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly.
    A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.

    Args:
        tasks (list[Task]): List of Task objects to be included in the batch.
        id (str): Unique identifier for the batch.
        run_dir (pathlib.Path): Directory where the batch will be executed.
        expected_outputs (list[Output]): List of expected outputs for the batch.
    """
    TEMPLATE_CMD = ""
    RUN_LOG_FILE = "batch_events.log"
    BATCH_LOG_FILE = "batch.log"

    def __init__(self, tasks: list[Task],
                 id: str,
                 run_dir: pathlib.Path,
                 expected_outputs: list[Output],
                 file_semaphore: asyncio.Semaphore| None = None
                 ) -> None:
        self.id = id
        self.tasks = tasks
        self.run_dir = pathlib.Path(run_dir)
        self.batch_dir = self.run_dir / self.id
        self.retry_count = 0
        self.expected_outputs = expected_outputs
        self.file_semaphore = file_semaphore
        for output in self.expected_outputs:
            if isinstance(output, BatchFileOutput):
                output.register_batch(self)
        self._status = self._get_initial_status()
        for task in self.tasks:
            task._batch_obj = self
            task.file_semaphore = self.file_semaphore
            for output in task.expected_outputs.values():
                output.register_task(task)
            task._status= task._get_initial_status()
            task.map_io()

        self._runner_obj:Runner = None
        self._cleaned_up = False
        self._last_progress_snapshot: tuple[str, int, int, int, int, int] | None = None

    def _get_initial_status(self) -> str:
        """Returns the initial status of the batch based on the presence of the batch directory."""
        if not self.batch_dir.exists():
            return Status.NOT_STARTED.value
        status_file = self.batch_dir / ".status"
        if not status_file.exists():
            return Status.NOT_STARTED.value
        with open(status_file, mode="r") as f:
            status_as_written = f.read().strip()

        if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
            outputs_ready = self.outputs_ready()

            if outputs_ready:
                return Status.SUCCESS.value

            else:
                return Status.FAILED.value
        if status_as_written in (
            Status.NOT_STARTED.value,
            Status.RUNNING.value,
            Status.SUBMITTED.value,
            Status.PENDING.value,
            Status.FAILED.value,
        ):
            return status_as_written
        return Status.NOT_STARTED.value

    def _progress_counts(self) -> tuple[int, int, int, int]:
        """Return (done, total, success, failed) counts for batch tasks."""
        total = len(self.tasks)
        success = sum(1 for task in self.tasks if task.status == Status.SUCCESS.value)
        failed = sum(1 for task in self.tasks if task.status == Status.FAILED.value)
        done = success + failed
        return done, total, success, failed

    async def _append_batch_log(self, event: str, message: str = "") -> None:
        """Append a human-readable lifecycle/progress line to logs."""
        timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
        clean_message = message.replace("\n", "\\n").strip()
        done, total, success, failed = self._progress_counts()
        line = (
            f"{timestamp} BATCH {self.id} {event} "
            f"state={self._status} progress={done}/{total} success={success} failed={failed} "
            f"attempt={self.retry_count + 1}"
        )
        if clean_message:
            line += f" message={clean_message}"
        line += "\n"
        await append_file(self.run_dir / self.RUN_LOG_FILE, line, self.file_semaphore)
        if self.batch_dir.exists():
            await append_file(self.batch_dir / self.BATCH_LOG_FILE, line, self.file_semaphore)

    async def log_progress(self, message: str = "", force: bool = False) -> None:
        """Log progress if the batch snapshot changed, unless forced."""
        done, total, success, failed = self._progress_counts()
        snapshot = (self._status, done, total, success, failed, self.retry_count)
        if not force and snapshot == self._last_progress_snapshot:
            return
        self._last_progress_snapshot = snapshot
        await self._append_batch_log("PROGRESS", message)

    async def _set_status(self, status: str, message: str = "", persist_status: str | None = None) -> None:
        """Set batch status, persist it, and write a log entry."""
        previous = self._status
        self._status = status
        persisted = persist_status if persist_status is not None else status
        if self.batch_dir.exists():
            await write_file(self.batch_dir / ".status", persisted, self.file_semaphore)
        if previous != status:
            transition = f"{previous}->{status}"
            if message:
                transition = f"{transition}; {message}"
            if status == Status.SUCCESS.value:
                event = "DONE"
            elif status == Status.FAILED.value:
                event = "FAILED"
            elif previous == Status.NOT_STARTED.value:
                event = "START"
            else:
                event = "STATE"
            await self._append_batch_log(event, transition)
        elif message:
            await self._append_batch_log("NOTE", message)

    def cleanup(self) -> None:
        """The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
        self._cleaned_up = True
        return None

    @abstractmethod
    async def cancel(self) -> None:
        """Cancels the batch. This method should be implemented by subclasses."""
        ...

    def outputs_ready(self) -> bool:
        """Check if all BATCH-LEVEL expected outputs are ready."""

        return all([output.ready() for output in self.expected_outputs])

    async def _collect_task_status(self) -> list[str]:
        """Collects the status of all tasks asynchronously."""
        return await asyncio.gather(*[task.get_status() for task in self.tasks])

    @abstractmethod
    async def run(self) -> None:
        """Runs the batch. This method should be implemented by subclasses."""
        ...

    @abstractmethod
    def _parse_job_id(self, sbatch_output: str) -> str:
        """Parses the job ID from the sbatch output. This method should be implemented by subclasses."""
        ...

    @property
    def status(self) -> str:
        """Returns the current status of the batch."""
        return self._status

    @property
    def stats(self) -> dict[str, str]:
        """Returns a dictionary of task IDs and their statuses."""
        return {task.id: task.status for task in self.tasks}

    async def update_status(self) -> str:
        """Updates the status of the batch by collecting the status of all tasks."""
        await self._collect_task_status()
        return self._status

    def _set_file_semaphore(self, file_semaphore: asyncio.Semaphore) -> None:
        self.file_semaphore = file_semaphore
        for task in self.tasks:
            task.file_semaphore = file_semaphore
stats property

Returns a dictionary of task IDs and their statuses.

status property

Returns the current status of the batch.

cancel() abstractmethod async

Cancels the batch. This method should be implemented by subclasses.

Source code in zipstrain/src/zipstrain/task_manager.py
769
770
771
772
@abstractmethod
async def cancel(self) -> None:
    """Cancels the batch. This method should be implemented by subclasses."""
    ...
cleanup()

The base class defines if any cleanup is needed after batch success. By default, it does nothing.

Source code in zipstrain/src/zipstrain/task_manager.py
764
765
766
767
def cleanup(self) -> None:
    """The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
    self._cleaned_up = True
    return None
log_progress(message='', force=False) async

Log progress if the batch snapshot changed, unless forced.

Source code in zipstrain/src/zipstrain/task_manager.py
732
733
734
735
736
737
738
739
async def log_progress(self, message: str = "", force: bool = False) -> None:
    """Log progress if the batch snapshot changed, unless forced."""
    done, total, success, failed = self._progress_counts()
    snapshot = (self._status, done, total, success, failed, self.retry_count)
    if not force and snapshot == self._last_progress_snapshot:
        return
    self._last_progress_snapshot = snapshot
    await self._append_batch_log("PROGRESS", message)
outputs_ready()

Check if all BATCH-LEVEL expected outputs are ready.

Source code in zipstrain/src/zipstrain/task_manager.py
774
775
776
777
def outputs_ready(self) -> bool:
    """Check if all BATCH-LEVEL expected outputs are ready."""

    return all([output.ready() for output in self.expected_outputs])
run() abstractmethod async

Runs the batch. This method should be implemented by subclasses.

Source code in zipstrain/src/zipstrain/task_manager.py
783
784
785
786
@abstractmethod
async def run(self) -> None:
    """Runs the batch. This method should be implemented by subclasses."""
    ...
update_status() async

Updates the status of the batch by collecting the status of all tasks.

Source code in zipstrain/src/zipstrain/task_manager.py
803
804
805
806
async def update_status(self) -> str:
    """Updates the status of the batch by collecting the status of all tasks."""
    await self._collect_task_status()
    return self._status

BatchFileOutput

Bases: Output

This is used when the output is a file path relative to the batch directory. Also it will be registered to the batch instead of the task.

Source code in zipstrain/src/zipstrain/task_manager.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
class BatchFileOutput(Output):
    """This is used when the output is a file path relative to the batch directory.
    Also it will be registered to the batch instead of the task.
    """
    def __init__(self, expected_file:str) -> None:
        self._expected_file_name = expected_file

    def ready(self) -> bool:
        """Check if the expected output file exists."""
        return True if self.expected_file.absolute().exists() else False

    def register_batch(self, batch: Batch) -> None:
        """Registers the batch that produces this output and sets the expected file path.

        Args:
        batch (Batch): The batch that produces this output and sets the expected file path.
        """
        self.expected_file = batch.batch_dir / self._expected_file_name
ready()

Check if the expected output file exists.

Source code in zipstrain/src/zipstrain/task_manager.py
265
266
267
def ready(self) -> bool:
    """Check if the expected output file exists."""
    return True if self.expected_file.absolute().exists() else False
register_batch(batch)

Registers the batch that produces this output and sets the expected file path.

Args: batch (Batch): The batch that produces this output and sets the expected file path.

Source code in zipstrain/src/zipstrain/task_manager.py
269
270
271
272
273
274
275
def register_batch(self, batch: Batch) -> None:
    """Registers the batch that produces this output and sets the expected file path.

    Args:
    batch (Batch): The batch that produces this output and sets the expected file path.
    """
    self.expected_file = batch.batch_dir / self._expected_file_name

CollectComps

Bases: Task

A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
class CollectComps(Task):
    """A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""
    TEMPLATE_CMD="""
    rm -rf comps
    mkdir -p comps
    find . -maxdepth 2 -type f -name "*_comparison.parquet" ! -path "./comps/*" -exec cp {} comps/ \\;
    zipstrain utilities merge_parquet --input-dir comps --output-file <output-file>
    rm -rf comps
    """

    @property
    def pre_run(self) -> str:
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"

CollectGeneComps

Bases: Task

A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
class CollectGeneComps(Task):
    """A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""
    TEMPLATE_CMD="""
    rm -rf gene_comps
    mkdir -p gene_comps
    find . -maxdepth 2 -type f -name "*_gene_comparison.parquet" ! -path "./gene_comps/*" -exec cp {} gene_comps/ \\;
    zipstrain utilities merge_parquet --input-dir gene_comps --output-file <output-file>
    rm -rf gene_comps
    """

    @property
    def pre_run(self) -> str:
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"

CompareRunner

Bases: Runner

Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
class CompareRunner(Runner):
    """
    Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.    
    """

    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = FastCompareSlurmBatch
            final_batch_factory = PrepareCompareGenomeRunOutputsSlurmBatch
        else:
            batch_factory = FastCompareLocalBatch
            final_batch_factory = PrepareCompareGenomeRunOutputsLocalBatch
        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )




    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue. Each batch includes a CollectComps task to merge the outputs of the tasks in the batch.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    batch_id = f"batch_{self._batch_counter}"
                    self._batch_counter += 1
                    batch_tasks = buffer + [
                        CollectComps(
                            "concat_parquet",
                            {},
                            {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
                            engine=self.container_engine,
                        )
                    ]
                    expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
                    if self.batch_type == "slurm":
                        batch = self.batch_factory(
                            tasks=batch_tasks,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=expected_outputs,
                            slurm_config=self.slurm_config,
                        )
                    else:
                        batch = self.batch_factory(
                            tasks=batch_tasks,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=expected_outputs,
                        )
                    await self.batches_queue.put(batch)
                await self.batches_queue.put(None)
                self._batcher_done = True
                break
            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                batch_id = f"batch_{self._batch_counter}"
                self._batch_counter += 1
                batch_tasks = buffer + [
                    CollectComps(
                        "concat_parquet",
                        {},
                        {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
                        engine=self.container_engine,
                    )
                ]
                expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
                if self.batch_type == "slurm":
                    batch = self.batch_factory(
                        tasks=batch_tasks,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=expected_outputs,
                        slurm_config=self.slurm_config,
                    )
                else:
                    batch = self.batch_factory(
                        tasks=batch_tasks,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=expected_outputs,
                    )
                await self.batches_queue.put(batch)
                buffer = []



    def _create_final_batch(self) -> Batch:
        """Creates the final batch that prepares the overall outputs after all comparison batches are done."""
        final_task = PrepareCompareGenomeRunOutputs(
            id="prepare_outputs",
            inputs={"output-dir": StringInput("Outputs")},
            expected_outputs={},
            engine=self.container_engine,
        )
        expected_outputs = [BatchFileOutput("all_comparisons.parquet")]
        if self.batch_type == "slurm":
            final_batch=self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
                slurm_config=self.slurm_config,
            )
            final_batch._runner_obj = self
            return final_batch
        else:
            final_batch = self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
            )
            final_batch._runner_obj = self
            return final_batch

CompareTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in zipstrain.compare module.

Parameters:

Name Type Description Default
data LazyFrame

Polars LazyFrame containing the data for generating tasks.

required
yield_size int

Number of tasks to yield at a time.

required
comp_config GenomeComparisonConfig

Configuration for genome comparison.

required
duckdb_memory_limit str | None

Optional DuckDB memory limit (for example "2GB").

None
duckdb_threads int | None

Optional DuckDB thread cap (for example 8).

None
compare_engine str

Compare engine passed to single compare tasks ("polars" or "duckdb").

'polars'
Source code in zipstrain/src/zipstrain/task_manager.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
class CompareTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in
    zipstrain.compare module.

    Args:
        data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
        yield_size (int): Number of tasks to yield at a time.
        comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
        duckdb_memory_limit (str | None): Optional DuckDB memory limit (for example "2GB").
        duckdb_threads (int | None): Optional DuckDB thread cap (for example 8).
        compare_engine (str): Compare engine passed to single compare tasks ("polars" or "duckdb").
    """
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        comp_config: database.GenomeComparisonConfig,
        duckdb_memory_limit: str | None = None,
        duckdb_threads: int | None = None,
        compare_engine: str = "polars",
        calculate: str = "all",
    ) -> None:
        super().__init__(data, yield_size)
        self.comp_config = comp_config
        self.engine = container_engine
        self.duckdb_memory_limit = duckdb_memory_limit
        self.duckdb_threads = duckdb_threads
        self.compare_engine = compare_engine
        self.calculate = calculate
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")
        if self.compare_engine not in {"polars", "duckdb"}:
            raise ValueError("compare_engine must be one of {'polars', 'duckdb'}.")

    def get_total_tasks(self) -> int:
        """Returns total number of pairwise comparisons to be made."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                duckdb_memory_limit_arg = (
                    f"--duckdb-memory-limit {self.duckdb_memory_limit}"
                    if self.duckdb_memory_limit
                    else ""
                )
                duckdb_threads_arg = (
                    f"--duckdb-threads {self.duckdb_threads}"
                    if self.duckdb_threads is not None
                    else ""
                )
                compare_engine_arg = f"--engine {self.compare_engine}"
                inputs = {
                "mpile_1_file": FileInput(row["profile_location_1"]),
                "mpile_2_file": FileInput(row["profile_location_2"]),
                "stb_file": FileInput(self.comp_config.stb_file_loc),
                "min_cov": IntInput(self.comp_config.min_cov),
                "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
                "duckdb-memory-limit-arg": StringInput(duckdb_memory_limit_arg),
                "duckdb-threads-arg": StringInput(duckdb_threads_arg),
                "compare-engine-arg": StringInput(compare_engine_arg),
                "calculate-arg": StringInput(f"--calculate {self.calculate}"),
                "genome-name": StringInput(self.comp_config.scope),
                }
                expected_outputs ={
                "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),

                }
                task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
async def generate_tasks(self) -> list[Task]:
    """Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            duckdb_memory_limit_arg = (
                f"--duckdb-memory-limit {self.duckdb_memory_limit}"
                if self.duckdb_memory_limit
                else ""
            )
            duckdb_threads_arg = (
                f"--duckdb-threads {self.duckdb_threads}"
                if self.duckdb_threads is not None
                else ""
            )
            compare_engine_arg = f"--engine {self.compare_engine}"
            inputs = {
            "mpile_1_file": FileInput(row["profile_location_1"]),
            "mpile_2_file": FileInput(row["profile_location_2"]),
            "stb_file": FileInput(self.comp_config.stb_file_loc),
            "min_cov": IntInput(self.comp_config.min_cov),
            "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
            "duckdb-memory-limit-arg": StringInput(duckdb_memory_limit_arg),
            "duckdb-threads-arg": StringInput(duckdb_threads_arg),
            "compare-engine-arg": StringInput(compare_engine_arg),
            "calculate-arg": StringInput(f"--calculate {self.calculate}"),
            "genome-name": StringInput(self.comp_config.scope),
            }
            expected_outputs ={
            "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),

            }
            task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of pairwise comparisons to be made.

Source code in zipstrain/src/zipstrain/task_manager.py
592
593
594
def get_total_tasks(self) -> int:
    """Returns total number of pairwise comparisons to be made."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

FastCompareLocalBatch

Bases: LocalBatch

A LocalBatch that runs FastCompareTask tasks locally.

Source code in zipstrain/src/zipstrain/task_manager.py
1685
1686
1687
1688
1689
1690
1691
1692
class FastCompareLocalBatch(LocalBatch):
    """A LocalBatch that runs FastCompareTask tasks locally."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)
        self._cleaned_up = True

FastCompareSlurmBatch

Bases: SlurmBatch

A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future

Source code in zipstrain/src/zipstrain/task_manager.py
1694
1695
1696
1697
1698
1699
1700
1701
1702
class FastCompareSlurmBatch(SlurmBatch):
    """A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future"""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)

        self._cleaned_up = True

FastCompareTask

Bases: Task

A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
class FastCompareTask(Task):
    """A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command.
        """
    TEMPLATE_CMD="""
    zipstrain utilities single_compare_genome --mpileup-contig-1 <mpile_1_file> \
    --mpileup-contig-2 <mpile_2_file> \
    --stb-file <stb_file> \
    --min-cov <min_cov> \
    --min-gene-compare-len <min-gene-compare-len> \
    <duckdb-memory-limit-arg> \
    <duckdb-threads-arg> \
    <compare-engine-arg> \
    <calculate-arg> \
    --output-file <output-file> \
    --genome <genome-name>
    """

FastGeneCompareLocalBatch

Bases: LocalBatch

A LocalBatch that runs FastGeneCompareTask tasks locally.

Source code in zipstrain/src/zipstrain/task_manager.py
2088
2089
2090
2091
2092
2093
2094
2095
2096
class FastGeneCompareLocalBatch(LocalBatch):
    """A LocalBatch that runs FastGeneCompareTask tasks locally."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
        for task in tasks_to_remove:
            task._status=Status.SUCCESS
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)
        self._cleaned_up = True

FastGeneCompareSlurmBatch

Bases: SlurmBatch

A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster.

Source code in zipstrain/src/zipstrain/task_manager.py
2099
2100
2101
2102
2103
2104
2105
2106
class FastGeneCompareSlurmBatch(SlurmBatch):
    """A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)
        self._cleaned_up = True

FastGeneCompareTask

Bases: Task

A Task that performs a fast gene comparison using the compare single_compare_gene command.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
class FastGeneCompareTask(Task):
    """A Task that performs a fast gene comparison using the compare single_compare_gene command.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command.
    """
    TEMPLATE_CMD="""
    zipstrain utilities single_compare_gene --mpileup-contig-1 <mpile_1_file> \
    --mpileup-contig-2 <mpile_2_file> \
    --stb-file <stb_file> \
    --min-cov <min_cov> \
    --min-gene-compare-len <min-gene-compare-len> \
    <duckdb-memory-limit-arg> \
    <duckdb-threads-arg> \
    <compare-engine-arg> \
    --output-file <output-file> \
    --ani-method <ani-method>
    """

FileInput

Bases: Input

This is used when the input is a file path. By default, the validate method checks for file existence.

Source code in zipstrain/src/zipstrain/task_manager.py
171
172
173
174
175
176
177
178
179
class FileInput(Input):
    """This is used when the input is a file path. By default, the validate method checks for file existence."""
    def validate(self, check_exists: bool = True) -> None:
        if check_exists and not pathlib.Path(self.value).exists():
            raise FileNotFoundError(f"Input file {self.value} does not exist.")

    def get_value(self) -> str:
        """Returns the absolute path of the input file as a string."""
        return str(pathlib.Path(self.value).absolute())
get_value()

Returns the absolute path of the input file as a string.

Source code in zipstrain/src/zipstrain/task_manager.py
177
178
179
def get_value(self) -> str:
    """Returns the absolute path of the input file as a string."""
    return str(pathlib.Path(self.value).absolute())

FileOutput

Bases: Output

This is used when the output is a file path.

Parameters:

Name Type Description Default
expected_file str

The expected output file name relative to the task directory.

required
Source code in zipstrain/src/zipstrain/task_manager.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class FileOutput(Output):
    """This is used when the output is a file path.

    Args:
        expected_file (str): The expected output file name relative to the task directory.
    """
    def __init__(self, expected_file:str) -> None:
        self._expected_file_name = expected_file ### When the task is finished, the expected file should be in task.task_dir / expected_file otherwise ready() will return False

    def ready(self) -> bool:
        """Check if the expected output file exists."""
        return True if self.expected_file.absolute().exists() else False

    def register_task(self, task: Task) -> None:
        """Registers the task that produces this output and sets the expected file path.

        Args:
        task (Task): The task that produces this output.
        """
        super().register_task(task)
        self.expected_file = self.task.task_dir / self._expected_file_name
ready()

Check if the expected output file exists.

Source code in zipstrain/src/zipstrain/task_manager.py
244
245
246
def ready(self) -> bool:
    """Check if the expected output file exists."""
    return True if self.expected_file.absolute().exists() else False
register_task(task)

Registers the task that produces this output and sets the expected file path.

Args: task (Task): The task that produces this output.

Source code in zipstrain/src/zipstrain/task_manager.py
248
249
250
251
252
253
254
255
def register_task(self, task: Task) -> None:
    """Registers the task that produces this output and sets the expected file path.

    Args:
    task (Task): The task that produces this output.
    """
    super().register_task(task)
    self.expected_file = self.task.task_dir / self._expected_file_name

GeneCompareRunner

Bases: Runner

Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
class GeneCompareRunner(Runner):
    """
    Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.    
    """

    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = FastGeneCompareSlurmBatch
            final_batch_factory = PrepareGeneCompareRunOutputsSlurmBatch
        else:
            batch_factory = FastGeneCompareLocalBatch
            final_batch_factory = PrepareGeneCompareRunOutputsLocalBatch
        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )

    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue. Each batch includes a CollectGeneComps task to merge the outputs of the tasks in the batch.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    collect_task = CollectGeneComps(
                        id="collect_gene_comps",
                        inputs={},
                        expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
                        engine=self.container_engine,
                    )
                    buffer.append(collect_task)
                    if self.batch_type == "slurm":
                        batch = self.batch_factory(
                            tasks=buffer,
                            id=f"gene_batch_{self._batch_counter}",
                            run_dir=self.run_dir,
                            expected_outputs=[],
                            slurm_config=self.slurm_config,
                        )
                    else:
                        batch = self.batch_factory(
                        tasks=buffer,
                        id=f"gene_batch_{self._batch_counter}",
                        run_dir=self.run_dir,
                        expected_outputs=[],
                    )
                    await self.batches_queue.put(batch)
                    self._batch_counter += 1
                self._batcher_done = True
                break

            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                collect_task = CollectGeneComps(
                    id="collect_gene_comps",
                    inputs={},
                    expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
                    engine=self.container_engine,
                )
                buffer.append(collect_task)
                if self.batch_type == "slurm":
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=f"gene_batch_{self._batch_counter}",
                        run_dir=self.run_dir,
                        expected_outputs=[],
                        slurm_config=self.slurm_config,
                    )
                else:

                    batch = self.batch_factory(
                    tasks=buffer,
                    id=f"gene_batch_{self._batch_counter}",
                    run_dir=self.run_dir,
                    expected_outputs=[],
                )
                await self.batches_queue.put(batch)
                self._batch_counter += 1
                buffer = []

    def _create_final_batch(self) -> Batch:
        """Creates the final batch that prepares the overall outputs after all gene comparison batches are done."""
        final_task = PrepareGeneCompareRunOutputs(
            id="prepare_gene_outputs",
            inputs={"output-dir": StringInput("Outputs")},
            expected_outputs={},
            engine=self.container_engine,
        )
        expected_outputs = [BatchFileOutput("all_gene_comparisons.parquet")]
        if self.batch_type == "slurm":
            final_batch=self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
                slurm_config=self.slurm_config,
            )
            final_batch._runner_obj = self
            return final_batch
        else:
            final_batch = self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
            )
            final_batch._runner_obj = self
            return final_batch

GeneCompareTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in zipstrain.compare module.

Parameters:

Name Type Description Default
data LazyFrame

Polars LazyFrame containing the data for generating tasks.

required
yield_size int

Number of tasks to yield at a time.

required
comp_config GenomeComparisonConfig

Configuration for genome comparison.

required
ani_method str

ANI calculation method to use. Default is "popani".

'popani'
duckdb_threads int | None

Optional DuckDB thread cap (for example 8).

None
compare_engine str

Compare engine passed to single compare tasks ("polars" or "duckdb").

'polars'
Source code in zipstrain/src/zipstrain/task_manager.py
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
class GeneCompareTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in
    zipstrain.compare module.

    Args:
        data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
        yield_size (int): Number of tasks to yield at a time.
        comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
        ani_method (str): ANI calculation method to use. Default is "popani".
        duckdb_threads (int | None): Optional DuckDB thread cap (for example 8).
        compare_engine (str): Compare engine passed to single compare tasks ("polars" or "duckdb").
    """
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        comp_config: database.GeneComparisonConfig,
        ani_method: str = "popani",
        duckdb_memory_limit: str | None = None,
        duckdb_threads: int | None = None,
        compare_engine: str = "polars",
    ) -> None:
        super().__init__(data, yield_size)
        self.comp_config = comp_config
        self.engine = container_engine
        self.ani_method = ani_method
        self.duckdb_memory_limit = duckdb_memory_limit
        self.duckdb_threads = duckdb_threads
        self.compare_engine = compare_engine
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")
        if self.compare_engine not in {"polars", "duckdb"}:
            raise ValueError("compare_engine must be one of {'polars', 'duckdb'}.")

    def get_total_tasks(self) -> int:
        """Returns total number of pairwise comparisons to be made."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                duckdb_memory_limit_arg = (
                    f"--duckdb-memory-limit {self.duckdb_memory_limit}"
                    if self.duckdb_memory_limit
                    else ""
                )
                duckdb_threads_arg = (
                    f"--duckdb-threads {self.duckdb_threads}"
                    if self.duckdb_threads is not None
                    else ""
                )
                compare_engine_arg = f"--engine {self.compare_engine}"
                inputs = {
                "mpile_1_file": FileInput(row["profile_location_1"]),
                "mpile_2_file": FileInput(row["profile_location_2"]),
                "stb_file": FileInput(self.comp_config.stb_file_loc),
                "min_cov": IntInput(self.comp_config.min_cov),
                "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
                "duckdb-memory-limit-arg": StringInput(duckdb_memory_limit_arg),
                "duckdb-threads-arg": StringInput(duckdb_threads_arg),
                "compare-engine-arg": StringInput(compare_engine_arg),
                "ani-method": StringInput(self.ani_method),
                }
                expected_outputs ={
                "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
                }
                task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
async def generate_tasks(self) -> list[Task]:
    """Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            duckdb_memory_limit_arg = (
                f"--duckdb-memory-limit {self.duckdb_memory_limit}"
                if self.duckdb_memory_limit
                else ""
            )
            duckdb_threads_arg = (
                f"--duckdb-threads {self.duckdb_threads}"
                if self.duckdb_threads is not None
                else ""
            )
            compare_engine_arg = f"--engine {self.compare_engine}"
            inputs = {
            "mpile_1_file": FileInput(row["profile_location_1"]),
            "mpile_2_file": FileInput(row["profile_location_2"]),
            "stb_file": FileInput(self.comp_config.stb_file_loc),
            "min_cov": IntInput(self.comp_config.min_cov),
            "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
            "duckdb-memory-limit-arg": StringInput(duckdb_memory_limit_arg),
            "duckdb-threads-arg": StringInput(duckdb_threads_arg),
            "compare-engine-arg": StringInput(compare_engine_arg),
            "ani-method": StringInput(self.ani_method),
            }
            expected_outputs ={
            "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
            }
            task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of pairwise comparisons to be made.

Source code in zipstrain/src/zipstrain/task_manager.py
1870
1871
1872
def get_total_tasks(self) -> int:
    """Returns total number of pairwise comparisons to be made."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

Input

Bases: ABC

Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY. Most commonly used Input types are provided but if you want to define a new one, subclass this and implement validate() and get_value().

Source code in zipstrain/src/zipstrain/task_manager.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class Input(ABC):
    """Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY.
    Most commonly used Input types are provided but if you want to define a new one,
    subclass this and implement validate() and get_value().
    """
    def __init__(self, value: str | int) -> None:
        self.value = value
        self.validate()

    @abstractmethod
    def validate(self) -> None:
        ...

    @abstractmethod
    def get_value(self) -> str | int:
        ...

IntInput

Bases: Input

This is used when the input is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
195
196
197
198
199
200
201
202
203
204
205
206
class IntInput(Input):
    """This is used when the input is an integer."""
    def validate(self) -> None:
        """Validate that the input value is an integer."""
        if not isinstance(self.value, int):
            raise ValueError(f"Input value {self.value!r} is not an integer.")

    def get_value(self) -> str:
        """
        Returns the integer value as a string.
        """
        return str(self.value)
get_value()

Returns the integer value as a string.

Source code in zipstrain/src/zipstrain/task_manager.py
202
203
204
205
206
def get_value(self) -> str:
    """
    Returns the integer value as a string.
    """
    return str(self.value)
validate()

Validate that the input value is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
197
198
199
200
def validate(self) -> None:
    """Validate that the input value is an integer."""
    if not isinstance(self.value, int):
        raise ValueError(f"Input value {self.value!r} is not an integer.")

IntOutput

Bases: Output

This is used when the output is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
290
291
292
293
294
295
296
297
298
299
class IntOutput(Output):
    """This is used when the output is an integer."""
    def ready(self) -> bool:
        """Check if the output value is an integer."""
        if isinstance(self._value, int):
            return True
        elif self._value is not None:
            raise ValueError(f"Output value for task {self.task.id} is not an integer.")
        else:
            return False
ready()

Check if the output value is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
292
293
294
295
296
297
298
299
def ready(self) -> bool:
    """Check if the output value is an integer."""
    if isinstance(self._value, int):
        return True
    elif self._value is not None:
        raise ValueError(f"Output value for task {self.task.id} is not an integer.")
    else:
        return False

LocalBatch

Bases: Batch

Batch that runs tasks locally in a single shell script.

Source code in zipstrain/src/zipstrain/task_manager.py
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
class LocalBatch(Batch):
    """Batch that runs tasks locally in a single shell script."""
    TEMPLATE_CMD = "#!/bin/bash\n"

    def __init__(self, tasks, id, run_dir, expected_outputs) -> None:
        super().__init__(tasks, id, run_dir, expected_outputs)
        self._script = self.TEMPLATE_CMD + "\nset -euo pipefail\n"
        self._proc: asyncio.subprocess.Process | None = None 


    async def run(self) -> None:
        """This method runs all tasks in the batch locally by creating a shell script and executing it."""
        if self.status != Status.SUCCESS:

            self.batch_dir.mkdir(parents=True, exist_ok=True)

            await self._set_status(Status.RUNNING.value, "batch execution started")

            script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch
            script = self._script # Initialize the script content

            for task in self.tasks:
                if task.status != Status.SUCCESS.value:
                    if task.task_dir.exists():
                        shutil.rmtree(task.task_dir) # Because it must have failed and we don't want those remnants
                    task.task_dir.mkdir(parents=True)  # Create task directory
                    await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
                    script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"



            await write_file(script_path, script, self.file_semaphore)

            self._proc = await asyncio.create_subprocess_exec(
                "bash", f"{self.id}.sh",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
                cwd=self.batch_dir,
            )
            out_bytes, err_bytes = b"", b""

            try:
                out_bytes, err_bytes = await self._proc.communicate()

            except asyncio.CancelledError:
                if self._proc and self._proc.returncode is None:
                    self._proc.terminate()
                    await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
                    await self._set_status(Status.FAILED.value, "batch execution cancelled")

                raise RuntimeError("Batch script execution was cancelled.")


            await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
            await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
            await self._collect_task_status()

            if self._proc.returncode != 0:
                error=err_bytes.decode()
                await self._set_status(Status.FAILED.value, f"runtime error: {error.strip()}")
                raise RuntimeError(f"Batch {self.id} hit the following error at runtime:\n{error}")

            if self._proc.returncode == 0 and self.outputs_ready():
                self.cleanup()
                await self._set_status(Status.SUCCESS.value, "batch outputs validated", persist_status=Status.DONE.value)

            else:
                await self._set_status(Status.FAILED.value, "missing expected outputs")
                raise FileNotFoundError(f"Batch {self.id} is done but at least one expected output is missing.")
        else:
            await self._append_batch_log("status_note", "batch already marked success; skipping run")




    def _parse_job_id(self, sbatch_output):
        return super()._parse_job_id(sbatch_output)


    async def cancel(self) -> None:
        """Cancels the local batch by terminating the subprocess if it's running."""
        if self._proc and self._proc.returncode is None:
            self._proc.terminate()
            try:
                await asyncio.wait_for(self._proc.wait(), timeout=5.0)
            except asyncio.TimeoutError:
                self._proc.kill()
                await self._proc.wait()
            await self._set_status(Status.FAILED.value, Messages.CANCELLED_BY_USER.value)
cancel() async

Cancels the local batch by terminating the subprocess if it's running.

Source code in zipstrain/src/zipstrain/task_manager.py
892
893
894
895
896
897
898
899
900
901
async def cancel(self) -> None:
    """Cancels the local batch by terminating the subprocess if it's running."""
    if self._proc and self._proc.returncode is None:
        self._proc.terminate()
        try:
            await asyncio.wait_for(self._proc.wait(), timeout=5.0)
        except asyncio.TimeoutError:
            self._proc.kill()
            await self._proc.wait()
        await self._set_status(Status.FAILED.value, Messages.CANCELLED_BY_USER.value)
run() async

This method runs all tasks in the batch locally by creating a shell script and executing it.

Source code in zipstrain/src/zipstrain/task_manager.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
async def run(self) -> None:
    """This method runs all tasks in the batch locally by creating a shell script and executing it."""
    if self.status != Status.SUCCESS:

        self.batch_dir.mkdir(parents=True, exist_ok=True)

        await self._set_status(Status.RUNNING.value, "batch execution started")

        script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch
        script = self._script # Initialize the script content

        for task in self.tasks:
            if task.status != Status.SUCCESS.value:
                if task.task_dir.exists():
                    shutil.rmtree(task.task_dir) # Because it must have failed and we don't want those remnants
                task.task_dir.mkdir(parents=True)  # Create task directory
                await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
                script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"



        await write_file(script_path, script, self.file_semaphore)

        self._proc = await asyncio.create_subprocess_exec(
            "bash", f"{self.id}.sh",
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            cwd=self.batch_dir,
        )
        out_bytes, err_bytes = b"", b""

        try:
            out_bytes, err_bytes = await self._proc.communicate()

        except asyncio.CancelledError:
            if self._proc and self._proc.returncode is None:
                self._proc.terminate()
                await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
                await self._set_status(Status.FAILED.value, "batch execution cancelled")

            raise RuntimeError("Batch script execution was cancelled.")


        await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
        await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)
        await self._collect_task_status()

        if self._proc.returncode != 0:
            error=err_bytes.decode()
            await self._set_status(Status.FAILED.value, f"runtime error: {error.strip()}")
            raise RuntimeError(f"Batch {self.id} hit the following error at runtime:\n{error}")

        if self._proc.returncode == 0 and self.outputs_ready():
            self.cleanup()
            await self._set_status(Status.SUCCESS.value, "batch outputs validated", persist_status=Status.DONE.value)

        else:
            await self._set_status(Status.FAILED.value, "missing expected outputs")
            raise FileNotFoundError(f"Batch {self.id} is done but at least one expected output is missing.")
    else:
        await self._append_batch_log("status_note", "batch already marked success; skipping run")

Messages

Bases: StrEnum

Enumeration of common messages used in task and batch management.

Source code in zipstrain/src/zipstrain/task_manager.py
149
150
151
class Messages(StrEnum):
    """Enumeration of common messages used in task and batch management."""
    CANCELLED_BY_USER = "Task was cancelled by a signal from the user."

Output

Bases: ABC

Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY. Most commonly used Output types are provided but if you want to define a new one, subclass this and implement ready(). This method is used to check if the output is ready/valid after task completion.

Source code in zipstrain/src/zipstrain/task_manager.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class Output(ABC):
    """Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY.
    Most commonly used Output types are provided but if you want to define a new one,
    subclass this and implement ready().
    This method is used to check if the output is ready/valid after task completion.
    """
    def __init__(self) -> None:
        self._value = None ## Will be set by the task when it completes
        self.task = None  ## Will be set when the output is registered to a task

    @property
    def value(self):
        return self._value

    @abstractmethod
    def ready(self) -> bool:
        ...

    def register_task(self, task: Task) -> None:
        """Registers the task that produces this output. In most cases, you won't need to override this.

        Args:
        task (Task): The task that produces this output.
        """
        self.task = task
register_task(task)

Registers the task that produces this output. In most cases, you won't need to override this.

Args: task (Task): The task that produces this output.

Source code in zipstrain/src/zipstrain/task_manager.py
227
228
229
230
231
232
233
def register_task(self, task: Task) -> None:
    """Registers the task that produces this output. In most cases, you won't need to override this.

    Args:
    task (Task): The task that produces this output.
    """
    self.task = task

PrepareCompareGenomeRunOutputs

Bases: Task

A Task that prepares the final output by merging all parquet files after all genome comparisons are done.

Source code in zipstrain/src/zipstrain/task_manager.py
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
class PrepareCompareGenomeRunOutputs(Task):
    """A Task that prepares the final output by merging all parquet files after all genome comparisons are done."""
    TEMPLATE_CMD="""
    mkdir -p <output-dir>/comps
    find "$(pwd)" -type f -name "Merged_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/comps/
    zipstrain utilities merge_parquet --input-dir <output-dir>/comps --output-file <output-dir>/all_comparisons.parquet
    rm -rf <output-dir>/comps
    """

    @property
    def pre_run(self) -> str:
        """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
pre_run property

Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs.

PrepareGeneCompareRunOutputs

Bases: Task

A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done.

Source code in zipstrain/src/zipstrain/task_manager.py
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
class PrepareGeneCompareRunOutputs(Task):
    """A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done."""
    TEMPLATE_CMD="""
    mkdir -p <output-dir>/gene_comps
    find "$(pwd)" -type f -name "Merged_gene_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/gene_comps/
    zipstrain utilities merge_parquet --input-dir <output-dir>/gene_comps --output-file <output-dir>/all_gene_comparisons.parquet
    rm -rf <output-dir>/gene_comps
    """

    @property
    def pre_run(self) -> str:
        """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
pre_run property

Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs.

ProfileBamTask

Bases: Task

A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command. The inputs to this task includes:

- bam-file: The input BAM file to be profiled.

- bed-file: The BED file specifying the regions to profile.

- sample-name: The name of the sample being processed.

- gene-range-table: A BED file specifying the gene ranges for the sample.

- num-workers: The number of concurrent workers to use for processing.

- genome-length-file: A file containing the lengths of the genomes in the reference fasta.

- stb-file: The STB file used for profiling.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
class ProfileBamTask(Task):
    """A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command.
    The inputs to this task includes:

        - bam-file: The input BAM file to be profiled.

        - bed-file: The BED file specifying the regions to profile.

        - sample-name: The name of the sample being processed.

        - gene-range-table: A BED file specifying the gene ranges for the sample.

        - num-workers: The number of concurrent workers to use for processing.

        - genome-length-file: A file containing the lengths of the genomes in the reference fasta.

        - stb-file: The STB file used for profiling.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""

    TEMPLATE_CMD="""
    ln -s <bam-file> input.bam
    ln -s <bed-file> bed_file.bed
    ln -s <gene-range-table> gene-range-table.bed
    samtools index <bam-file>
    zipstrain profile profile-single --bam-file input.bam \
    --bed-file bed_file.bed \
    --gene-range-table gene-range-table.bed \
    --stb-file <stb-file> \
    --num-workers <num-workers> \
    --output-dir .
    mv input_profile.parquet <sample-name>.parquet
    mv input_genome_stats.parquet <sample-name>_genome_stats.parquet
    mv input_gene_stats.parquet <sample-name>_gene_stats.parquet
    """

ProfileRunner

Bases: Runner

Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
class ProfileRunner(Runner):
    """
    Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.
    """
    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = SlurmBatch
            final_batch_factory = None
        else:
            batch_factory = LocalBatch
            final_batch_factory = None

        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )

    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    batch_id = f"batch_{self._batch_counter}"
                    self._batch_counter += 1
                    if self.batch_type == "slurm":
                        batch = self.batch_factory(
                            tasks=buffer,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=[],
                            slurm_config=self.slurm_config,
                        )
                    else:
                        batch = self.batch_factory(
                            tasks=buffer,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=[],
                        )
                    await self.batches_queue.put(batch)
                await self.batches_queue.put(None)
                self._batcher_done = True
                break
            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                batch_id = f"batch_{self._batch_counter}"
                self._batch_counter += 1
                if self.batch_type == "slurm":
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=[],
                        slurm_config=self.slurm_config,
                    )
                else:
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=[],
                    )
                await self.batches_queue.put(batch)
                buffer = []

ProfileTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file.

Source code in zipstrain/src/zipstrain/task_manager.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
class ProfileTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file."""
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        stb_file: str,
        profile_bed_file: str,
        gene_range_file: str,
        genome_length_file: str,
        num_procs: int = 4,
        breadth_min_cov: int = 1,
    ) -> None:
        super().__init__(data, yield_size)
        self.stb_file = pathlib.Path(stb_file)
        self.profile_bed_file = pathlib.Path(profile_bed_file)
        self.gene_range_file = pathlib.Path(gene_range_file)
        self.genome_length_file = pathlib.Path(genome_length_file)
        self.num_procs = num_procs
        self.breadth_min_cov = breadth_min_cov
        self.engine = container_engine
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")
        for path_attr in [
            self.stb_file,
            self.profile_bed_file,
            self.gene_range_file,
            self.genome_length_file,
        ]:
            if not path_attr.exists():
                raise FileNotFoundError(f"File {path_attr} does not exist.")

    def get_total_tasks(self) -> int:
        """Returns total number of profiles to be generated."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                inputs = {
                "bam-file": FileInput(row["bamfile"]),
                "sample-name": StringInput(row["sample_name"]),
                "stb-file": FileInput(self.stb_file),
                "bed-file": FileInput(self.profile_bed_file),
                "gene-range-table": FileInput(self.gene_range_file),
                "genome-length-file": FileInput(self.genome_length_file),
                "num-workers": IntInput(self.num_procs),
                "breadth-min-cov": IntInput(self.breadth_min_cov),
                }
                expected_outputs ={
                "profile":  FileOutput(row["sample_name"]+".parquet" ),
                "genome-stats": FileOutput(row["sample_name"]+"_genome_stats.parquet" ),
                "gene-stats": FileOutput(row["sample_name"]+"_gene_stats.parquet" ),
                }
                task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
async def generate_tasks(self) -> list[Task]:
    """Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            inputs = {
            "bam-file": FileInput(row["bamfile"]),
            "sample-name": StringInput(row["sample_name"]),
            "stb-file": FileInput(self.stb_file),
            "bed-file": FileInput(self.profile_bed_file),
            "gene-range-table": FileInput(self.gene_range_file),
            "genome-length-file": FileInput(self.genome_length_file),
            "num-workers": IntInput(self.num_procs),
            "breadth-min-cov": IntInput(self.breadth_min_cov),
            }
            expected_outputs ={
            "profile":  FileOutput(row["sample_name"]+".parquet" ),
            "genome-stats": FileOutput(row["sample_name"]+"_genome_stats.parquet" ),
            "gene-stats": FileOutput(row["sample_name"]+"_gene_stats.parquet" ),
            }
            task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of profiles to be generated.

Source code in zipstrain/src/zipstrain/task_manager.py
526
527
528
def get_total_tasks(self) -> int:
    """Returns total number of profiles to be generated."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

Runner

Bases: ABC

Base Runner class to manage task generation, batching, and execution.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
batch_factory Batch

The class that creates Batch instances. It should be a subclass of Batch or its subclasses.

required
final_batch_factory Batch

A callable that creates the final Batch instance.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
class Runner(ABC):
    """Base Runner class to manage task generation, batching, and execution.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        batch_factory (Batch): The class that creates Batch instances. It should be a subclass of Batch or its subclasses.
        final_batch_factory (Batch): A callable that creates the final Batch instance.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.

    """
    TERMINAL_BATCH_STATES = {Status.SUCCESS.value, Status.FAILED.value}
    def __init__(self,
                    run_dir: str | pathlib.Path,
                    task_generator: TaskGenerator,
                    container_engine: Engine,
                    batch_factory: Batch,
                    final_batch_factory: Batch,
                    max_concurrent_batches: int = 1,
                    poll_interval: float = 1.0,
                    tasks_per_batch: int = 10,
                    batch_type: str = "local",
                    slurm_config: SlurmConfig | None = None,
                    max_retries: int = 3,
                    ) -> None:
        self.run_dir = pathlib.Path(run_dir)
        self.run_dir.mkdir(parents=True, exist_ok=True)
        self.task_generator = task_generator
        self.container_engine = container_engine
        self.max_concurrent_batches = max_concurrent_batches
        self.poll_interval = poll_interval
        self.tasks_per_batch = tasks_per_batch
        self.batch_type = batch_type
        self.slurm_config = slurm_config
        self.tasks_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches * tasks_per_batch)
        self.batches_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches)
        self._finished_batches_count = 0
        self._success_batches_count = 0
        self._produced_tasks_count = 0
        self._active_batches: list[Batch] = []
        self._batch_counter = 0
        self._batcher_done = False
        self._final_batch_created = False
        self.batch_factory = batch_factory
        self.final_batch_factory = final_batch_factory
        self._failed_batches_count = 0
        self.max_retries = max_retries
        self.total_expected_tasks = self.task_generator.get_total_tasks()
        self.total_expected_batches = (self.total_expected_tasks + tasks_per_batch - 1) // tasks_per_batch
        self._shutdown_event = asyncio.Event()
        self._shutdown_initiated = False

    async def _refill_tasks(self):
        """Repeatedly call task_generator until it returns an empty list. This feeds tasks into the tasks_queue and waits for the queue to have space if it's full in an async manner."""
        async for tasks in self.task_generator.generate_tasks():
            for task in tasks:
                await self.tasks_queue.put(task)
                self._produced_tasks_count += 1
        await self.tasks_queue.put(None)

    @abstractmethod
    async def _batcher(self):
        ...

    def _create_final_batch(self) -> Batch|None:
        """Creates the final batch using the final_batch_factory callable."""
        return None


    async def _shutdown(self):
        """Cancel all active batches and signal shutdown."""
        if self._shutdown_initiated:
            return
        self._shutdown_initiated = True
        console.print("[yellow]Shutdown requested. Cancelling active jobs...[/]")

        for batch in list(self._active_batches):
            try:
                await batch.cancel()  
            except Exception as e:
                console.print(f"[red]Error cancelling batch {batch.id}: {e}[/]")

        # Signal the main loop to stop
        self._shutdown_event.set()

    async def run(self):
        """
        Run the producer, batcher and worker coroutines and present a live UI while working.
        Runs the task generator to produce tasks, batches them using the batcher,
        and executes batches with up to [max_concurrent_batches] parallel workers.
        UI: displays an overall panel (produced/finished counts), active batch Progress bars,
        and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

        """
        asyncio.create_task(self._batcher())
        asyncio.create_task(self._refill_tasks())
        semaphore = asyncio.Semaphore(self.max_concurrent_batches)
        file_semaphore = asyncio.Semaphore(20)
        async def run_batch(batch: Batch):
            async with semaphore:
                while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
                    current_attempt = batch.retry_count + 1
                    if current_attempt > 1:
                        await batch._append_batch_log("RETRY", f"starting attempt {current_attempt}")
                    try:
                        await batch.run()
                    except Exception as e:
                        await batch._append_batch_log(
                            "ERROR",
                            f"attempt {current_attempt} raised: {e}",
                        )
                        if batch.status != Status.FAILED.value:
                            await batch._set_status(Status.FAILED.value, f"unhandled exception: {e}")
                    await batch.log_progress()
                    if batch.status == Status.SUCCESS.value:
                        break
                    else:
                        batch.retry_count += 1

                self._finished_batches_count += 1

                if batch.status == Status.SUCCESS.value:
                    self._success_batches_count += 1
                    await batch._append_batch_log(
                        "SUMMARY",
                        f"terminal=success retries_used={batch.retry_count}",
                    )

                elif batch.status == Status.FAILED.value:
                    self._failed_batches_count += 1
                    await batch._append_batch_log(
                        "SUMMARY",
                        f"terminal=failed retries_used={batch.retry_count}",
                    )
                else:
                    await batch._append_batch_log(
                        "SUMMARY",
                        f"terminal={batch.status} retries_used={batch.retry_count}",
                    )

                if batch in self._active_batches:
                    self._active_batches.remove(batch)

        # Rich progress objects

        overall_progress = Progress(
            TextColumn(f"[bold white]{type(self).__name__}[/]"),
            BarColumn(),
            TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
            TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
            TimeElapsedColumn(),
            expand=True,
        )
        overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)

        batch_progress = Progress(
            TextColumn("[bold white]{task.fields[batch_id]}[/]"),
            BarColumn(),
            TextColumn("{task.completed}/{task.total}"),
            TextColumn("• {task.fields[status]}"),
            TimeElapsedColumn(),
            expand=True,
        )

        batch_to_progress_id: dict[Batch, int] = {}
        batch_task_totals: dict[Batch, int] = {}

        body = Panel(Group(
            Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
            Panel(overall_progress, title="Overall Progress"),
            Panel(batch_progress, title="Active Batches", height=10),
            Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
        ), border_style="magenta")

        loop = asyncio.get_running_loop()
        for sig in (signal.SIGINT, signal.SIGTERM):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))

        with Live(body, console=console, refresh_per_second=2) as live:
            while not self._shutdown_event.is_set():
                await self._update_statuses()
                if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
                    if not self._final_batch_created:
                        final_batch = self._create_final_batch()
                        if final_batch is not None:
                            await self.batches_queue.put(final_batch)
                            self._final_batch_created = True
                        else:
                            self._final_batch_created = True
                            break
                    else:
                        break
                while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty(): 
                    batch = await self.batches_queue.get()
                    if batch is not None:
                        batch._set_file_semaphore(file_semaphore)
                        self._active_batches.append(batch)
                        asyncio.create_task(run_batch(batch))
                # Update overall progress fields
                overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)  
                # Add newly queued batches into UI


                for batch in list(self._active_batches):
                    if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
                        total = len(batch.tasks) if batch.tasks else 1
                        task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
                        batch_to_progress_id[batch] = task_id
                        batch_task_totals[batch] = total
                # Remove finished batches from UI
                for batch, tid in list(batch_to_progress_id.items()):
                    if batch.status in self.TERMINAL_BATCH_STATES:
                        try:
                            batch_progress.remove_task(tid)
                        except Exception:
                            pass
                        del batch_to_progress_id[batch]
                        if batch in batch_task_totals:
                            del batch_task_totals[batch]
                # Update per-batch progress
                for batch, tid in batch_to_progress_id.items():
                    completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
                    total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
                    batch_progress.update(tid, completed=completed, total=total, status=batch.status)
                # Update system panel
                body = Panel(Group(
                    Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
                    Panel(overall_progress, title="Overall Progress"),
                    Panel(batch_progress, title="Active Batches"),
                    Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
                ), border_style="magenta")
                live.update(body)
                await asyncio.sleep(self.poll_interval)

        # final UI summary
        console.clear()
        total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
        summary = Panel(
            f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{self.task_generator.get_total_tasks()/self.tasks_per_batch} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
            expand=True,
            title="Summary",
            border_style="green",
        )
        console.print(summary)

    async def _update_statuses(self):
        active = [batch for batch in self._active_batches if batch.status not in self.TERMINAL_BATCH_STATES]
        await asyncio.gather(*[batch.update_status() for batch in active])
        await asyncio.gather(*[batch.log_progress() for batch in active])

    def _make_system_stats_panel(self):
        """helpers to create a system stats panel for the live UI."""
        def usage_bar(label: str, percent: float, color: str):
            p = Progress(
                TextColumn(f"[bold]{label}[/]"),
                BarColumn(bar_width=None, complete_style=color),
                TextColumn(f"{percent:.1f}%"),
                expand=True,
            )
            p.add_task("", total=100, completed=percent)
            return Panel(p, expand=True, width=30)

        cpu = psutil.cpu_percent(interval=None)
        ram = psutil.virtual_memory().percent
        cpu_panel = usage_bar("CPU", cpu, "cyan")
        ram_panel = usage_bar("RAM", ram, "magenta")
        return Columns([cpu_panel, ram_panel], expand=True, equal=True, align="center")
run() async

Run the producer, batcher and worker coroutines and present a live UI while working. Runs the task generator to produce tasks, batches them using the batcher, and executes batches with up to [max_concurrent_batches] parallel workers. UI: displays an overall panel (produced/finished counts), active batch Progress bars, and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

Source code in zipstrain/src/zipstrain/task_manager.py
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
async def run(self):
    """
    Run the producer, batcher and worker coroutines and present a live UI while working.
    Runs the task generator to produce tasks, batches them using the batcher,
    and executes batches with up to [max_concurrent_batches] parallel workers.
    UI: displays an overall panel (produced/finished counts), active batch Progress bars,
    and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

    """
    asyncio.create_task(self._batcher())
    asyncio.create_task(self._refill_tasks())
    semaphore = asyncio.Semaphore(self.max_concurrent_batches)
    file_semaphore = asyncio.Semaphore(20)
    async def run_batch(batch: Batch):
        async with semaphore:
            while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
                current_attempt = batch.retry_count + 1
                if current_attempt > 1:
                    await batch._append_batch_log("RETRY", f"starting attempt {current_attempt}")
                try:
                    await batch.run()
                except Exception as e:
                    await batch._append_batch_log(
                        "ERROR",
                        f"attempt {current_attempt} raised: {e}",
                    )
                    if batch.status != Status.FAILED.value:
                        await batch._set_status(Status.FAILED.value, f"unhandled exception: {e}")
                await batch.log_progress()
                if batch.status == Status.SUCCESS.value:
                    break
                else:
                    batch.retry_count += 1

            self._finished_batches_count += 1

            if batch.status == Status.SUCCESS.value:
                self._success_batches_count += 1
                await batch._append_batch_log(
                    "SUMMARY",
                    f"terminal=success retries_used={batch.retry_count}",
                )

            elif batch.status == Status.FAILED.value:
                self._failed_batches_count += 1
                await batch._append_batch_log(
                    "SUMMARY",
                    f"terminal=failed retries_used={batch.retry_count}",
                )
            else:
                await batch._append_batch_log(
                    "SUMMARY",
                    f"terminal={batch.status} retries_used={batch.retry_count}",
                )

            if batch in self._active_batches:
                self._active_batches.remove(batch)

    # Rich progress objects

    overall_progress = Progress(
        TextColumn(f"[bold white]{type(self).__name__}[/]"),
        BarColumn(),
        TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
        TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
        TimeElapsedColumn(),
        expand=True,
    )
    overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)

    batch_progress = Progress(
        TextColumn("[bold white]{task.fields[batch_id]}[/]"),
        BarColumn(),
        TextColumn("{task.completed}/{task.total}"),
        TextColumn("• {task.fields[status]}"),
        TimeElapsedColumn(),
        expand=True,
    )

    batch_to_progress_id: dict[Batch, int] = {}
    batch_task_totals: dict[Batch, int] = {}

    body = Panel(Group(
        Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
        Panel(overall_progress, title="Overall Progress"),
        Panel(batch_progress, title="Active Batches", height=10),
        Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
    ), border_style="magenta")

    loop = asyncio.get_running_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))

    with Live(body, console=console, refresh_per_second=2) as live:
        while not self._shutdown_event.is_set():
            await self._update_statuses()
            if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
                if not self._final_batch_created:
                    final_batch = self._create_final_batch()
                    if final_batch is not None:
                        await self.batches_queue.put(final_batch)
                        self._final_batch_created = True
                    else:
                        self._final_batch_created = True
                        break
                else:
                    break
            while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty(): 
                batch = await self.batches_queue.get()
                if batch is not None:
                    batch._set_file_semaphore(file_semaphore)
                    self._active_batches.append(batch)
                    asyncio.create_task(run_batch(batch))
            # Update overall progress fields
            overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)  
            # Add newly queued batches into UI


            for batch in list(self._active_batches):
                if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
                    total = len(batch.tasks) if batch.tasks else 1
                    task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
                    batch_to_progress_id[batch] = task_id
                    batch_task_totals[batch] = total
            # Remove finished batches from UI
            for batch, tid in list(batch_to_progress_id.items()):
                if batch.status in self.TERMINAL_BATCH_STATES:
                    try:
                        batch_progress.remove_task(tid)
                    except Exception:
                        pass
                    del batch_to_progress_id[batch]
                    if batch in batch_task_totals:
                        del batch_task_totals[batch]
            # Update per-batch progress
            for batch, tid in batch_to_progress_id.items():
                completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
                total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
                batch_progress.update(tid, completed=completed, total=total, status=batch.status)
            # Update system panel
            body = Panel(Group(
                Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
                Panel(overall_progress, title="Overall Progress"),
                Panel(batch_progress, title="Active Batches"),
                Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
            ), border_style="magenta")
            live.update(body)
            await asyncio.sleep(self.poll_interval)

    # final UI summary
    console.clear()
    total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
    summary = Panel(
        f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{self.task_generator.get_total_tasks()/self.tasks_per_batch} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
        expand=True,
        title="Summary",
        border_style="green",
    )
    console.print(summary)

SlurmBatch

Bases: Batch

Batch that submits tasks to a Slurm job scheduler.

Parameters:

Name Type Description Default
tasks list[Task]

List of Task objects to be included in the batch.

required
id str

Unique identifier for the batch.

required
run_dir Path

Directory where the batch will be executed.

required
expected_outputs list[Output]

List of expected outputs for the batch.

required
slurm_config SlurmConfig

Configuration for Slurm job submission. Refer to SlurmConfig class for details.

required
Source code in zipstrain/src/zipstrain/task_manager.py
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
class SlurmBatch(Batch):
    """Batch that submits tasks to a Slurm job scheduler.

     Args:
        tasks (list[Task]): List of Task objects to be included in the batch.
        id (str): Unique identifier for the batch.
        run_dir (pathlib.Path): Directory where the batch will be executed.
        expected_outputs (list[Output]): List of expected outputs for the batch.
        slurm_config (SlurmConfig): Configuration for Slurm job submission. Refer to SlurmConfig class for details."""
    TEMPLATE_CMD = "#!/bin/bash\n"

    def __init__(self, tasks, id, run_dir, expected_outputs, slurm_config: SlurmConfig) -> None:
        super().__init__(tasks, id, run_dir, expected_outputs)
        self._check_slurm_works()
        self.slurm_config = slurm_config
        self._script = self.TEMPLATE_CMD + self.slurm_config.to_slurm_args() + "\nset -euo pipefail\n"
        self._job_id = None

    def _check_slurm_works(self) -> None:
        """Checks if Slurm commands are available on the system."""
        try:
            subprocess.run(["sbatch", "--version"], capture_output=True, text=True, check=True)
            subprocess.run(["sacct", "--version"], capture_output=True, text=True, check=True)
        except Exception:
            raise EnvironmentError("Slurm does not seem to be available or configured properly on this system.")

    async def cancel(self) -> None:
        """Cancel a running or submitted Slurm job."""
        if self._job_id:
            proc = await asyncio.create_subprocess_exec(
                "scancel", self._job_id,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            await proc.wait()

        await self._set_status(Status.FAILED.value, Messages.CANCELLED_BY_USER.value)

    async def run(self) -> None:
        """This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
        This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
        """

        if self.status != Status.SUCCESS:
            self.batch_dir.mkdir(parents=True, exist_ok=True)
            await self._set_status(Status.RUNNING.value, "batch submission started")
            # create task directories and initialize .status if needed

            for task in self.tasks:
                if task.status != Status.SUCCESS.value:
                    if task.task_dir.exists():
                        shutil.rmtree(task.task_dir) # Because it must have failed and we don't want those remnants
                    task.task_dir.mkdir(parents=True)
                    await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)

            # write the batch script (all tasks included)

            batch_path = self.batch_dir / f"{self.id}.batch"
            script=self._script
            for task in self.tasks:
                if task.status != Status.SUCCESS.value:
                    script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

            await write_file(batch_path, script, self.file_semaphore)

            proc = await asyncio.create_subprocess_exec(
                "sbatch","--parsable", batch_path.name,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
                cwd=str(self.batch_dir),
            )
            out_bytes, out_err = await proc.communicate()
            out = out_bytes.decode().strip() if out_bytes else ""
            if proc.returncode == 0:
                try:
                    self._job_id = self._parse_job_id(out)
                    await self._set_status(Status.SUBMITTED.value, f"job_id={self._job_id}")
                    await self._wait_to_finish()
                except Exception as e:
                    await self._set_status(Status.FAILED.value, f"slurm submission/monitoring error: {e}")
            else:
                await self._set_status(Status.FAILED.value, "failed to submit slurm job")

            if self._status == Status.SUCCESS.value and self.outputs_ready():
                self.cleanup()
                await self._set_status(Status.SUCCESS.value, "batch outputs validated")
            else:
                await self._set_status(Status.FAILED.value, "missing expected outputs")

        else:
            if self.status == Status.SUCCESS.value and self.outputs_ready():
                await self._set_status(Status.SUCCESS.value, "batch already successful")
            else:
                await self._set_status(Status.FAILED.value, "batch marked success but outputs missing")

    def _parse_job_id(self, sbatch_output: str) -> str:
        if match := re.search(r"(\d+)", sbatch_output):
            return match.group(1)
        else:
            raise ValueError("Could not parse job ID from sbatch output.")

    async def _wait_to_finish(self,sleep_duration:float=5.0):
        while self.status not in (Status.SUCCESS.value, Status.FAILED.value):
            await self.update_status()
            await asyncio.sleep(sleep_duration)

    async def update_status(self):
        if self._job_id is None:
            self._status=Status.NOT_STARTED.value
        else:
            await self._collect_task_status()
            out= await asyncio.create_subprocess_exec(
                "sacct", "-j", self._job_id, "--format=State", "--noheader","--allocations",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            out_bytes, _ = await out.communicate()
            if out_bytes:
                state = out_bytes.decode().strip()
                if state in ["FAILED", "CANCELLED", "TIMEOUT"]:
                    await self._set_status(Status.FAILED.value, f"slurm_state={state}")
                    raise Exception(f"Job {self._job_id} failed with state: {state}")

                elif state in ["RUNNING", "COMPLETING"]:
                    await self._set_status(Status.RUNNING.value, f"slurm_state={state}")

                elif state in ["COMPLETED"]:
                    if self.outputs_ready():
                        await self._set_status(Status.SUCCESS.value, f"slurm_state={state}", persist_status=Status.DONE.value)
                    else:
                        await self._set_status(Status.FAILED.value, "slurm completed but outputs missing")
                        raise Exception(f"Job {self._job_id} finished but at least one output is missing.")

                else:
                    await self._set_status(Status.PENDING.value, f"slurm_state={state}")
cancel() async

Cancel a running or submitted Slurm job.

Source code in zipstrain/src/zipstrain/task_manager.py
930
931
932
933
934
935
936
937
938
939
940
async def cancel(self) -> None:
    """Cancel a running or submitted Slurm job."""
    if self._job_id:
        proc = await asyncio.create_subprocess_exec(
            "scancel", self._job_id,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
        await proc.wait()

    await self._set_status(Status.FAILED.value, Messages.CANCELLED_BY_USER.value)
run() async

This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion. This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.

Source code in zipstrain/src/zipstrain/task_manager.py
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
async def run(self) -> None:
    """This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
    This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
    """

    if self.status != Status.SUCCESS:
        self.batch_dir.mkdir(parents=True, exist_ok=True)
        await self._set_status(Status.RUNNING.value, "batch submission started")
        # create task directories and initialize .status if needed

        for task in self.tasks:
            if task.status != Status.SUCCESS.value:
                if task.task_dir.exists():
                    shutil.rmtree(task.task_dir) # Because it must have failed and we don't want those remnants
                task.task_dir.mkdir(parents=True)
                await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)

        # write the batch script (all tasks included)

        batch_path = self.batch_dir / f"{self.id}.batch"
        script=self._script
        for task in self.tasks:
            if task.status != Status.SUCCESS.value:
                script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

        await write_file(batch_path, script, self.file_semaphore)

        proc = await asyncio.create_subprocess_exec(
            "sbatch","--parsable", batch_path.name,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            cwd=str(self.batch_dir),
        )
        out_bytes, out_err = await proc.communicate()
        out = out_bytes.decode().strip() if out_bytes else ""
        if proc.returncode == 0:
            try:
                self._job_id = self._parse_job_id(out)
                await self._set_status(Status.SUBMITTED.value, f"job_id={self._job_id}")
                await self._wait_to_finish()
            except Exception as e:
                await self._set_status(Status.FAILED.value, f"slurm submission/monitoring error: {e}")
        else:
            await self._set_status(Status.FAILED.value, "failed to submit slurm job")

        if self._status == Status.SUCCESS.value and self.outputs_ready():
            self.cleanup()
            await self._set_status(Status.SUCCESS.value, "batch outputs validated")
        else:
            await self._set_status(Status.FAILED.value, "missing expected outputs")

    else:
        if self.status == Status.SUCCESS.value and self.outputs_ready():
            await self._set_status(Status.SUCCESS.value, "batch already successful")
        else:
            await self._set_status(Status.FAILED.value, "batch marked success but outputs missing")

SlurmConfig

Bases: BaseModel

Configuration model for Slurm batch jobs.

Attributes:

Name Type Description
time str

Time limit for the job in HH:MM:SS format.

tasks int

Number of tasks.

mem int

Memory in GB.

additional_params dict

Additional SLURM parameters as key-value pairs.

NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of "#SBATCH --cpus-per-task=4" to the sbatch script.

Source code in zipstrain/src/zipstrain/task_manager.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class SlurmConfig(BaseModel):
    """Configuration model for Slurm batch jobs.

    Attributes:
        time (str): Time limit for the job in HH:MM:SS format.
        tasks (int): Number of tasks.
        mem (int): Memory in GB.
        additional_params (dict): Additional SLURM parameters as key-value pairs.

    NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form
    of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of
    "#SBATCH --cpus-per-task=4" to the sbatch script.

    """
    time: str = Field(description="Time limit for the job.")
    tasks: int = Field(default=1, description="Number of tasks.")
    mem: int = Field(default=4, description="Memory in GB.")
    additional_params: dict[str, str] = Field(default_factory=dict, description="Additional SLURM parameters as key-value pairs.")

    @field_validator("time")
    def validate_time(cls, v):
        """Validate time format HH:MM:SS (H..HHH allowed)."""
        if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
            raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
        return v

    def to_slurm_args(self) -> str:
        """Generates the slurm batch file header form the configuration object"""
        args = [
            f"#SBATCH --time={self.time}",
            f"#SBATCH --ntasks={self.tasks}",
            f"#SBATCH --mem={self.mem}G",
        ]
        for key, value in self.additional_params.items():
            args.append(f"#SBATCH --{key}={value}")
        return "\n".join(args)

    @classmethod
    def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
        """Load SlurmConfig from a JSON file."""
        path = pathlib.Path(json_path)
        if not path.exists():
            raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
        return cls.model_validate_json(path.read_text())
from_json(json_path) classmethod

Load SlurmConfig from a JSON file.

Source code in zipstrain/src/zipstrain/task_manager.py
101
102
103
104
105
106
107
@classmethod
def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
    """Load SlurmConfig from a JSON file."""
    path = pathlib.Path(json_path)
    if not path.exists():
        raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
    return cls.model_validate_json(path.read_text())
to_slurm_args()

Generates the slurm batch file header form the configuration object

Source code in zipstrain/src/zipstrain/task_manager.py
90
91
92
93
94
95
96
97
98
99
def to_slurm_args(self) -> str:
    """Generates the slurm batch file header form the configuration object"""
    args = [
        f"#SBATCH --time={self.time}",
        f"#SBATCH --ntasks={self.tasks}",
        f"#SBATCH --mem={self.mem}G",
    ]
    for key, value in self.additional_params.items():
        args.append(f"#SBATCH --{key}={value}")
    return "\n".join(args)
validate_time(v)

Validate time format HH:MM:SS (H..HHH allowed).

Source code in zipstrain/src/zipstrain/task_manager.py
83
84
85
86
87
88
@field_validator("time")
def validate_time(cls, v):
    """Validate time format HH:MM:SS (H..HHH allowed)."""
    if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
        raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
    return v

Status

Bases: StrEnum

Enumeration of possible task and batch statuses.

Source code in zipstrain/src/zipstrain/task_manager.py
138
139
140
141
142
143
144
145
146
147
class Status(StrEnum):
    """Enumeration of possible task and batch statuses."""
    BATCH_NOT_ASSIGNED = "batch_not_assigned"
    NOT_STARTED = "not_started"
    RUNNING = "running"
    DONE = "done"      # Means a unit is finished running but the outputs are not validated
    FAILED = "failed"
    SUBMITTED = "submitted"
    SUCCESS = "success" # Means a unit is done and the outputs exist
    PENDING = "pending"

StringInput

Bases: Input

This is used when the input is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
182
183
184
185
186
187
188
189
190
191
192
class StringInput(Input):
    """This is used when the input is a string."""

    def validate(self) -> None:
        """Validate that the input value is a string."""
        if not isinstance(self.value, str):
            raise ValueError(f"Input value {self.value!r} is not a string.")

    def get_value(self) -> str:
        """Returns the string value."""
        return str(self.value)
get_value()

Returns the string value.

Source code in zipstrain/src/zipstrain/task_manager.py
190
191
192
def get_value(self) -> str:
    """Returns the string value."""
    return str(self.value)
validate()

Validate that the input value is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
185
186
187
188
def validate(self) -> None:
    """Validate that the input value is a string."""
    if not isinstance(self.value, str):
        raise ValueError(f"Input value {self.value!r} is not a string.")

StringOutput

Bases: Output

This is used when the output is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
278
279
280
281
282
283
284
285
286
287
class StringOutput(Output):
    """This is used when the output is a string."""
    def ready(self) -> bool:
        """Check if the output value is a string."""
        if isinstance(self._value, str):
            return True
        elif self._value is not None:
            raise ValueError(f"Output value for task {self.task.id} is not a string.")
        else:
            return False
ready()

Check if the output value is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
280
281
282
283
284
285
286
287
def ready(self) -> bool:
    """Check if the output value is a string."""
    if isinstance(self._value, str):
        return True
    elif self._value is not None:
        raise ValueError(f"Output value for task {self.task.id} is not a string.")
    else:
        return False

Task

Bases: ABC

Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>. As an example, if a task has an input file called "input-file" and an expected output file called "output-file", the TEMPLATE_CMD could be something like: TEMPLATE_CMD = "some_command --input --output " the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.

Source code in zipstrain/src/zipstrain/task_manager.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
class Task(ABC):
    """Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this
    and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>.
    As an example, if a task has an input file called "input-file" and an expected output file called "output-file",
    the TEMPLATE_CMD could be something like:
        TEMPLATE_CMD = "some_command --input <input-file> --output <output-file>"
    the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.
    """
    TEMPLATE_CMD = ""

    def __init__(
        self,
        id: str,
        inputs: dict[str, Input | Output],
        expected_outputs: dict[str, Output] ,
        engine: Engine,
        batch_obj: Batch | None = None,
        file_semaphore: asyncio.Semaphore | None = None
    ) -> None:
        self.id = id
        self.inputs = inputs
        self.expected_outputs = expected_outputs
        self._batch_obj = batch_obj
        self.engine = engine
        self._status = self._get_initial_status()
        self.file_semaphore = file_semaphore

    def map_io(self) -> None:
        """Maps inputs and expected outputs to the command template. Note that when this method is called,
        all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
        However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
        """
        cmd = self.TEMPLATE_CMD
        for key, value in self.inputs.items():
            cmd = cmd.replace(f"<{key}>", value.get_value())
        # if any placeholders remain, report them

        for handle, output in self.expected_outputs.items():
            cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
        remaining = re.findall(r"<\w+>", cmd)
        if remaining:
            raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
        self._command = cmd

    @property
    def batch_dir(self) -> pathlib.Path:
        """Returns the batch directory path. Raises an error if the task is not associated with any batch yet."""
        if self._batch_obj is None:
            raise ValueError(f"Task {self.id} is not associated with any batch yet.")
        return self._batch_obj.batch_dir

    @property
    def task_dir(self) -> pathlib.Path:
        """Returns the task directory path."""
        return self.batch_dir / self.id

    @property
    def command(self) -> str:
        """Returns the command to be executed, wrapped with the engine if applicable."""
        file_inputs = [v for v in self.inputs.values() if isinstance(v, FileInput)]
        return self.engine.wrap(self._command, file_inputs)

    @property
    def pre_run(self) -> str:
        """Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like
        batch aggregation."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self.task_dir.absolute()}"

    @property
    def status(self) -> str:
        """Returns the current status of the task."""
        return self._status

    @property
    def post_run(self) -> str:
        """Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like
        batch aggregation."""
        return f"cd {self.batch_dir.absolute()} && echo {Status.DONE.value} > {self.task_dir.absolute()}/.status"

    async def get_status(self) -> str:
        """Asynchronously reads the task status from the .status file in the task directory."""
        status_path = self.task_dir / ".status"
        # read the status file if it exists

        if status_path.exists():
            raw = await read_file(status_path, self.file_semaphore)
            self._status = raw.strip()

            # if task reported 'done', check outputs to decide success/failure
            if self._status == Status.DONE.value:
                all_ready = True
                try:
                    for output in self.expected_outputs.values():
                        if not output.ready():
                            all_ready = False
                            break
                except Exception:
                    all_ready = False

                if all_ready or self._batch_obj._cleaned_up:
                    self._status = Status.SUCCESS.value

                else:
                    self._status = Status.FAILED.value
                    raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")

        return self._status

    def _get_initial_status(self) -> str:
        """Returns the initial status of the task based on the presence of the batch and task directories."""
        if self._batch_obj is None:
            return Status.BATCH_NOT_ASSIGNED.value
        if not self.task_dir.exists():
            return Status.NOT_STARTED.value
        status_file = self.task_dir / ".status"
        with open(status_file, mode="r") as f:
            status_as_written = f.read().strip()
        if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
            all_ready = True
            try:
                for output in self.expected_outputs.values():
                    if not output.ready():
                        all_ready = False
                        break
            except Exception:
                all_ready = False

            if all_ready:
                return Status.SUCCESS.value
            else:
                return Status.FAILED.value
batch_dir property

Returns the batch directory path. Raises an error if the task is not associated with any batch yet.

command property

Returns the command to be executed, wrapped with the engine if applicable.

post_run property

Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like batch aggregation.

pre_run property

Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like batch aggregation.

status property

Returns the current status of the task.

task_dir property

Returns the task directory path.

get_status() async

Asynchronously reads the task status from the .status file in the task directory.

Source code in zipstrain/src/zipstrain/task_manager.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
async def get_status(self) -> str:
    """Asynchronously reads the task status from the .status file in the task directory."""
    status_path = self.task_dir / ".status"
    # read the status file if it exists

    if status_path.exists():
        raw = await read_file(status_path, self.file_semaphore)
        self._status = raw.strip()

        # if task reported 'done', check outputs to decide success/failure
        if self._status == Status.DONE.value:
            all_ready = True
            try:
                for output in self.expected_outputs.values():
                    if not output.ready():
                        all_ready = False
                        break
            except Exception:
                all_ready = False

            if all_ready or self._batch_obj._cleaned_up:
                self._status = Status.SUCCESS.value

            else:
                self._status = Status.FAILED.value
                raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")

    return self._status
map_io()

Maps inputs and expected outputs to the command template. Note that when this method is called, all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries. However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.

Source code in zipstrain/src/zipstrain/task_manager.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def map_io(self) -> None:
    """Maps inputs and expected outputs to the command template. Note that when this method is called,
    all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
    However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
    """
    cmd = self.TEMPLATE_CMD
    for key, value in self.inputs.items():
        cmd = cmd.replace(f"<{key}>", value.get_value())
    # if any placeholders remain, report them

    for handle, output in self.expected_outputs.items():
        cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
    remaining = re.findall(r"<\w+>", cmd)
    if remaining:
        raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
    self._command = cmd

TaskGenerator

Bases: ABC

Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner. Some important concepts:

  • generate_tasks() is an async generator that yields lists of Task objects.

  • yield_size determines how many tasks are generated and yielded at a time.

  • get_total_tasks() returns the total number of tasks that can be generated.

Source code in zipstrain/src/zipstrain/task_manager.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
class TaskGenerator(ABC):
    """Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class 
    should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner.
    Some important concepts:

    - generate_tasks() is an async generator that yields lists of Task objects.

    - yield_size determines how many tasks are generated and yielded at a time.

    - get_total_tasks() returns the total number of tasks that can be generated.

    """
    def __init__(self,
                 data,
                 yield_size:int,

                 ):
        self.data = data
        self.yield_size = yield_size
        self._total_tasks = self.get_total_tasks()

    @abstractmethod
    async def generate_tasks(self) -> list[Task]:
        pass

    @abstractmethod
    def get_total_tasks(self) -> int:
        pass

get_cpu_usage()

Returns the current CPU usage percentage.

Source code in zipstrain/src/zipstrain/task_manager.py
1042
1043
1044
def get_cpu_usage():
    """Returns the current CPU usage percentage."""
    return psutil.cpu_percent(interval=0.1)

get_memory_usage()

Returns the current memory usage percentage.

Source code in zipstrain/src/zipstrain/task_manager.py
1046
1047
1048
def get_memory_usage():
    """Returns the current memory usage percentage."""
    return psutil.virtual_memory().percent

lazy_run_compares(run_dir, container_engine, comps_db=None, tasks_per_batch=10, max_concurrent_batches=1, poll_interval=5.0, execution_mode='local', slurm_config=None, duckdb_memory_limit=None, duckdb_threads=None, compare_engine='polars', calculate='all')

A helper function to quickly set up and run a CompareRunner with given parameters.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
comps_db GenomeComparisonDatabase | None

An instance of GenomeComparisonDatabase containing comparison data.

None
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 5.0.

5.0
execution_mode str

Execution mode, either "local" or "slurm". Default is "local".

'local'
duckdb_threads int | None

Optional DuckDB thread cap passed to compare tasks.

None
compare_engine str

Compare engine passed to single compare tasks ("polars" or "duckdb").

'polars'
Source code in zipstrain/src/zipstrain/task_manager.py
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
def lazy_run_compares(
    run_dir: str | pathlib.Path,
    container_engine: Engine,
    comps_db: database.GenomeComparisonDatabase|None = None,
    tasks_per_batch: int = 10,
    max_concurrent_batches: int = 1,
    poll_interval: float = 5.0,
    execution_mode: str = "local",
    slurm_config: SlurmConfig | None = None,
    duckdb_memory_limit: str | None = None,
    duckdb_threads: int | None = None,
    compare_engine: str = "polars",
    calculate: str = "all",
) -> None:
    """A helper function to quickly set up and run a CompareRunner with given parameters.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        container_engine (Engine): An instance of Engine to wrap task commands.
        comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
        execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
        duckdb_threads (int | None): Optional DuckDB thread cap passed to compare tasks.
        compare_engine (str): Compare engine passed to single compare tasks ("polars" or "duckdb").
    """
    task_generator = CompareTaskGenerator(
        data=comps_db.to_complete_input_table(),
        yield_size=tasks_per_batch,
        container_engine=container_engine,
        comp_config=comps_db.config,
        duckdb_memory_limit=duckdb_memory_limit,
        duckdb_threads=duckdb_threads,
        compare_engine=compare_engine,
        calculate=calculate,
    )
    if execution_mode=="local":
        batch_type="local"
    elif execution_mode=="slurm":
        batch_type="slurm"
    else:
        raise ValueError(f"Unknown execution mode: {execution_mode}")
    runner = CompareRunner(
        run_dir=pathlib.Path(run_dir),
        task_generator=task_generator,
        container_engine=container_engine,
        max_concurrent_batches=max_concurrent_batches,
        poll_interval=poll_interval,
        tasks_per_batch=tasks_per_batch,
        batch_type=batch_type,
        slurm_config=slurm_config,
    )
    asyncio.run(runner.run())

lazy_run_gene_compares(run_dir, container_engine, comps_db=None, tasks_per_batch=10, max_concurrent_batches=1, poll_interval=5.0, execution_mode='local', slurm_config=None, ani_method='popani', duckdb_memory_limit=None, duckdb_threads=None, compare_engine='polars')

A helper function to quickly set up and run a GeneCompareRunner with given parameters.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
comps_db GenomeComparisonDatabase | None

An instance of GenomeComparisonDatabase containing comparison data.

None
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 5.0.

5.0
execution_mode str

Execution mode, either "local" or "slurm". Default is "local".

'local'
ani_method str

ANI calculation method to use. Default is "popani".

'popani'
duckdb_threads int | None

Optional DuckDB thread cap passed to compare tasks.

None
compare_engine str

Compare engine passed to single compare tasks ("polars" or "duckdb").

'polars'
Source code in zipstrain/src/zipstrain/task_manager.py
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
def lazy_run_gene_compares(
    run_dir: str | pathlib.Path,
    container_engine: Engine,
    comps_db: database.GeneComparisonDatabase | None = None,
    tasks_per_batch: int = 10,
    max_concurrent_batches: int = 1,
    poll_interval: float = 5.0,
    execution_mode: str = "local",
    slurm_config: SlurmConfig | None = None,
    ani_method: str = "popani",
    duckdb_memory_limit: str | None = None,
    duckdb_threads: int | None = None,
    compare_engine: str = "polars",
) -> None:
    """A helper function to quickly set up and run a GeneCompareRunner with given parameters.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        container_engine (Engine): An instance of Engine to wrap task commands.
        comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
        execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
        ani_method (str): ANI calculation method to use. Default is "popani".
        duckdb_threads (int | None): Optional DuckDB thread cap passed to compare tasks.
        compare_engine (str): Compare engine passed to single compare tasks ("polars" or "duckdb").
    """
    task_generator = GeneCompareTaskGenerator(
        data=comps_db.to_complete_input_table(),
        yield_size=tasks_per_batch,
        container_engine=container_engine,
        comp_config=comps_db.config,
        ani_method=ani_method,
        duckdb_memory_limit=duckdb_memory_limit,
        duckdb_threads=duckdb_threads,
        compare_engine=compare_engine,
    )
    if execution_mode=="local":
        batch_type="local"
    elif execution_mode=="slurm":
        batch_type="slurm"
    else:
        raise ValueError(f"Unknown execution mode: {execution_mode}")
    runner = GeneCompareRunner(
        run_dir=pathlib.Path(run_dir),
        task_generator=task_generator,
        container_engine=container_engine,
        max_concurrent_batches=max_concurrent_batches,
        poll_interval=poll_interval,
        tasks_per_batch=tasks_per_batch,
        batch_type=batch_type,
        slurm_config=slurm_config,
    )
    asyncio.run(runner.run())