Skip to content

API

ZipStrain provides a Python API for programmatic access, enabling integration into larger bioinformatics pipelines and workflows. The API also allows for downstream analyses and visualizations.


Database

zipstrain.database

This module provides classes and functions to manage profile and comparison databases for efficient data handling. The ProfileDatabase class manages profiles, while the GenomeComparisonDatabase class handles comparisons between profiles. See the documentation of each class for more details.

GeneComparisonConfig

Bases: BaseModel

Configuration for gene-level comparisons between profiles.

Attributes:

Name Type Description
scope str

The scope of the comparison in format "GENOME:GENE" (e.g., "all:gene1" compares gene1 across all genomes, "genome1:gene1" compares gene1 only in genome1 across samples).

null_model_loc str

Location of the null model parquet file.

stb_file_loc str

Location of the scaffold-to-genome mapping file.

min_cov int

Minimum coverage threshold for considering a position.

min_gene_compare_len int

Minimum gene length required for comparison.

Source code in zipstrain/src/zipstrain/database.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class GeneComparisonConfig(BaseModel):
    """
    Configuration for gene-level comparisons between profiles.

    Attributes:
        scope (str): The scope of the comparison in format "GENOME:GENE" (e.g., "all:gene1" compares gene1 across all genomes, "genome1:gene1" compares gene1 only in genome1 across samples).
        null_model_loc (str): Location of the null model parquet file.
        stb_file_loc (str): Location of the scaffold-to-genome mapping file.
        min_cov (int): Minimum coverage threshold for considering a position.
        min_gene_compare_len (int): Minimum gene length required for comparison.
    """
    model_config = ConfigDict(extra="forbid")
    scope: str = Field(description="Scope in format GENOME:GENE (e.g., 'all:gene1', 'genome1:gene1')")
    null_model_loc: str = Field(description="Location of the null model parquet file")
    stb_file_loc: str = Field(description="Location of the scaffold-to-genome mapping file")
    min_cov: int = Field(default=5, description="Minimum coverage threshold")
    min_gene_compare_len: int = Field(default=100, description="Minimum gene length for comparison")

    @field_validator("scope")
    @classmethod
    def validate_scope(cls, v: str) -> str:
        """Validate that scope follows GENOME:GENE format."""
        if ":" not in v:
            raise ValueError("Scope must be in format 'GENOME:GENE' (e.g., 'all:gene1' or 'genome1:gene1')")
        parts = v.split(":")
        if len(parts) != 2:
            raise ValueError("Scope must have exactly one ':' separator")
        genome_part, gene_part = parts
        if not genome_part or not gene_part:
            raise ValueError("Both genome and gene parts must be non-empty")
        return v

    def get_genome_scope(self) -> str:
        """Extract the genome part from the scope."""
        return self.scope.split(":")[0]

    def get_gene_scope(self) -> str:
        """Extract the gene part from the scope."""
        return self.scope.split(":")[1]
get_gene_scope()

Extract the gene part from the scope.

Source code in zipstrain/src/zipstrain/database.py
255
256
257
def get_gene_scope(self) -> str:
    """Extract the gene part from the scope."""
    return self.scope.split(":")[1]
get_genome_scope()

Extract the genome part from the scope.

Source code in zipstrain/src/zipstrain/database.py
251
252
253
def get_genome_scope(self) -> str:
    """Extract the genome part from the scope."""
    return self.scope.split(":")[0]
validate_scope(v) classmethod

Validate that scope follows GENOME:GENE format.

Source code in zipstrain/src/zipstrain/database.py
237
238
239
240
241
242
243
244
245
246
247
248
249
@field_validator("scope")
@classmethod
def validate_scope(cls, v: str) -> str:
    """Validate that scope follows GENOME:GENE format."""
    if ":" not in v:
        raise ValueError("Scope must be in format 'GENOME:GENE' (e.g., 'all:gene1' or 'genome1:gene1')")
    parts = v.split(":")
    if len(parts) != 2:
        raise ValueError("Scope must have exactly one ':' separator")
    genome_part, gene_part = parts
    if not genome_part or not gene_part:
        raise ValueError("Both genome and gene parts must be non-empty")
    return v

GeneComparisonDatabase

A database for managing gene-level comparisons between profiles.

This class handles pairwise gene comparisons between profiles within specified genome and gene scopes. It manages profile metadata, validates compatibility, and generates comparison tasks.

Attributes:

Name Type Description
profiles dict[str, ProfileItem]

A dictionary mapping profile names to ProfileItem objects.

config GeneComparisonConfig

Configuration for gene comparisons.

Source code in zipstrain/src/zipstrain/database.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
class GeneComparisonDatabase:
    """
    A database for managing gene-level comparisons between profiles.

    This class handles pairwise gene comparisons between profiles within specified genome and gene scopes.
    It manages profile metadata, validates compatibility, and generates comparison tasks.

    Attributes:
        profiles (dict[str, ProfileItem]): A dictionary mapping profile names to ProfileItem objects.
        config (GeneComparisonConfig): Configuration for gene comparisons.

    """

    def __init__(self, config: GeneComparisonConfig) -> None:
        """Initialize a GeneComparisonDatabase with the given configuration.

        Args:
            config (GeneComparisonConfig): Configuration for gene comparisons.
        """
        self.profiles: dict[str, ProfileItem] = {}
        self.config = config

    def add_profile(self, profile: ProfileItem) -> None:
        """Add a profile to the database.

        Args:
            profile (ProfileItem): Profile to add.

        Raises:
            ValueError: If a profile with the same name already exists or if reference_db_id doesn't match existing profiles.
        """
        if profile.profile_name in self.profiles:
            raise ValueError(f"Profile with name {profile.profile_name} already exists in the database.")

        # Check if this is the first profile or if reference_db_id matches
        if self.profiles:
            existing_ref_id = next(iter(self.profiles.values())).reference_db_id
            if profile.reference_db_id != existing_ref_id:
                raise ValueError(
                    f"Reference database ID mismatch: {profile.reference_db_id} != {existing_ref_id}. "
                    "All profiles in a GeneComparisonDatabase must use the same reference database."
                )

        self.profiles[profile.profile_name] = profile

    def remove_profile(self, profile_name: str) -> None:
        """Remove a profile from the database by name.

        Args:
            profile_name (str): Name of the profile to remove.

        Raises:
            KeyError: If profile_name doesn't exist in the database.
        """
        if profile_name not in self.profiles:
            raise KeyError(f"Profile {profile_name} not found in database.")
        del self.profiles[profile_name]

    def to_complete_input_table(self) -> pl.LazyFrame:
        """Generate a LazyFrame containing all pairwise comparisons for the specified gene scope.

        Returns:
            pl.LazyFrame: A LazyFrame with columns:
                - sample_name_1: First sample name
                - sample_name_2: Second sample name
                - profile_location_1: Location of first profile
                - profile_location_2: Location of second profile
                - scaffold_location_1: Location of first scaffold file
                - scaffold_location_2: Location of second scaffold file
        """
        if len(self.profiles) < 2:
            raise ValueError("At least 2 profiles are required for comparisons.")

        profile_list = list(self.profiles.values())
        comparisons = []

        # Generate all pairwise combinations
        for i in range(len(profile_list)):
            for j in range(i + 1, len(profile_list)):
                profile_1 = profile_list[i]
                profile_2 = profile_list[j]

                comparisons.append({
                    "sample_name_1": profile_1.profile_name,
                    "sample_name_2": profile_2.profile_name,
                    "profile_location_1": profile_1.profile_location,
                    "profile_location_2": profile_2.profile_location,
                    "scaffold_location_1": profile_1.scaffold_location,
                    "scaffold_location_2": profile_2.scaffold_location,
                })

        return pl.LazyFrame(comparisons)

    def save_obj(self, path: pathlib.Path) -> None:
        """Save the GeneComparisonDatabase to a JSON file.

        Args:
            path (pathlib.Path): Path where the JSON file will be saved.
        """
        data = {
            "config": self.config.model_dump(),
            "profiles": {name: profile.model_dump() for name, profile in self.profiles.items()}
        }
        with open(path, "w") as f:
            json.dump(data, f, indent=2)

    @classmethod
    def load_obj(cls, path: pathlib.Path) -> GeneComparisonDatabase:
        """Load a GeneComparisonDatabase from a JSON file.

        Args:
            path (pathlib.Path): Path to the JSON file.

        Returns:
            GeneComparisonDatabase: The loaded database object.
        """
        with open(path, "r") as f:
            data = json.load(f)

        config = GeneComparisonConfig(**data["config"])
        db = cls(config=config)

        for profile_data in data["profiles"].values():
            db.add_profile(ProfileItem(**profile_data))

        return db

    def __len__(self) -> int:
        """Return the number of profiles in the database."""
        return len(self.profiles)

    def __repr__(self) -> str:
        """Return a string representation of the database."""
        genome_scope = self.config.get_genome_scope()
        gene_scope = self.config.get_gene_scope()
        return (
            f"GeneComparisonDatabase(n_profiles={len(self.profiles)}, "
            f"genome_scope='{genome_scope}', gene_scope='{gene_scope}', "
            f"reference_db='{next(iter(self.profiles.values())).reference_db_id if self.profiles else 'N/A'}')"
        )
__init__(config)

Initialize a GeneComparisonDatabase with the given configuration.

Parameters:

Name Type Description Default
config GeneComparisonConfig

Configuration for gene comparisons.

required
Source code in zipstrain/src/zipstrain/database.py
272
273
274
275
276
277
278
279
def __init__(self, config: GeneComparisonConfig) -> None:
    """Initialize a GeneComparisonDatabase with the given configuration.

    Args:
        config (GeneComparisonConfig): Configuration for gene comparisons.
    """
    self.profiles: dict[str, ProfileItem] = {}
    self.config = config
__len__()

Return the number of profiles in the database.

Source code in zipstrain/src/zipstrain/database.py
386
387
388
def __len__(self) -> int:
    """Return the number of profiles in the database."""
    return len(self.profiles)
__repr__()

Return a string representation of the database.

Source code in zipstrain/src/zipstrain/database.py
390
391
392
393
394
395
396
397
398
def __repr__(self) -> str:
    """Return a string representation of the database."""
    genome_scope = self.config.get_genome_scope()
    gene_scope = self.config.get_gene_scope()
    return (
        f"GeneComparisonDatabase(n_profiles={len(self.profiles)}, "
        f"genome_scope='{genome_scope}', gene_scope='{gene_scope}', "
        f"reference_db='{next(iter(self.profiles.values())).reference_db_id if self.profiles else 'N/A'}')"
    )
add_profile(profile)

Add a profile to the database.

Parameters:

Name Type Description Default
profile ProfileItem

Profile to add.

required

Raises:

Type Description
ValueError

If a profile with the same name already exists or if reference_db_id doesn't match existing profiles.

Source code in zipstrain/src/zipstrain/database.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def add_profile(self, profile: ProfileItem) -> None:
    """Add a profile to the database.

    Args:
        profile (ProfileItem): Profile to add.

    Raises:
        ValueError: If a profile with the same name already exists or if reference_db_id doesn't match existing profiles.
    """
    if profile.profile_name in self.profiles:
        raise ValueError(f"Profile with name {profile.profile_name} already exists in the database.")

    # Check if this is the first profile or if reference_db_id matches
    if self.profiles:
        existing_ref_id = next(iter(self.profiles.values())).reference_db_id
        if profile.reference_db_id != existing_ref_id:
            raise ValueError(
                f"Reference database ID mismatch: {profile.reference_db_id} != {existing_ref_id}. "
                "All profiles in a GeneComparisonDatabase must use the same reference database."
            )

    self.profiles[profile.profile_name] = profile
load_obj(path) classmethod

Load a GeneComparisonDatabase from a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file.

required

Returns:

Name Type Description
GeneComparisonDatabase GeneComparisonDatabase

The loaded database object.

Source code in zipstrain/src/zipstrain/database.py
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
@classmethod
def load_obj(cls, path: pathlib.Path) -> GeneComparisonDatabase:
    """Load a GeneComparisonDatabase from a JSON file.

    Args:
        path (pathlib.Path): Path to the JSON file.

    Returns:
        GeneComparisonDatabase: The loaded database object.
    """
    with open(path, "r") as f:
        data = json.load(f)

    config = GeneComparisonConfig(**data["config"])
    db = cls(config=config)

    for profile_data in data["profiles"].values():
        db.add_profile(ProfileItem(**profile_data))

    return db
remove_profile(profile_name)

Remove a profile from the database by name.

Parameters:

Name Type Description Default
profile_name str

Name of the profile to remove.

required

Raises:

Type Description
KeyError

If profile_name doesn't exist in the database.

Source code in zipstrain/src/zipstrain/database.py
304
305
306
307
308
309
310
311
312
313
314
315
def remove_profile(self, profile_name: str) -> None:
    """Remove a profile from the database by name.

    Args:
        profile_name (str): Name of the profile to remove.

    Raises:
        KeyError: If profile_name doesn't exist in the database.
    """
    if profile_name not in self.profiles:
        raise KeyError(f"Profile {profile_name} not found in database.")
    del self.profiles[profile_name]
save_obj(path)

Save the GeneComparisonDatabase to a JSON file.

Parameters:

Name Type Description Default
path Path

Path where the JSON file will be saved.

required
Source code in zipstrain/src/zipstrain/database.py
352
353
354
355
356
357
358
359
360
361
362
363
def save_obj(self, path: pathlib.Path) -> None:
    """Save the GeneComparisonDatabase to a JSON file.

    Args:
        path (pathlib.Path): Path where the JSON file will be saved.
    """
    data = {
        "config": self.config.model_dump(),
        "profiles": {name: profile.model_dump() for name, profile in self.profiles.items()}
    }
    with open(path, "w") as f:
        json.dump(data, f, indent=2)
to_complete_input_table()

Generate a LazyFrame containing all pairwise comparisons for the specified gene scope.

Returns:

Type Description
LazyFrame

pl.LazyFrame: A LazyFrame with columns: - sample_name_1: First sample name - sample_name_2: Second sample name - profile_location_1: Location of first profile - profile_location_2: Location of second profile - scaffold_location_1: Location of first scaffold file - scaffold_location_2: Location of second scaffold file

Source code in zipstrain/src/zipstrain/database.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def to_complete_input_table(self) -> pl.LazyFrame:
    """Generate a LazyFrame containing all pairwise comparisons for the specified gene scope.

    Returns:
        pl.LazyFrame: A LazyFrame with columns:
            - sample_name_1: First sample name
            - sample_name_2: Second sample name
            - profile_location_1: Location of first profile
            - profile_location_2: Location of second profile
            - scaffold_location_1: Location of first scaffold file
            - scaffold_location_2: Location of second scaffold file
    """
    if len(self.profiles) < 2:
        raise ValueError("At least 2 profiles are required for comparisons.")

    profile_list = list(self.profiles.values())
    comparisons = []

    # Generate all pairwise combinations
    for i in range(len(profile_list)):
        for j in range(i + 1, len(profile_list)):
            profile_1 = profile_list[i]
            profile_2 = profile_list[j]

            comparisons.append({
                "sample_name_1": profile_1.profile_name,
                "sample_name_2": profile_2.profile_name,
                "profile_location_1": profile_1.profile_location,
                "profile_location_2": profile_2.profile_location,
                "scaffold_location_1": profile_1.scaffold_location,
                "scaffold_location_2": profile_2.scaffold_location,
            })

    return pl.LazyFrame(comparisons)

GenomeComparisonConfig

Bases: BaseModel

This class defines object which have all necessary options to describe Parameters used to compare profiles:

Attributes:

Name Type Description
gene_db_id str

The ID of the gene fasta database to use for the comparison. The file name is perfect.

reference_id str

The ID of the reference fasta database to use for the comparison. The file name is perfect.

scope str

The scope of the comparison- 'all' if all covered positions are desired. Otherwise, a bunch of genome names separated by commas.

min_cov int

Minimum coverage a base on the reference fasta that must have in order to be compared.

null_model_p_value(float) int

P_value above which a base call is counted as sequencing error

min_gene_compare_len int

Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations

stb_file_loc str

The location of the scaffold to bin file.

null_model_loc str

The location of the null model file.

Source code in zipstrain/src/zipstrain/database.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
class GenomeComparisonConfig(BaseModel):
    """
    This class defines object which have all necessary options to describe 
    Parameters used to compare profiles:

    Attributes:
        gene_db_id (str): The ID of the gene fasta database to use for the comparison. The file name is perfect.
        reference_id (str): The ID of the reference fasta database to use for the comparison. The file name is perfect.
        scope (str): The scope of the comparison- 'all' if all covered positions are desired. Otherwise, a bunch of genome names separated by commas.
        min_cov (int): Minimum coverage a base on the reference fasta that must have in order to be compared.
        null_model_p_value(float): P_value above which a base call is counted as sequencing error
        min_gene_compare_len (int): Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations
        stb_file_loc (str): The location of the scaffold to bin file.
        null_model_loc (str): The location of the null model file.
    """
    model_config = ConfigDict(extra="forbid")
    gene_db_id:str= Field(default="",description="An ID given to the gene fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with gene database IDs in the Profile Database.")
    reference_id:str= Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
    scope: str =Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
    min_cov: int =Field(description="Minimum coverage a base on the reference fasta that must have in order to be compared.")
    min_gene_compare_len: int=Field(description="Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations")
    null_model_p_value:float=Field(default=0.05,description="P_value above which a base call is counted as sequencing error")
    stb_file_loc:str=Field(description="The location of the scaffold to bin file.")
    null_model_loc:str=Field(description="The location of the null model file.")

    def is_compatible(self, other: GenomeComparisonConfig) -> bool:
        """
        Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope.
        Scope can be different as long as they are not disjoint. Also, all is compatible with any scope.
        Args:
            other (GenomeComparisonConfig): The other comparison configuration to check compatibility with.
        Returns:
            bool: True if the configurations are compatible, False otherwise.
        """
        attrs=self.__dict__
        for key in attrs:
            if key!="scope":
                if attrs[key] != getattr(other, key):
                    return False
        if other.scope != "all" and self.scope != "all":
            if (set(other.scope.split(",")).intersection(set(self.scope.split(","))) == set()):
                return False
        return True

    @classmethod
    def from_json(cls,json_file_dir:str)->GenomeComparisonConfig:
        """Create a GenomeComparisonConfig instance from a json file."""
        with open(json_file_dir, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

    def to_json(self,json_file_dir:str)->None:
        """Writes the the current object to a json file"""
        with open(json_file_dir,"w") as f:
            json.dump(self.__dict__,f)

    def to_dict(self)->dict:
        """Returns the dictionary representation of the current object"""
        return copy.copy(self.__dict__)


    def get_maximal_scope_config(self, other: GenomeComparisonConfig) -> GenomeComparisonConfig:
        """
        Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations.
        Args:
            other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with.
        Returns:
            GenomeComparisonConfig: The new comparison configuration with the maximal scope.
        """
        if not self.is_compatible(other):
            raise ValueError("The two comparison configurations are not compatible.")

        new_scope=None
        if other.scope == "all" and self.scope == "all":
            new_scope="all"

        elif other.scope == "all":
            new_scope=self.scope.split(",")

        elif self.scope == "all":
            new_scope=other.scope.split(",")

        else:
            new_scope=list(set(self.scope.split(",")).intersection(set(other.scope.split(","))))
        curr_config_dict=self.to_dict()
        curr_config_dict["scope"]=new_scope if new_scope=="all" else ",".join(sorted(new_scope))
        return GenomeComparisonConfig(**curr_config_dict)
from_json(json_file_dir) classmethod

Create a GenomeComparisonConfig instance from a json file.

Source code in zipstrain/src/zipstrain/database.py
445
446
447
448
449
450
@classmethod
def from_json(cls,json_file_dir:str)->GenomeComparisonConfig:
    """Create a GenomeComparisonConfig instance from a json file."""
    with open(json_file_dir, 'r') as f:
        config_dict = json.load(f)
    return cls(**config_dict)
get_maximal_scope_config(other)

Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations. Args: other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with. Returns: GenomeComparisonConfig: The new comparison configuration with the maximal scope.

Source code in zipstrain/src/zipstrain/database.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def get_maximal_scope_config(self, other: GenomeComparisonConfig) -> GenomeComparisonConfig:
    """
    Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations.
    Args:
        other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with.
    Returns:
        GenomeComparisonConfig: The new comparison configuration with the maximal scope.
    """
    if not self.is_compatible(other):
        raise ValueError("The two comparison configurations are not compatible.")

    new_scope=None
    if other.scope == "all" and self.scope == "all":
        new_scope="all"

    elif other.scope == "all":
        new_scope=self.scope.split(",")

    elif self.scope == "all":
        new_scope=other.scope.split(",")

    else:
        new_scope=list(set(self.scope.split(",")).intersection(set(other.scope.split(","))))
    curr_config_dict=self.to_dict()
    curr_config_dict["scope"]=new_scope if new_scope=="all" else ",".join(sorted(new_scope))
    return GenomeComparisonConfig(**curr_config_dict)
is_compatible(other)

Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope. Scope can be different as long as they are not disjoint. Also, all is compatible with any scope. Args: other (GenomeComparisonConfig): The other comparison configuration to check compatibility with. Returns: bool: True if the configurations are compatible, False otherwise.

Source code in zipstrain/src/zipstrain/database.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def is_compatible(self, other: GenomeComparisonConfig) -> bool:
    """
    Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope.
    Scope can be different as long as they are not disjoint. Also, all is compatible with any scope.
    Args:
        other (GenomeComparisonConfig): The other comparison configuration to check compatibility with.
    Returns:
        bool: True if the configurations are compatible, False otherwise.
    """
    attrs=self.__dict__
    for key in attrs:
        if key!="scope":
            if attrs[key] != getattr(other, key):
                return False
    if other.scope != "all" and self.scope != "all":
        if (set(other.scope.split(",")).intersection(set(self.scope.split(","))) == set()):
            return False
    return True
to_dict()

Returns the dictionary representation of the current object

Source code in zipstrain/src/zipstrain/database.py
457
458
459
def to_dict(self)->dict:
    """Returns the dictionary representation of the current object"""
    return copy.copy(self.__dict__)
to_json(json_file_dir)

Writes the the current object to a json file

Source code in zipstrain/src/zipstrain/database.py
452
453
454
455
def to_json(self,json_file_dir:str)->None:
    """Writes the the current object to a json file"""
    with open(json_file_dir,"w") as f:
        json.dump(self.__dict__,f)

GenomeComparisonDatabase

GenomeComparisonDatabase object holds a reference to a comparison parquet file. The methods in this class serve to provide functionality for working with the comparison data in an easy and efficient manner. The comparison parquet file the result of running compare, and optionally concatenating multiple compare parquet file from single comparisons. This parquet file must contain the following columns:

  • genome

  • total_positions

  • share_allele_pos

  • genome_pop_ani

  • max_consecutive_length

  • shared_genes_count

  • identical_gene_count

  • sample_1

  • sample_2

A ComparisonDatabase object needs a ComparisonConfig object to specify the parameters used for the comparison.

Parameters:

Name Type Description Default
profile_db ProfileDatabase

The profile database used for the comparison.

required
config GenomeComparisonConfig

The comparison configuration used for the comparison.

required
comp_db_loc str | None

The location of the comparison database parquet file. If None, an empty comparison database is created.

None
Source code in zipstrain/src/zipstrain/database.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
class GenomeComparisonDatabase:
    """
    GenomeComparisonDatabase object holds a reference to a comparison parquet file. The methods in this class serve to provide
    functionality for working with the comparison data in an easy and efficient manner.
    The comparison parquet file the result of running compare, and optionally concatenating multiple compare parquet file from single comparisons.
    This parquet file must contain the following columns:

    - genome

    - total_positions

    - share_allele_pos

    - genome_pop_ani

    - max_consecutive_length

    - shared_genes_count

    - identical_gene_count

    - sample_1

    - sample_2

    A ComparisonDatabase object needs a ComparisonConfig object to specify the parameters used for the comparison.

    Args:
        profile_db (ProfileDatabase): The profile database used for the comparison.
        config (GenomeComparisonConfig): The comparison configuration used for the comparison.
        comp_db_loc (str|None): The location of the comparison database parquet file. If
            None, an empty comparison database is created.

    """
    COLUMN_NAMES = [
        "genome",
        "total_positions",
        "share_allele_pos",
        "genome_pop_ani",
        "max_consecutive_length",
        "shared_genes_count",
        "identical_gene_count",
        "perc_id_genes",
        "sample_1",
        "sample_2"
    ]

    def __init__(self,
                 profile_db: ProfileDatabase,
                 config: GenomeComparisonConfig,
                 comp_db_loc: str|None = None,
                 ):
        self.profile_db = profile_db
        self.config = config
        if comp_db_loc is not None:
            self.comp_db_loc = pathlib.Path(comp_db_loc)
            self._comp_db = pl.scan_parquet(self.comp_db_loc)
        else:
            self.comp_db_loc = None
            self._comp_db=pl.LazyFrame({
                "genome": [],
                "total_positions": [],
                "share_allele_pos": [],
                "genome_pop_ani": [],
                "max_consecutive_length": [],
                "shared_genes_count": [],
                "identical_gene_count": [],
                "perc_id_genes": [],
                "sample_1": [],
                "sample_2": []
            }, schema={
                "genome": pl.Utf8,
                "total_positions": pl.Int64,
                "share_allele_pos": pl.Int64,
                "genome_pop_ani": pl.Float64,
                "max_consecutive_length": pl.Int64,
                "shared_genes_count": pl.Int64,
                "identical_gene_count": pl.Int64,
                "perc_id_genes": pl.Float64,
                "sample_1": pl.Utf8,
                "sample_2": pl.Utf8
            })
            self.comp_db_loc=None

    @property
    def comp_db(self):
        return self._comp_db

    def _validate_db(self)->None:
        self.profile_db._validate_db()

        if set(self._comp_db.collect_schema()) != set(self.COLUMN_NAMES):
            raise ValueError(f"Your comparison database must provide these extra columns: { set(self.COLUMN_NAMES)-set(self._comp_db.collect_schema())}")
        #check if all profile names exist in the profile database
        profile_names_in_comp_db = set(self.get_all_profile_names())
        profile_names_in_profile_db = set(self.profile_db.db.select("profile_name").collect(engine="streaming").to_series().to_list())
        if not profile_names_in_comp_db.issubset(profile_names_in_profile_db):
            raise ValueError(f"The following profile names are in the comparison database but not in the profile database: {profile_names_in_comp_db - profile_names_in_profile_db}")

    def get_all_profile_names(self) -> set[str]:
        """
        Get all profile names that are in the comparison database.
        """
        return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
            set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
        )
    def get_remaining_pairs(self) -> pl.LazyFrame:
        """
        Get pairs of profiles that are in the profile database but not in the comparison database.
        """
        profiles = self.profile_db.db.select("profile_name")
        pairs=profiles.join(profiles,how="cross").rename({"profile_name":"profile_1","profile_name_right":"profile_2"}).filter(pl.col("profile_1")<pl.col("profile_2"))
        samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")).select(["profile_1", "profile_2"])

        remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1","profile_2"])
        return remaining_pairs

    def is_complete(self) -> bool:
        """
        Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
        """
        return self.get_remaining_pairs().collect(engine="streaming").is_empty()

    def add_comp_database(self, comp_database: GenomeComparisonDatabase) -> None:
        """Merge the provided comparison database into the current database.

        Args:
            comp_database (ComparisonDatabase): The comparison database to merge.
        """
        try:
            comp_database._validate_db()

        except Exception as e:
            raise ValueError(f"The comparison database provided is not valid: {e}")

        if not self.config.is_compatible(comp_database.config):
            raise ValueError("The comparison database provided is not compatible with the current comparison database.")

        self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
        self.config = self.config.get_maximal_scope_config(comp_database.config)


    def save_new_compare_database(self, output_path: str) -> None:
        """Save the database to a parquet file."""
        output_path = pathlib.Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # The new database must be written to a new location
        if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
            raise ValueError("The output path must be different from the current database location.")

        self.comp_db.sink_parquet(output_path)


    def update_compare_database(self)->None:
        """Overwrites the comparison database saved on the disk to the current comparison database object
        """
        if self.comp_db_loc is None:
            raise Exception("comp_db_loc attribute is not determined yet!")
        try:
            tmp_path=pathlib.Path(tempfile.mktemp(suffix=".parquet",prefix="tmp_comp_db_",dir=str(self.comp_db_loc.parent)))
            self.comp_db.sink_parquet(tmp_path)
            os.replace(tmp_path,self.comp_db_loc)
            self._comp_db=pl.scan_parquet(self.comp_db_loc)
        except Exception as e:
            raise Exception(f"Something went wrong when updating the comparison database:{e}")

    def dump_obj(self, output_path: str) -> None:
        """Dump the current object to a json file.

        Args:
            output_path (str): The path to save the json file to.
        """
        obj_dict = {
            "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
            "config": self.config.to_dict(),
            "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
        }
        with open(output_path, "w") as f:
            json.dump(obj_dict, f, indent=4)

    @classmethod
    def load_obj(cls, json_path: str) -> GenomeComparisonDatabase:
        """Load a GenomeComparisonDatabase object from a json file.

        Args:
            json_path (str): The path to the json file.

        Returns:
            GenomeComparisonDatabase: The loaded GenomeComparisonDatabase object.
        """
        with open(json_path, "r") as f:
            obj_dict = json.load(f)

        return cls(profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]) , 
                   config=GenomeComparisonConfig(**obj_dict["config"]), 
                   comp_db_loc=obj_dict["comp_db_loc"])


    def to_complete_input_table(self)->pl.LazyFrame:
        """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

        - sample_name_1

        - sample_name_2

        - profile_location_1

        - scaffold_location_1

        - profile_location_2

        - scaffold_location_2

        Returns:
            pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
        """
        lf=self.get_remaining_pairs().rename({"profile_1":"sample_name_1","profile_2":"sample_name_2"})
        return (lf.join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_1",right_on="profile_name",how="left")
                .rename({"profile_location":"profile_location_1","scaffold_location":"scaffold_location_1"})
                .join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_2",right_on="profile_name",how="left")
                .rename({"profile_location":"profile_location_2","scaffold_location":"scaffold_location_2"})
               )
add_comp_database(comp_database)

Merge the provided comparison database into the current database.

Parameters:

Name Type Description Default
comp_database ComparisonDatabase

The comparison database to merge.

required
Source code in zipstrain/src/zipstrain/database.py
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def add_comp_database(self, comp_database: GenomeComparisonDatabase) -> None:
    """Merge the provided comparison database into the current database.

    Args:
        comp_database (ComparisonDatabase): The comparison database to merge.
    """
    try:
        comp_database._validate_db()

    except Exception as e:
        raise ValueError(f"The comparison database provided is not valid: {e}")

    if not self.config.is_compatible(comp_database.config):
        raise ValueError("The comparison database provided is not compatible with the current comparison database.")

    self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
    self.config = self.config.get_maximal_scope_config(comp_database.config)
dump_obj(output_path)

Dump the current object to a json file.

Parameters:

Name Type Description Default
output_path str

The path to save the json file to.

required
Source code in zipstrain/src/zipstrain/database.py
657
658
659
660
661
662
663
664
665
666
667
668
669
def dump_obj(self, output_path: str) -> None:
    """Dump the current object to a json file.

    Args:
        output_path (str): The path to save the json file to.
    """
    obj_dict = {
        "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
        "config": self.config.to_dict(),
        "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
    }
    with open(output_path, "w") as f:
        json.dump(obj_dict, f, indent=4)
get_all_profile_names()

Get all profile names that are in the comparison database.

Source code in zipstrain/src/zipstrain/database.py
589
590
591
592
593
594
595
def get_all_profile_names(self) -> set[str]:
    """
    Get all profile names that are in the comparison database.
    """
    return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
        set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
    )
get_remaining_pairs()

Get pairs of profiles that are in the profile database but not in the comparison database.

Source code in zipstrain/src/zipstrain/database.py
596
597
598
599
600
601
602
603
604
605
def get_remaining_pairs(self) -> pl.LazyFrame:
    """
    Get pairs of profiles that are in the profile database but not in the comparison database.
    """
    profiles = self.profile_db.db.select("profile_name")
    pairs=profiles.join(profiles,how="cross").rename({"profile_name":"profile_1","profile_name_right":"profile_2"}).filter(pl.col("profile_1")<pl.col("profile_2"))
    samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")).select(["profile_1", "profile_2"])

    remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1","profile_2"])
    return remaining_pairs
is_complete()

Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.

Source code in zipstrain/src/zipstrain/database.py
607
608
609
610
611
def is_complete(self) -> bool:
    """
    Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
    """
    return self.get_remaining_pairs().collect(engine="streaming").is_empty()
load_obj(json_path) classmethod

Load a GenomeComparisonDatabase object from a json file.

Parameters:

Name Type Description Default
json_path str

The path to the json file.

required

Returns:

Name Type Description
GenomeComparisonDatabase GenomeComparisonDatabase

The loaded GenomeComparisonDatabase object.

Source code in zipstrain/src/zipstrain/database.py
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
@classmethod
def load_obj(cls, json_path: str) -> GenomeComparisonDatabase:
    """Load a GenomeComparisonDatabase object from a json file.

    Args:
        json_path (str): The path to the json file.

    Returns:
        GenomeComparisonDatabase: The loaded GenomeComparisonDatabase object.
    """
    with open(json_path, "r") as f:
        obj_dict = json.load(f)

    return cls(profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]) , 
               config=GenomeComparisonConfig(**obj_dict["config"]), 
               comp_db_loc=obj_dict["comp_db_loc"])
save_new_compare_database(output_path)

Save the database to a parquet file.

Source code in zipstrain/src/zipstrain/database.py
632
633
634
635
636
637
638
639
640
641
def save_new_compare_database(self, output_path: str) -> None:
    """Save the database to a parquet file."""
    output_path = pathlib.Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # The new database must be written to a new location
    if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
        raise ValueError("The output path must be different from the current database location.")

    self.comp_db.sink_parquet(output_path)
to_complete_input_table()

This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

  • sample_name_1

  • sample_name_2

  • profile_location_1

  • scaffold_location_1

  • profile_location_2

  • scaffold_location_2

Returns:

Type Description
LazyFrame

pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.

Source code in zipstrain/src/zipstrain/database.py
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def to_complete_input_table(self)->pl.LazyFrame:
    """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:

    - sample_name_1

    - sample_name_2

    - profile_location_1

    - scaffold_location_1

    - profile_location_2

    - scaffold_location_2

    Returns:
        pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
    """
    lf=self.get_remaining_pairs().rename({"profile_1":"sample_name_1","profile_2":"sample_name_2"})
    return (lf.join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_1",right_on="profile_name",how="left")
            .rename({"profile_location":"profile_location_1","scaffold_location":"scaffold_location_1"})
            .join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_2",right_on="profile_name",how="left")
            .rename({"profile_location":"profile_location_2","scaffold_location":"scaffold_location_2"})
           )
update_compare_database()

Overwrites the comparison database saved on the disk to the current comparison database object

Source code in zipstrain/src/zipstrain/database.py
644
645
646
647
648
649
650
651
652
653
654
655
def update_compare_database(self)->None:
    """Overwrites the comparison database saved on the disk to the current comparison database object
    """
    if self.comp_db_loc is None:
        raise Exception("comp_db_loc attribute is not determined yet!")
    try:
        tmp_path=pathlib.Path(tempfile.mktemp(suffix=".parquet",prefix="tmp_comp_db_",dir=str(self.comp_db_loc.parent)))
        self.comp_db.sink_parquet(tmp_path)
        os.replace(tmp_path,self.comp_db_loc)
        self._comp_db=pl.scan_parquet(self.comp_db_loc)
    except Exception as e:
        raise Exception(f"Something went wrong when updating the comparison database:{e}")

ProfileDatabase

The profile database simply holds profile information. Does not need to be specific to a comparison database. The data behind a profile is stored in a parquet file. It is basically a table with the following columns:

  • profile_name: An arbitrary name given to the profile (Usually sample name or name of the parquet file)

  • profile_location: The location of the profile

  • scaffold_location: The location of the scaffold

  • reference_db_id: The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.

  • gene_db_id: The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.

Parameters:

Name Type Description Default
db_loc str | None

The location of the profile database parquet file. If None, an empty database is created.

None
Source code in zipstrain/src/zipstrain/database.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class ProfileDatabase:
    """
    The profile database simply holds profile information. Does not need to be specific to a comparison database.
    The data behind a profile is stored in a parquet file. It is basically a table with the following columns:

    - profile_name: An arbitrary name given to the profile (Usually sample name or name of the parquet file)

    - profile_location: The location of the profile

    - scaffold_location: The location of the scaffold

    - reference_db_id: The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.

    - gene_db_id: The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.

    Args:
        db_loc (str|None): The location of the profile database parquet file. If None, an empty database is created.

    """
    def __init__(self,
                 db_loc: str|None = None,
                 ):
        if db_loc is not None:
            self.db_loc = pathlib.Path(db_loc)
            self._db = pl.scan_parquet(self.db_loc)
        else:
            self._db=pl.LazyFrame({
                "profile_name": [],
                "profile_location": [],
                "scaffold_location": [],
                "reference_db_id": [],
                "gene_db_id": []
            }, schema={
                "profile_name": pl.Utf8,
                "profile_location": pl.Utf8,
                "scaffold_location": pl.Utf8,
                "reference_db_id": pl.Utf8,
                "gene_db_id": pl.Utf8
            })
            self.db_loc=None

    @property
    def db(self):
        return self._db

    def _validate_db(self,check_profile_exists: bool=True,check_scaffold_exists:bool=True)->None:
        """Simple method to see if the database has the minimum required structure."""

        ### Next check if the database has the required columns
        required_columns = ["profile_name","profile_location", "scaffold_location", "reference_db_id", "gene_db_id"]
        for col in required_columns:
            if col not in self.db.collect_schema().names():
                raise ValueError(f"Missing required column: {col}")

        if check_profile_exists:
            # Check if the profile exists in the database
            db_path_validated= self.db.select(pl.col("profile_location")).collect(engine="streaming").with_columns(
                (pl.col("profile_location").map_elements(lambda x: pathlib.Path(x).exists(),return_dtype=pl.Boolean)).alias("profile_exists")
            ).filter(~ pl.col("profile_exists"))
            if db_path_validated.height != 0:
                raise ValueError(f"There are {db_path_validated.height} profiles that do not exist: {db_path_validated['profile_location'].to_list()}")
            ### add log later
        if check_scaffold_exists:
            db_path_validated= self.db.select(pl.col("scaffold_location")).collect(engine="streaming").with_columns(
                (pl.col("scaffold_location").map_elements(lambda x: pathlib.Path(x).exists(),return_dtype=pl.Boolean)).alias("scaffold_exists")
            ).filter(~ pl.col("scaffold_exists"))
            if db_path_validated.height != 0:
                raise ValueError(f"There are {db_path_validated.height} scaffolds that do not exist: {db_path_validated['scaffold_location'].to_list()}")
            ### add log later

    def add_profile(self,
                    data: dict
                    ) -> None:
        """Add a profile to the database.
        The data dictionary must contain the following and only the following keys:

        - profile_name

        - profile_location

        - scaffold_location

        - reference_db_id

        - gene_db_id

        Args:
            data (dict): The profile data to add.
        """
        try:
            profile_item = ProfileItem(**data)
            lf=pl.LazyFrame({
                "profile_name": [profile_item.profile_name],
                "profile_location": [profile_item.profile_location],
                "scaffold_location": [profile_item.scaffold_location],
                "reference_db_id": [profile_item.reference_db_id],
                "gene_db_id": [profile_item.gene_db_id]
            })
            self._db = pl.concat([self.db, lf]).unique()
            self._validate_db()
        except Exception as e:
            raise ValueError(f"The profile data provided is not valid: {e}")


    def add_database(self, profile_database: ProfileDatabase) -> None:
        """Merge the provided profile database into the current database.

        Args:
            profile_database (ProfileDatabase): The profile database to merge.
        """
        try:
            profile_database._validate_db()

        except Exception as e:
            raise ValueError(f"The profile database provided is not valid: {e}")

        self._db = pl.concat([self._db, profile_database.db]).unique()


    def save_as_new_database(self, output_path: str) -> None:
        """Save the database to a parquet file.

        Args:
            output_path (str): The path to save the database to.
        """
        #The new database must be written to a new location
        if self.db_loc is not None and str(self.db_loc.absolute()) == str(pathlib.Path(output_path).absolute()):
            raise ValueError("The output path must be different from the current database location.")

        try:
            self.db.sink_parquet(output_path)
            self.db_loc=pathlib.Path(output_path)
        ### add log later
        except Exception as e:
            pass 
        ### add log later

    def update_database(self)->None:
        """Overwrites the database saved on the disk to the current database object
        """
        if self.db_loc is None:
            raise Exception("db_loc attribute is not determined yet!")
        try:
            self.db.collect(engine="streaming").write_parquet(self.db_loc)
        except Exception as e:
            raise Exception(f"Something went wrong when updating the database:{e}")


    @classmethod
    def from_csv(cls, csv_path: str) -> ProfileDatabase:
        """Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

        Args:
            csv_path (str): The path to the CSV file.

        Returns:
            ProfileDatabase: The created ProfileDatabase instance.
        """
        lf=pl.scan_csv(csv_path).collect().lazy() # To avoid clash when using to_csv on same file
        prof_db=cls()
        prof_db._db=lf
        prof_db._validate_db()
        return prof_db

    def to_csv(self,output_dir:str)->None:
        """Writes the the current database object to a csv file"

        Args:
            output_dir (str): The path to save the CSV file.

        Returns:
            None
        """
        self.db.sink_csv(output_dir,engine="streaming")
add_database(profile_database)

Merge the provided profile database into the current database.

Parameters:

Name Type Description Default
profile_database ProfileDatabase

The profile database to merge.

required
Source code in zipstrain/src/zipstrain/database.py
147
148
149
150
151
152
153
154
155
156
157
158
159
def add_database(self, profile_database: ProfileDatabase) -> None:
    """Merge the provided profile database into the current database.

    Args:
        profile_database (ProfileDatabase): The profile database to merge.
    """
    try:
        profile_database._validate_db()

    except Exception as e:
        raise ValueError(f"The profile database provided is not valid: {e}")

    self._db = pl.concat([self._db, profile_database.db]).unique()
add_profile(data)

Add a profile to the database. The data dictionary must contain the following and only the following keys:

  • profile_name

  • profile_location

  • scaffold_location

  • reference_db_id

  • gene_db_id

Parameters:

Name Type Description Default
data dict

The profile data to add.

required
Source code in zipstrain/src/zipstrain/database.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def add_profile(self,
                data: dict
                ) -> None:
    """Add a profile to the database.
    The data dictionary must contain the following and only the following keys:

    - profile_name

    - profile_location

    - scaffold_location

    - reference_db_id

    - gene_db_id

    Args:
        data (dict): The profile data to add.
    """
    try:
        profile_item = ProfileItem(**data)
        lf=pl.LazyFrame({
            "profile_name": [profile_item.profile_name],
            "profile_location": [profile_item.profile_location],
            "scaffold_location": [profile_item.scaffold_location],
            "reference_db_id": [profile_item.reference_db_id],
            "gene_db_id": [profile_item.gene_db_id]
        })
        self._db = pl.concat([self.db, lf]).unique()
        self._validate_db()
    except Exception as e:
        raise ValueError(f"The profile data provided is not valid: {e}")
from_csv(csv_path) classmethod

Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

Parameters:

Name Type Description Default
csv_path str

The path to the CSV file.

required

Returns:

Name Type Description
ProfileDatabase ProfileDatabase

The created ProfileDatabase instance.

Source code in zipstrain/src/zipstrain/database.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@classmethod
def from_csv(cls, csv_path: str) -> ProfileDatabase:
    """Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.

    Args:
        csv_path (str): The path to the CSV file.

    Returns:
        ProfileDatabase: The created ProfileDatabase instance.
    """
    lf=pl.scan_csv(csv_path).collect().lazy() # To avoid clash when using to_csv on same file
    prof_db=cls()
    prof_db._db=lf
    prof_db._validate_db()
    return prof_db
save_as_new_database(output_path)

Save the database to a parquet file.

Parameters:

Name Type Description Default
output_path str

The path to save the database to.

required
Source code in zipstrain/src/zipstrain/database.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def save_as_new_database(self, output_path: str) -> None:
    """Save the database to a parquet file.

    Args:
        output_path (str): The path to save the database to.
    """
    #The new database must be written to a new location
    if self.db_loc is not None and str(self.db_loc.absolute()) == str(pathlib.Path(output_path).absolute()):
        raise ValueError("The output path must be different from the current database location.")

    try:
        self.db.sink_parquet(output_path)
        self.db_loc=pathlib.Path(output_path)
    ### add log later
    except Exception as e:
        pass 
to_csv(output_dir)

Writes the the current database object to a csv file"

Parameters:

Name Type Description Default
output_dir str

The path to save the CSV file.

required

Returns:

Type Description
None

None

Source code in zipstrain/src/zipstrain/database.py
207
208
209
210
211
212
213
214
215
216
def to_csv(self,output_dir:str)->None:
    """Writes the the current database object to a csv file"

    Args:
        output_dir (str): The path to save the CSV file.

    Returns:
        None
    """
    self.db.sink_csv(output_dir,engine="streaming")
update_database()

Overwrites the database saved on the disk to the current database object

Source code in zipstrain/src/zipstrain/database.py
180
181
182
183
184
185
186
187
188
def update_database(self)->None:
    """Overwrites the database saved on the disk to the current database object
    """
    if self.db_loc is None:
        raise Exception("db_loc attribute is not determined yet!")
    try:
        self.db.collect(engine="streaming").write_parquet(self.db_loc)
    except Exception as e:
        raise Exception(f"Something went wrong when updating the database:{e}")

ProfileItem

Bases: BaseModel

This class describes all necessary attributes of a profile and makes sure they comply with the necessary formating.

Source code in zipstrain/src/zipstrain/database.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ProfileItem(BaseModel):
    """
    This class describes all necessary attributes of a profile and makes sure they comply with the necessary formating.
    """
    model_config = ConfigDict(extra="forbid")
    profile_name: str = Field(description="An arbitrary name given to the profile (Usually sample name or name of the parquet file)")
    profile_location: str = Field(description="The location of the profile")
    scaffold_location: str = Field(description="The location of the scaffold")
    reference_db_id: str = Field(description="The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.")
    gene_db_id:str= Field(default="",description="The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.")

    @field_validator("profile_location","scaffold_location")
    def check_file_exists(cls, v):
        if not os.path.exists(v):
            raise ValueError(f"The file {v} does not exist.")
        return v

    @field_validator("reference_db_id","gene_db_id")
    def check_reference_db_id(cls, v):
        if not v:
            raise ValueError("The reference_db_id and gene_db_id cannot be empty.")
        return v

Profile

zipstrain.profile

This module provides functions and utilities to profile a bamfile. By profile we mean generating gene, genome, and nucleotide counts at each position on the reference. This is a fundamental step for downstream analysis in zipstrain.

build_gene_loc_table(fasta_file, scaffold)

Build a gene location table from a FASTA file.

Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Returns: pl.DataFrame: A Polars DataFrame containing gene locations.

Source code in zipstrain/src/zipstrain/profile.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def build_gene_loc_table(fasta_file:pathlib.Path,scaffold:set)->pl.DataFrame:
    """
    Build a gene location table from a FASTA file.

    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    pl.DataFrame: A Polars DataFrame containing gene locations.
    """
    scaffolds = []
    gene_ids = []
    pos=[]
    for genes in parse_gene_loc_table(fasta_file):
        if genes[1] in scaffold:
            scaffolds.extend([genes[1]]* (int(genes[3])-int(genes[2])+1))
            gene_ids.extend([genes[0]]* (int(genes[3])-int(genes[2])+1))
            pos.extend(list(range(int(genes[2]), int(genes[3])+1)))
    return pl.DataFrame({
        "scaffold":scaffolds,
        "gene":gene_ids,
        "pos":pos
    })

build_gene_range_table(fasta_file)

Build a gene location table in the form of from a FASTA file. Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Returns: pl.DataFrame: A Polars DataFrame containing gene locations.

Source code in zipstrain/src/zipstrain/profile.py
63
64
65
66
67
68
69
70
71
72
73
74
75
def build_gene_range_table(fasta_file:pathlib.Path)->pl.DataFrame:
    """
    Build a gene location table in the form of <gene scaffold start end> from a FASTA file.
    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    pl.DataFrame: A Polars DataFrame containing gene locations.
    """
    out=[]
    for parsed_annot in parse_gene_loc_table(fasta_file):
        out.append(parsed_annot)
    return pl.DataFrame(out, schema=["gene", "scaffold", "start", "end"],orient='row')

get_strain_hetrogeneity(profile, stb, min_cov=5, freq_threshold=0.8)

Calculate strain heterogeneity for each genome based on nucleotide frequencies. The definition of strain heterogeneity here is the fraction of sites that have enough coverage (min_cov) and have a dominant nucleotide with frequency less than freq_threshold.

Parameters:

Name Type Description Default
profile LazyFrame

The profile LazyFrame containing nucleotide counts.

required
stb LazyFrame

The scaffold-to-bin mapping LazyFrame. First column is 'scaffold', second column is 'bin'.

required
min_cov int

The minimum coverage threshold.

5
freq_threshold float

The frequency threshold for dominant nucleotides.

0.8

Returns: pl.LazyFrame: A LazyFrame containing strain heterogeneity information grouped by genome.

Source code in zipstrain/src/zipstrain/profile.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_strain_hetrogeneity(profile:pl.LazyFrame,
                            stb:pl.LazyFrame, 
                            min_cov=5,
                            freq_threshold=0.8)->pl.LazyFrame:
    """
    Calculate strain heterogeneity for each genome based on nucleotide frequencies.
    The definition of strain heterogeneity here is the fraction of sites that have enough coverage
    (min_cov) and have a dominant nucleotide with frequency less than freq_threshold.

    Args:
        profile (pl.LazyFrame): The profile LazyFrame containing nucleotide counts.
        stb (pl.LazyFrame): The scaffold-to-bin mapping LazyFrame. First column is 'scaffold', second column is 'bin'.
        min_cov (int): The minimum coverage threshold.
        freq_threshold (float): The frequency threshold for dominant nucleotides.

    Returns:
    pl.LazyFrame: A LazyFrame containing strain heterogeneity information grouped by genome.
    """
    # Calculate the total number of sites with sufficient coverage
    profile = profile.with_columns(
        (pl.col("A")+pl.col("T")+pl.col("C")+pl.col("G")).alias("coverage")
    ).filter(pl.col("coverage") >= min_cov)

    profile = profile.with_columns(
        (pl.max_horizontal(["A", "T", "C", "G"])/pl.col("coverage") < freq_threshold)
        .cast(pl.Int8)
        .alias("heterogeneous_site")
    )

    profile = profile.join(stb, left_on="chrom", right_on="scaffold", how="left").group_by("genome").agg([
        pl.len().alias(f"total_sites_at_{min_cov}_coverage"),
        pl.sum("heterogeneous_site").alias("heterogeneous_sites")
    ])

    strain_heterogeneity = profile.with_columns(
        (pl.col("heterogeneous_sites")/pl.col(f"total_sites_at_{min_cov}_coverage")).alias("strain_heterogeneity")
    )
    return strain_heterogeneity

parse_gene_loc_table(fasta_file)

Extract gene locations from a FASTA assuming it is from prodigal yield gene info.

Parameters: fasta_file (pathlib.Path): Path to the FASTA file.

Tuple: A tuple containing: - gene_ID - scaffold - start - end

Source code in zipstrain/src/zipstrain/profile.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def parse_gene_loc_table(fasta_file:pathlib.Path) -> Generator[tuple,None,None]:
    """
    Extract gene locations from a FASTA assuming it is from prodigal yield gene info.

    Parameters:
    fasta_file (pathlib.Path): Path to the FASTA file.

    Returns:
    Tuple: A tuple containing:
        - gene_ID
        - scaffold
        - start
        - end
    """
    with open(fasta_file, 'r') as f:
        for line in f:
            if line.startswith('>'):
                parts = line[1:].strip().split()
                gene_id = parts[0]
                scaffold = "_".join(gene_id.split('_')[:-1])
                start = parts[2]
                end=parts[4]      
                yield gene_id, scaffold,start,end

profile_bam(bed_file, bam_file, gene_range_table, output_dir, num_workers=4)

Profile a BAM file in chunks using provided BED files.

Parameters: bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled. bam_file (pathlib.Path): Path to the BAM file. gene_range_table (pathlib.Path): Path to the gene range table. output_dir (pathlib.Path): Directory to save output files. num_workers (int): Number of concurrent workers to use.

Source code in zipstrain/src/zipstrain/profile.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def profile_bam(
    bed_file:str,
    bam_file:str,
    gene_range_table:str,
    output_dir:str,
    num_workers:int=4
)->None:
    """
    Profile a BAM file in chunks using provided BED files.

    Parameters:
    bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
    bam_file (pathlib.Path): Path to the BAM file.
    gene_range_table (pathlib.Path): Path to the gene range table.
    output_dir (pathlib.Path): Directory to save output files.
    num_workers (int): Number of concurrent workers to use.
    """
    asyncio.run(profile_bam_in_chunks(
        bed_file=bed_file,
        bam_file=bam_file,
        gene_range_table=gene_range_table,
        output_dir=output_dir,
        num_workers=num_workers
    ))

profile_bam_in_chunks(bed_file, bam_file, gene_range_table, output_dir, num_workers=4) async

Profile a BAM file in chunks using provided BED files.

Parameters: bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled. bam_file (pathlib.Path): Path to the BAM file. gene_range_table (pathlib.Path): Path to the gene range table. output_dir (pathlib.Path): Directory to save output files. num_workers (int): Number of concurrent workers to use.

Source code in zipstrain/src/zipstrain/profile.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
async def profile_bam_in_chunks(
    bed_file:str,
    bam_file:str,
    gene_range_table:str,
    output_dir:str,
    num_workers:int=4
)->None:
    """
    Profile a BAM file in chunks using provided BED files.

    Parameters:
    bed_file (list[pathlib.Path]): A bed file describing all regions to be profiled.
    bam_file (pathlib.Path): Path to the BAM file.
    gene_range_table (pathlib.Path): Path to the gene range table.
    output_dir (pathlib.Path): Directory to save output files.
    num_workers (int): Number of concurrent workers to use.
    """

    output_dir=pathlib.Path(output_dir)
    bam_file=pathlib.Path(bam_file)
    bed_file=pathlib.Path(bed_file)
    gene_range_table=pathlib.Path(gene_range_table)

    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir/"tmp").mkdir(exist_ok=True)
    bed_lf=pl.scan_csv(bed_file,has_header=False,separator="\t")
    bed_chunks=utils.split_lf_to_chunks(bed_lf,num_workers)
    bed_chunk_files=[]
    for chunk_id, bed_file in enumerate(bed_chunks):
        bed_file.sink_csv(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed",include_header=False,separator="\t")
        bed_chunk_files.append(output_dir/"tmp"/f"bed_chunk_{chunk_id}.bed")
    tasks = []
    for chunk_id, bed_chunk_file in enumerate(bed_chunk_files):
        tasks.append(_profile_chunk_task(
            bed_file=bed_chunk_file,
            bam_file=bam_file,
            gene_range_table=gene_range_table,
            output_dir=output_dir/"tmp",
            chunk_id=chunk_id
        ))
    await asyncio.gather(*tasks) 
    pfs=[output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet" for chunk_id in range(len(bed_chunk_files)) if (output_dir/"tmp"/f"{bam_file.stem}_{chunk_id}.parquet").exists()]
    mpileup_df = pl.concat([pl.scan_parquet(pf) for pf in pfs])
    mpileup_df.sink_parquet(output_dir/f"{bam_file.stem}.parquet", compression='zstd')
    os.system(f"rm -r {output_dir}/tmp")

Compare

zipstrain.compare

This module provides all comparison functions for zipstrain.

PolarsANIExpressions

Any kind of ANI calculation based on two profiles should be implemented as a method of this class. In defining this method, the following rules should be followed:

  • The method returns a Polars expression (pl.Expr).

  • When applied to a row, the method returns a zero if that position is a SNV. Otherwise it should return a number greater than zero.

  • A, T, C, G columns in the first profile are named "A", "T", "C", "G" and in the second profile they are named "A_2", "T_2", "C_2", "G_2".

  • popani: Population ANI based on the shared alleles between two profiles.

  • conani: Consensus ANI based on the consensus alleles between two profiles.
  • cosani_: Generalized cosine similarity ANI where threshold is a float value between 0 and 1. Once the similarity is below the threshold, it is considered a SNV.
Source code in zipstrain/src/zipstrain/compare.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class PolarsANIExpressions:
    """ 
    Any kind of ANI calculation based on two profiles should be implemented as a method of this class.
    In defining this method, the following rules should be followed:

    -   The method returns a Polars expression (pl.Expr).

    -   When applied to a row, the method returns a zero if that position is a SNV. Otherwise it should return a number greater than zero.

    -   A, T, C, G columns in the first profile are named "A", "T", "C", "G" and in the second profile they are named "A_2", "T_2", "C_2", "G_2".

    1. popani: Population ANI based on the shared alleles between two profiles.
    2. conani: Consensus ANI based on the consensus alleles between two profiles.
    3. cosani_<threshold>: Generalized cosine similarity ANI where threshold is a float value between 0 and 1. Once the similarity is below the threshold, it is considered a SNV.
    """
    MPILE_1_BASES = ["A", "T", "C", "G"]
    MPILE_2_BASES = ["A_2", "T_2", "C_2", "G_2"]

    def popani(self):
        return pl.col("A")*pl.col("A_2") + pl.col("C")*pl.col("C_2") + pl.col("G")*pl.col("G_2") + pl.col("T")*pl.col("T_2")

    def conani(self):
        max_base_1=pl.max_horizontal(*[pl.col(base) for base in self.MPILE_1_BASES])
        max_base_2=pl.max_horizontal(*[pl.col(base) for base in self.MPILE_2_BASES])
        return pl.when((pl.col("A")==max_base_1) & (pl.col("A_2")==max_base_2) | 
                       (pl.col("T")==max_base_1) & (pl.col("T_2")==max_base_2) | 
                       (pl.col("C")==max_base_1) & (pl.col("C_2")==max_base_2) | 
                       (pl.col("G")==max_base_1) & (pl.col("G_2")==max_base_2)).then(1).otherwise(0)

    def generalized_cos_ani(self,threshold:float=0.4):
        dot_product = pl.col("A")*pl.col("A_2") + pl.col("C")*pl.col("C_2") + pl.col("G")*pl.col("G_2") + pl.col("T")*pl.col("T_2")
        magnitude_1 = (pl.col("A")**2 + pl.col("C")**2 + pl.col("G")**2 + pl.col("T")**2)**0.5
        magnitude_2 = (pl.col("A_2")**2 + pl.col("C_2")**2 + pl.col("G_2")**2 + pl.col("T_2")**2)**0.5
        cos_sim = dot_product / (magnitude_1 * magnitude_2)
        return pl.when(cos_sim >= threshold).then(1).otherwise(0)

    def __getattribute__(self, name):
        if name.startswith("cosani_"):
            try:
                threshold = float(name.split("_")[1])
            except ValueError:
                raise AttributeError(f"Invalid threshold in method name: {name}")
            return lambda: self.generalized_cos_ani(threshold)
        else:
            return super().__getattribute__(name)

add_contiguity_info(mpile_contig)

Adds group id information to the lazy frame. If on the same scaffold and not popANI, then they are in the same group.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with group id information added.

Source code in zipstrain/src/zipstrain/compare.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def add_contiguity_info(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """ Adds group id information to the lazy frame. If on the same scaffold and not popANI, then they are in the same group.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with group id information added.
    """

    mpile_contig= mpile_contig.sort(["scaffold", "pos"])
    mpile_contig = mpile_contig.with_columns([
        (pl.col("scaffold").shift(1).fill_null(pl.col("scaffold")).alias("prev_scaffold")),
    ])
    mpile_contig = mpile_contig.with_columns([
        (((pl.col("scaffold") != pl.col("prev_scaffold")) | (pl.col("surr") == 0))).cum_sum().alias("group_id")
    ])
    return mpile_contig

add_genome_info(mpile_contig, scaffold_to_genome)

Adds genome information to the mpileup LazyFrame based on scaffold to genome mapping.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
scaffold_to_genome LazyFrame

The LazyFrame mapping scaffolds to genomes.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with genome information added.

Source code in zipstrain/src/zipstrain/compare.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def add_genome_info(mpile_contig:pl.LazyFrame, scaffold_to_genome:pl.LazyFrame) -> pl.LazyFrame:
    """
    Adds genome information to the mpileup LazyFrame based on scaffold to genome mapping.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        scaffold_to_genome (pl.LazyFrame): The LazyFrame mapping scaffolds to genomes.

    Returns:
        pl.LazyFrame: Updated LazyFrame with genome information added.
    """
    return mpile_contig.join(
        scaffold_to_genome, on="scaffold", how="left"
    ).fill_null("NA")

adjust_for_sequence_errors(mpile_frame, null_model)

Adjust the mpile frame for sequence errors based on the null model.

Parameters:

Name Type Description Default
mpile_frame LazyFrame

The input LazyFrame containing coverage data.

required
null_model LazyFrame

The null model LazyFrame containing error counts.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Adjusted LazyFrame with sequence errors accounted for.

Source code in zipstrain/src/zipstrain/compare.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def adjust_for_sequence_errors(mpile_frame:pl.LazyFrame, null_model:pl.LazyFrame) -> pl.LazyFrame:
    """
    Adjust the mpile frame for sequence errors based on the null model.

    Args:
        mpile_frame (pl.LazyFrame): The input LazyFrame containing coverage data.
        null_model (pl.LazyFrame): The null model LazyFrame containing error counts.

    Returns:
        pl.LazyFrame: Adjusted LazyFrame with sequence errors accounted for.
    """
    return mpile_frame.join(null_model, on="cov", how="left").with_columns([
        pl.when(pl.col(base) >= pl.col("max_error_count"))
        .then(pl.col(base))
        .otherwise(0)
        .alias(base)
        for base in ["A", "T", "C", "G"]
    ]).drop("max_error_count")

calculate_pop_ani(mpile_contig)

Calculates the population ANI (Average Nucleotide Identity) for the given mpileup LazyFrame. NOTE: Remember that this function should be applied to the merged mpileup using get_shared_locs.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with population ANI information added.

Source code in zipstrain/src/zipstrain/compare.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def calculate_pop_ani(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """
    Calculates the population ANI (Average Nucleotide Identity) for the given mpileup LazyFrame.
    NOTE: Remember that this function should be applied to the merged mpileup using get_shared_locs.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with population ANI information added.
    """
    return mpile_contig.group_by("genome").agg(
            total_positions=pl.len(),
            share_allele_pos=(pl.col("surr") > 0 ).sum()
        ).with_columns(
            genome_pop_ani=pl.col("share_allele_pos")/pl.col("total_positions")*100,
        )

compare_genes(mpile_contig_1, mpile_contig_2, null_model, scaffold_to_genome, min_cov=5, min_gene_compare_len=100, engine='streaming', ani_method='popani')

Compares two profiles and generates gene-level comparison statistics. The final output is a Polars LazyFrame with gene comparison statistics in the following columns: - genome: The genome identifier. - gene: The gene identifier. - total_positions: Total number of positions compared in the gene. - share_allele_pos: Number of positions with shared alleles in the gene. - ani: Average Nucleotide Identity (ANI) percentage for the gene.

Parameters:

Name Type Description Default
mpile_contig_1 LazyFrame

The first profile as a LazyFrame.

required
mpile_contig_2 LazyFrame

The second profile as a LazyFrame.

required
null_model LazyFrame

The null model LazyFrame that contains the thresholds for sequence error adjustment.

required
scaffold_to_genome LazyFrame

A mapping LazyFrame from scaffolds to genomes.

required
min_cov int

Minimum coverage threshold for filtering positions. Default is 5.

5
min_gene_compare_len int

Minimum length of genes that needs to be covered to consider for comparison. Default is 100.

100
engine str

The Polars engine to use for computation. Default is "streaming".

'streaming'
ani_method str

The ANI calculation method to use. Default is "popani".

'popani'

Returns:

Type Description
LazyFrame

pl.LazyFrame: A LazyFrame containing gene-level comparison statistics.

Source code in zipstrain/src/zipstrain/compare.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def compare_genes(mpile_contig_1:pl.LazyFrame,
              mpile_contig_2:pl.LazyFrame,
              null_model:pl.LazyFrame,
              scaffold_to_genome:pl.LazyFrame,
              min_cov:int=5,
              min_gene_compare_len:int=100,
              engine="streaming",
              ani_method:str="popani"
            )-> pl.LazyFrame:
    """
    Compares two profiles and generates gene-level comparison statistics.
    The final output is a Polars LazyFrame with gene comparison statistics in the following columns:
    - genome: The genome identifier.
    - gene: The gene identifier.
    - total_positions: Total number of positions compared in the gene.
    - share_allele_pos: Number of positions with shared alleles in the gene.
    - ani: Average Nucleotide Identity (ANI) percentage for the gene.

    Args:
        mpile_contig_1 (pl.LazyFrame): The first profile as a LazyFrame.
        mpile_contig_2 (pl.LazyFrame): The second profile as a LazyFrame.
        null_model (pl.LazyFrame): The null model LazyFrame that contains the thresholds for sequence error adjustment.
        scaffold_to_genome (pl.LazyFrame): A mapping LazyFrame from scaffolds to genomes.
        min_cov (int): Minimum coverage threshold for filtering positions. Default is 5.
        min_gene_compare_len (int): Minimum length of genes that needs to be covered to consider for comparison. Default is 100.
        engine (str): The Polars engine to use for computation. Default is "streaming".
        ani_method (str): The ANI calculation method to use. Default is "popani".

    Returns:
        pl.LazyFrame: A LazyFrame containing gene-level comparison statistics.
    """
    lf1=coverage_filter(mpile_contig_1, min_cov,engine=engine)
    lf1=adjust_for_sequence_errors(lf1, null_model)
    lf2=coverage_filter(mpile_contig_2, min_cov,engine=engine)
    lf2=adjust_for_sequence_errors(lf2, null_model)
    ### Now we need to only keep (scaffold, pos) that are in both lf1 and lf2
    lf = get_shared_locs(lf1, lf2, ani_method=ani_method)
    ## Let's add genome information for all scaffolds and positions
    lf = add_genome_info(lf, scaffold_to_genome)
    ## Let's calculate gene ani for each gene in each genome
    gene_comp = lf.group_by(["genome", "gene"]).agg(
        total_positions=pl.len(),
        share_allele_pos=(pl.col("surr") > 0).sum()
    ).filter(pl.col("total_positions") >= min_gene_compare_len).with_columns(
        ani=pl.col("share_allele_pos") / pl.col("total_positions") * 100,
    )
    return gene_comp

compare_genomes(mpile_contig_1, mpile_contig_2, null_model, scaffold_to_genome, min_cov=5, min_gene_compare_len=100, memory_mode='heavy', chrom_batch_size=10000, shared_scaffolds=None, scaffold_scope=None, engine='streaming', ani_method='popani')

Compares two profiles and generates genome-level comparison statistics. The final output is a Polars LazyFrame with genome comparison statisticsin the following columns:

  • genome: The genome identifier.

  • total_positions: Total number of positions compared.

  • share_allele_pos: Number of positions with shared alleles.

  • genome_pop_ani: Population ANI percentage.

  • max_consecutive_length: Length of the longest consecutive block of shared alleles.

  • shared_genes_count: Number of genes compared.

  • identical_gene_count: Number of identical genes.

  • perc_id_genes: Percentage of identical genes.

Parameters:

Name Type Description Default
mpile_contig_1 LazyFrame

The first profile as a LazyFrame.

required
mpile_contig_2 LazyFrame

The second profile as a LazyFrame.

required
null_model LazyFrame

The null model LazyFrame that contains the thresholds for sequence error adjustment.

required
scaffold_to_genome LazyFrame

A mapping LazyFrame from scaffolds to genomes.

required
min_cov int

Minimum coverage threshold for filtering positions. Default is 5.

5
min_gene_compare_len int

Minimum length of genes that needs to be covered to consider for comparison. Default is 100.

100
memory_mode str

Memory mode for processing. Options are "heavy" or "light". Default is "heavy".

'heavy'
chrom_batch_size int

Batch size for processing scaffolds in light memory mode. Default

10000
shared_scaffolds list

List of shared scaffolds between the two profiles. Required for light memory mode.

None
scaffold_scope list

List of scaffolds to limit the comparison to. Default is None.

None
engine str

The Polars engine to use for computation. Default is "streaming".

'streaming'
ani_method str

The ANI calculation method to use. Default is "popani".

'popani'

Returns:

Type Description
LazyFrame

pl.LazyFrame: A LazyFrame containing genome-level comparison statistics.

Source code in zipstrain/src/zipstrain/compare.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def compare_genomes(mpile_contig_1:pl.LazyFrame,
              mpile_contig_2:pl.LazyFrame,
              null_model:pl.LazyFrame,
              scaffold_to_genome:pl.LazyFrame,
              min_cov:int=5,
              min_gene_compare_len:int=100,
              memory_mode:str="heavy",
              chrom_batch_size:int=10000,
              shared_scaffolds:list=None,
              scaffold_scope:list=None,
              engine="streaming",
              ani_method:str="popani"
            )-> pl.LazyFrame:
    """
    Compares two profiles and generates genome-level comparison statistics.
    The final output is a Polars LazyFrame with genome comparison statisticsin the following columns:

    - genome: The genome identifier.

    - total_positions: Total number of positions compared.

    - share_allele_pos: Number of positions with shared alleles.

    - genome_pop_ani: Population ANI percentage.

    - max_consecutive_length: Length of the longest consecutive block of shared alleles.

    - shared_genes_count: Number of genes compared.

    - identical_gene_count: Number of identical genes.

    - perc_id_genes: Percentage of identical genes.

    Args:
        mpile_contig_1 (pl.LazyFrame): The first profile as a LazyFrame.
        mpile_contig_2 (pl.LazyFrame): The second profile as a LazyFrame.
        null_model (pl.LazyFrame): The null model LazyFrame that contains the thresholds for sequence error adjustment.
        scaffold_to_genome (pl.LazyFrame): A mapping LazyFrame from scaffolds to genomes.
        min_cov (int): Minimum coverage threshold for filtering positions. Default is 5.
        min_gene_compare_len (int): Minimum length of genes that needs to be covered to consider for comparison. Default is 100.
        memory_mode (str): Memory mode for processing. Options are "heavy" or "light". Default is "heavy".
        chrom_batch_size (int): Batch size for processing scaffolds in light memory mode. Default
        shared_scaffolds (list): List of shared scaffolds between the two profiles. Required for light memory mode.
        scaffold_scope (list): List of scaffolds to limit the comparison to. Default is None.
        engine (str): The Polars engine to use for computation. Default is "streaming".
        ani_method (str): The ANI calculation method to use. Default is "popani".

    Returns:
        pl.LazyFrame: A LazyFrame containing genome-level comparison statistics.
    """
    if memory_mode == "heavy":
        if scaffold_scope is not None:
            mpile_contig_1 = mpile_contig_1.filter(pl.col("chrom").is_in(scaffold_scope)).collect(engine=engine).lazy()
            mpile_contig_2 = mpile_contig_2.filter(pl.col("chrom").is_in(scaffold_scope)).collect(engine=engine).lazy()
        lf1=coverage_filter(mpile_contig_1, min_cov,engine=engine)
        lf1=adjust_for_sequence_errors(lf1, null_model)
        lf2=coverage_filter(mpile_contig_2, min_cov,engine=engine)
        lf2=adjust_for_sequence_errors(lf2, null_model)
        ### Now we need to only keep (scaffold, pos) that are in both lf1 and lf2
        lf = get_shared_locs(lf1, lf2, ani_method=ani_method)
        ## Add Contiguity Information
        lf = add_contiguity_info(lf)
        ## Let's add genome information for all scaffolds and positions
        lf = add_genome_info(lf, scaffold_to_genome)
        ## Let's calculate popANI
        genome_comp= calculate_pop_ani(lf)
        ## Calculate longest consecutive blocks
        max_consecutive_per_genome = get_longest_consecutive_blocks(lf)
        ## Calculate gene ani for each gene in each genome
        gene= get_gene_ani(lf, min_gene_compare_len)
        genome_comp=genome_comp.join(max_consecutive_per_genome, on="genome", how="left")
        genome_comp=genome_comp.join(gene, on="genome", how="left")

    elif memory_mode == "light":
        shared_scaffolds_batches = [shared_scaffolds[i:i + chrom_batch_size] for i in range(0, len(shared_scaffolds), chrom_batch_size)]
        lf_list=[]
        for scaffold in shared_scaffolds_batches:
            lf1= coverage_filter(mpile_contig_1.filter(pl.col("chrom").is_in(scaffold)), min_cov)
            lf1=adjust_for_sequence_errors(lf1, null_model)
            lf2= coverage_filter(mpile_contig_2.filter(pl.col("chrom").is_in(scaffold)), min_cov)
            lf2=adjust_for_sequence_errors(lf2, null_model)
            ### Now we need to only keep (scaffold, pos) that are in both lf1 and lf2
            lf = get_shared_locs(lf1, lf2, ani_method=ani_method)
            ## Lets add contiguity information
            lf= add_contiguity_info(lf)
            lf_list.append(lf)
        lf= pl.concat(lf_list)
        lf= add_genome_info(lf, scaffold_to_genome)
        genome_comp= calculate_pop_ani(lf)
        max_consecutive_per_genome = get_longest_consecutive_blocks(lf)
        gene= get_gene_ani(lf, min_gene_compare_len)
        genome_comp=genome_comp.join(max_consecutive_per_genome, on="genome", how="left")
        genome_comp=genome_comp.join(gene, on="genome", how="left")
    else:
        raise ValueError("Invalid memory_mode. Choose either 'heavy' or 'light'.")
    return genome_comp

coverage_filter(mpile_frame, min_cov, engine)

Filter the mpile lazyframe based on minimum coverage at each loci.

Parameters:

Name Type Description Default
mpile_frame LazyFrame

The input LazyFrame containing coverage data.

required
min_cov int

The minimum coverage threshold.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Filtered LazyFrame with positions having coverage >= min_cov.

Source code in zipstrain/src/zipstrain/compare.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def coverage_filter(mpile_frame:pl.LazyFrame, min_cov:int,engine:str)-> pl.LazyFrame:
    """
    Filter the mpile lazyframe based on minimum coverage at each loci.

    Args:
        mpile_frame (pl.LazyFrame): The input LazyFrame containing coverage data.
        min_cov (int): The minimum coverage threshold.

    Returns:
        pl.LazyFrame: Filtered LazyFrame with positions having coverage >= min_cov.
    """
    mpile_frame = mpile_frame.with_columns(
        (pl.col("A") + pl.col("C") + pl.col("G") + pl.col("T")).alias("cov")
    )
    return mpile_frame.filter(pl.col("cov") >= min_cov).collect(engine=engine).lazy()

get_gene_ani(mpile_contig, min_gene_compare_len)

Calculates gene ANI (Average Nucleotide Identity) for each gene in each genome.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
min_gene_compare_len int

Minimum length of the gene to consider for comparison.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with gene ANI information added.

Source code in zipstrain/src/zipstrain/compare.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def get_gene_ani(mpile_contig:pl.LazyFrame, min_gene_compare_len:int) -> pl.LazyFrame:
    """
    Calculates gene ANI (Average Nucleotide Identity) for each gene in each genome.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        min_gene_compare_len (int): Minimum length of the gene to consider for comparison.

    Returns:
        pl.LazyFrame: Updated LazyFrame with gene ANI information added.
    """
    return mpile_contig.group_by(["genome", "gene"]).agg(
        total_positions=pl.len(),
        share_allele_pos=(pl.col("surr") > 0).sum()
    ).filter(pl.col("total_positions") >= min_gene_compare_len).with_columns(
        identical=(pl.col("share_allele_pos") == pl.col("total_positions")),
    ).filter(pl.col("gene") != "NA").group_by("genome").agg(
        shared_genes_count=pl.len(),
        identical_gene_count=pl.col("identical").sum()
    ).with_columns(perc_id_genes=pl.col("identical_gene_count") / pl.col("shared_genes_count") * 100)

get_longest_consecutive_blocks(mpile_contig)

Calculates the longest consecutive blocks for each genome in the mpileup LazyFrame for any genome.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required

Returns:

Type Description
LazyFrame

pl.LazyFrame: Updated LazyFrame with longest consecutive blocks information added.

Source code in zipstrain/src/zipstrain/compare.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_longest_consecutive_blocks(mpile_contig:pl.LazyFrame) -> pl.LazyFrame:
    """
    Calculates the longest consecutive blocks for each genome in the mpileup LazyFrame for any genome.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.

    Returns:
        pl.LazyFrame: Updated LazyFrame with longest consecutive blocks information added.
    """
    block_lengths = (
        mpile_contig.group_by(["genome", "scaffold", "group_id"])
        .agg(pl.len().alias("length"))
    ) 
    return block_lengths.group_by("genome").agg(pl.col("length").max().alias("max_consecutive_length"))

get_shared_locs(mpile_contig_1, mpile_contig_2, ani_method='popani')

Returns a lazyframe with ATCG information for shared scaffolds and positions between two mpileup files.

Parameters:

Name Type Description Default
mpile_contig_1 LazyFrame

The first mpileup LazyFrame.

required
mpile_contig_2 LazyFrame

The second mpileup LazyFrame.

required
ani_method str

The ANI calculation method to use. Default is "popani".

'popani'

Returns:

Type Description
LazyFrame

pl.LazyFrame: Merged LazyFrame containing shared scaffolds and positions with ATCG information.

Source code in zipstrain/src/zipstrain/compare.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def get_shared_locs(mpile_contig_1:pl.LazyFrame, mpile_contig_2:pl.LazyFrame,ani_method:str="popani") -> pl.LazyFrame:
    """
    Returns a lazyframe with ATCG information for shared scaffolds and positions between two mpileup files.

    Args:
        mpile_contig_1 (pl.LazyFrame): The first mpileup LazyFrame.
        mpile_contig_2 (pl.LazyFrame): The second mpileup LazyFrame.
        ani_method (str): The ANI calculation method to use. Default is "popani".

    Returns:
        pl.LazyFrame: Merged LazyFrame containing shared scaffolds and positions with ATCG information.
    """
    ani_expr=getattr(PolarsANIExpressions(), ani_method)()

    mpile_contig= mpile_contig_1.join(
        mpile_contig_2,
        on=["chrom", "pos"],
        how="inner",
        suffix="_2"  # To distinguish lf2 columns
    ).with_columns(
        ani_expr.alias("surr")
    ).select(
        pl.col("surr"),
        scaffold=pl.col("chrom"),
        pos=pl.col("pos"),
        gene=pl.col("gene")
    )
    return mpile_contig

get_unique_scaffolds(mpile_contig, batch_size=10000)

Retrieves unique scaffolds from the mpileup LazyFrame.

Parameters:

Name Type Description Default
mpile_contig LazyFrame

The input LazyFrame containing mpileup data.

required
batch_size int

The number of rows to process in each batch. Default is 10000.

10000

Returns: set: A set of unique scaffold names.

Source code in zipstrain/src/zipstrain/compare.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_unique_scaffolds(mpile_contig:pl.LazyFrame,batch_size:int=10000) -> set:
    """
    Retrieves unique scaffolds from the mpileup LazyFrame.

    Args:
        mpile_contig (pl.LazyFrame): The input LazyFrame containing mpileup data.
        batch_size (int): The number of rows to process in each batch. Default is 10000.
    Returns:
        set: A set of unique scaffold names.
    """
    scaffolds = set()
    start_index = 0
    while True:
        batch = mpile_contig.slice(start_index, batch_size).select("chrom").collect()
        if batch.height == 0:
            break
        scaffolds.update(batch["chrom"].to_list())
        start_index += batch_size
    return scaffolds 

Utils

zipstrain.utils

This module provides utility functions for profiling and compare operations.

build_null_poisson(error_rate=0.001, max_total_reads=10000, p_threshold=0.05)

Build a null model to correct for sequencing errors based on the Poisson distribution.

Parameters: error_rate (float): Error rate for the sequencing technology. max_total_reads (int): Maximum total reads to consider. p_threshold (float): Significance threshold for the Poisson distribution.

Returns: pl.DataFrame: DataFrame containing total reads and maximum error count thresholds.

Source code in zipstrain/src/zipstrain/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def build_null_poisson(error_rate:float=0.001,
                       max_total_reads:int=10000,
                       p_threshold:float=0.05)->list[float]:
    """
    Build a null model to correct for sequencing errors based on the Poisson distribution.

    Parameters:
    error_rate (float): Error rate for the sequencing technology.
    max_total_reads (int): Maximum total reads to consider.
    p_threshold (float): Significance threshold for the Poisson distribution.

    Returns:
    pl.DataFrame: DataFrame containing total reads and maximum error count thresholds.
    """ 
    records = []
    for n in range(1, max_total_reads + 1):
        lam = n * (error_rate / 3)
        k = 0
        while poisson.sf(k - 1, lam) > p_threshold:
            k += 1
        records.append((n, k - 1))
    return records

clean_bases(bases, indel_re)

Remove read start/end markers and indels from bases string using regex. Returns cleaned uppercase string of bases only. Args: bases (str): The bases string from mpileup. indel_re (re.Pattern): Compiled regex pattern to match indels and markers.

Source code in zipstrain/src/zipstrain/utils.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def clean_bases(bases: str, indel_re: re.Pattern) -> str:
    """
    Remove read start/end markers and indels from bases string using regex.
    Returns cleaned uppercase string of bases only.
    Args:
        bases (str): The bases string from mpileup.
        indel_re (re.Pattern): Compiled regex pattern to match indels and markers.

    """
    cleaned = []
    i = 0
    while i < len(bases):
        m = indel_re.match(bases, i)
        if m:
            if m.group(0).startswith('+') or m.group(0).startswith('-'):
                # indel length
                indel_len = int(m.group(1))
                i = m.end() + indel_len
            else:
                i = m.end()
        else:
            cleaned.append(bases[i].upper())
            i += 1
    return ''.join(cleaned)

collect_breadth_tables(breadth_tables)

Collect multiple genome breadth tables into a single LazyFrame.

Parameters: breadth_tables (list[pl.LazyFrame]): List of LazyFrames containing genome breadth data.

Returns: pl.LazyFrame: A LazyFrame containing the combined genome breadth data.

Source code in zipstrain/src/zipstrain/utils.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def collect_breadth_tables(
    breadth_tables: list[pl.LazyFrame],
) -> pl.LazyFrame:
    """
    Collect multiple genome breadth tables into a single LazyFrame.

    Parameters:
    breadth_tables (list[pl.LazyFrame]): List of LazyFrames containing genome breadth data.

    Returns:
    pl.LazyFrame: A LazyFrame containing the combined genome breadth data.
    """
    if not breadth_tables:
        raise ValueError("No breadth tables provided.")

    return reduce(lambda x, y: x.join(y, on="genome", how="outer", coalesce=True), breadth_tables)

count_bases(bases)

Count occurrences of A, C, G, T in the cleaned bases string. Args: bases (str): Cleaned bases string. Returns: dict: Dictionary with counts of A, C, G, T.

Source code in zipstrain/src/zipstrain/utils.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def count_bases(bases: str):
    """
    Count occurrences of A, C, G, T in the cleaned bases string.
    Args:
        bases (str): Cleaned bases string.
    Returns:
        dict: Dictionary with counts of A, C, G, T.
    """
    counts = Counter(bases)
    return {
        'A': counts.get('A', 0),
        'C': counts.get('C', 0),
        'G': counts.get('G', 0),
        'T': counts.get('T', 0),
    }

extract_genome_length(stb, bed_table)

Extract the genome length information from the scaffold-to-genome mapping table.

Parameters: stb (pl.LazyFrame): Scaffold-to-bin mapping table. bed_table (pl.LazyFrame): BED table containing genomic regions.

Returns: pl.LazyFrame: A LazyFrame containing the genome lengths.

Source code in zipstrain/src/zipstrain/utils.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def extract_genome_length(stb: pl.LazyFrame, bed_table: pl.LazyFrame) -> pl.LazyFrame:
    """
    Extract the genome length information from the scaffold-to-genome mapping table.

    Parameters:
    stb (pl.LazyFrame): Scaffold-to-bin mapping table.
    bed_table (pl.LazyFrame): BED table containing genomic regions.

    Returns:
    pl.LazyFrame: A LazyFrame containing the genome lengths.
    """
    lf= bed_table.select(
        pl.col("scaffold"),
        (pl.col("end") - pl.col("start")).alias("scaffold_length")
    ).group_by("scaffold").agg(
        scaffold_length=pl.sum("scaffold_length")
    ).select(
        pl.col("scaffold").alias("scaffold"),
        pl.col("scaffold_length")
    ).join(
        stb.select(
            pl.col("scaffold").alias("scaffold"),
            pl.col("genome").alias("genome")
        ),
        on="scaffold",
        how="left"
    ).group_by("genome").agg(
        genome_length=pl.sum("scaffold_length")
    ).select(
        pl.col("genome"),
        pl.col("genome_length")
    )
    return lf

get_genome_breadth_matrix(profile, name, genome_length, stb, min_cov=1)

Get the genome breadth matrix from the provided profiles and scaffold-to-genome mapping. Parameters: profiles (list): List of tuples containing profile names and their corresponding LazyFrames. stb (pl.LazyFrame): Scaffold-to-genome mapping table. min_cov (int): Minimum coverage to consider a position. Returns: pl.LazyFrame: A LazyFrame containing the genome breadth matrix.

Source code in zipstrain/src/zipstrain/utils.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def get_genome_breadth_matrix(
                              profile:pl.LazyFrame,
                              name:str,
                              genome_length: pl.LazyFrame,
                              stb: pl.LazyFrame,
                              min_cov: int = 1)-> pl.LazyFrame:
    """
    Get the genome breadth matrix from the provided profiles and scaffold-to-genome mapping.
    Parameters:
    profiles (list): List of tuples containing profile names and their corresponding LazyFrames.
    stb (pl.LazyFrame): Scaffold-to-genome mapping table.
    min_cov (int): Minimum coverage to consider a position. 
    Returns:
    pl.LazyFrame: A LazyFrame containing the genome breadth matrix.
    """
    profile = profile.filter((pl.col("A") + pl.col("C") + pl.col("G") + pl.col("T")) >= min_cov)
    profile=profile.group_by("chrom").agg(
        breadth=pl.count()
    ).select(
        pl.col("chrom").alias("scaffold"),
        pl.col("breadth")
    ).join(
        stb,
        on="scaffold",
        how="left"
    )
    profile=profile.join(genome_length, on="genome", how="left")

    profile=profile.group_by("genome").agg(
        genome_length=pl.first("genome_length"),
        breadth=pl.col("breadth").sum())
    profile = profile.with_columns(
        (pl.col("breadth")/ pl.col("genome_length")).alias("breadth")
    )
    return profile.select(
            pl.col("genome"),
            pl.col("breadth").alias(name)
        )

make_the_bed(db_fasta_dir, max_scaffold_length=500000)

Create a BED file from the database in fasta format.

Parameters: db_fasta_dir (Union[str, pathlib.Path]): Path to the fasta file. max_scaffold_length (int): Splits scaffolds longer than this into multiple entries of length <= max_scaffold_length.

Returns: pl.LazyFrame: A LazyFrame containing the BED data.

Source code in zipstrain/src/zipstrain/utils.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def make_the_bed(db_fasta_dir: str | pathlib.Path, max_scaffold_length: int = 500_000) -> pl.DataFrame:
    """
    Create a BED file from the database in fasta format.

    Parameters:
    db_fasta_dir (Union[str, pathlib.Path]): Path to the fasta file.
    max_scaffold_length (int): Splits scaffolds longer than this into multiple entries of length <= max_scaffold_length.

    Returns:
    pl.LazyFrame: A LazyFrame containing the BED data.
    """
    db_fasta_dir = pathlib.Path(db_fasta_dir)
    if not db_fasta_dir.is_file():
        raise FileNotFoundError(f"{db_fasta_dir} is not a valid fasta file.")

    records = []
    with db_fasta_dir.open() as f:
        scaffold = None
        seq_chunks = []

        for line in f:
            line = line.strip()
            if line.startswith(">"):
                # Process the previous scaffold
                if scaffold is not None:
                    seq = ''.join(seq_chunks)
                    for start in range(0, len(seq), max_scaffold_length):
                        end = min(start + max_scaffold_length, len(seq))
                        records.append((scaffold, start, end))
                # Start new scaffold
                scaffold = line[1:].split()[0]  # ID only (up to first whitespace)
                seq_chunks = []
            else:
                seq_chunks.append(line)

        # Don't forget the last scaffold
        if scaffold is not None:
            seq = ''.join(seq_chunks)
            for start in range(0, len(seq), max_scaffold_length):
                end = min(start + max_scaffold_length, len(seq))
                records.append((scaffold, start, end))

    return pl.DataFrame(records, schema=["scaffold", "start", "end"], orient="row")

process_mpileup_function(gene_range_table_loc, batch_bed, batch_size, output_file)

Process mpileup files and save the results in a Parquet file.

Parameters: gene_range_table_loc (str): Path to the gene range table in TSV format. batch_bed (str): Path to the batch BED file. batch_size (int): Buffer size for processing stdin from samtools. output_file (str): Path to save the output Parquet file.

Source code in zipstrain/src/zipstrain/utils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def process_mpileup_function(gene_range_table_loc, batch_bed, batch_size, output_file):
    """
    Process mpileup files and save the results in a Parquet file.

    Parameters:
    gene_range_table_loc (str): Path to the gene range table in TSV format.
    batch_bed (str): Path to the batch BED file.
    batch_size (int): Buffer size for processing stdin from samtools.
    output_file (str): Path to save the output Parquet file.
    """
    indel_re = re.compile(r'\^.|[\$]|[+-](\d+)')
    gene_ranges_pl = pl.scan_csv(gene_range_table_loc,separator='\t', has_header=False).rename({
        "column_1": "scaffold",
        "column_2": "start",
        "column_3": "end",
        "column_4": "gene"
    })
    scaffolds = pl.read_csv(batch_bed, separator='\t', has_header=False)["column_1"].unique().to_list()
    gene_ranges_pl = gene_ranges_pl.filter(pl.col("scaffold").is_in(scaffolds)).collect()
    gene_ranges = defaultdict(IntervalTree)
    for row in gene_ranges_pl.iter_rows(named=True):
        gene_ranges[row["scaffold"]].addi(row["start"], row["end"] + 1, row["gene"])

    schema = pa.schema([
        ('chrom', pa.string()),
        ('pos', pa.int32()),
        ('gene', pa.string()),
        ('A', pa.uint16()),
        ('C', pa.uint16()),
        ('G', pa.uint16()),
        ('T', pa.uint16()),
    ])

    chroms = []
    positions = []
    genes = []
    As = []
    Cs = []
    Gs = []
    Ts = []

    writer = None
    def flush_batch():
        nonlocal writer
        if not chroms:
            return
        batch = pa.RecordBatch.from_arrays([
            pa.array(chroms, type=pa.string()),
            pa.array(positions, type=pa.int32()),
            pa.array(genes, type=pa.string()),
            pa.array(As, type=pa.uint16()),
            pa.array(Cs, type=pa.uint16()),
            pa.array(Gs, type=pa.uint16()),
            pa.array(Ts, type=pa.uint16()),
        ], schema=schema)

        if writer is None:
            # Open writer for the first time
            writer = pq.ParquetWriter(output_file, schema, compression='snappy')
        writer.write_table(pa.Table.from_batches([batch]))

        # Clear buffers
        chroms.clear()
        positions.clear()
        genes.clear()
        As.clear()
        Cs.clear()
        Gs.clear()
        Ts.clear()
    for line in sys.stdin:
        if not line.strip():
            continue
        fields = line.strip().split('\t')
        if len(fields) < 5:
            continue
        chrom, pos, _, _, bases = fields[:5]

        cleaned = clean_bases(bases, indel_re)
        counts = count_bases(cleaned)

        chroms.append(chrom)
        positions.append(int(pos))
        matches = gene_ranges[chrom][int(pos)]
        genes.append(next(iter(matches)).data if matches else "NA")
        As.append(counts['A'])
        Cs.append(counts['C'])
        Gs.append(counts['G'])
        Ts.append(counts['T'])

        if len(chroms) >= batch_size:
            flush_batch()

    # Flush remaining data
    flush_batch()

    if writer:
        writer.close()

split_lf_to_chunks(lf, num_chunks)

Split a Polars LazyFrame into smaller chunks.

Parameters: lf (pl.LazyFrame): The input LazyFrame to be split. num_chunks (int): The number of chunks to split the LazyFrame into.

Returns: list[pl.LazyFrame]: A list of smaller LazyFrames.

Source code in zipstrain/src/zipstrain/utils.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def split_lf_to_chunks(lf:pl.LazyFrame,num_chunks:int)->list[pl.LazyFrame]:
    """
    Split a Polars LazyFrame into smaller chunks.

    Parameters:
    lf (pl.LazyFrame): The input LazyFrame to be split.
    num_chunks (int): The number of chunks to split the LazyFrame into.

    Returns:
    list[pl.LazyFrame]: A list of smaller LazyFrames.
    """
    total_rows = lf.select(pl.count()).collect().item()
    chunk_size = total_rows // num_chunks
    chunks = []
    for i in range(num_chunks):
        start = i * chunk_size
        end = (i + 1) * chunk_size if i < num_chunks - 1 else total_rows
        chunk = lf.slice(start, end - start)
        chunks.append(chunk)
    return chunks

Visualize

zipstrain.visualize

This module provides statistical analysis and visualization functions for profiling and compare operations.

calculate_ibs(sample_to_population, comps_lf, max_perc_id_genes=15, min_total_positions=10000)

Calculate the Identity By State (IBS) between two populations for a given genome. The IBS is defined as the percentage of genes that are identical between two populations for a given genome. Args: sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples. max_perc_id_genes (float, optional): Maximum percentage of identical genes to consider. Defaults to 0.15. Returns: pl.LazyFrame: LazyFrame containing the IBS information for the given genome and populations.

Source code in zipstrain/src/zipstrain/visualize.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def calculate_ibs(
    sample_to_population:pl.LazyFrame, 
    comps_lf:pl.LazyFrame,
    max_perc_id_genes:float=15,
    min_total_positions:int=10000,
)->pl.DataFrame:
    """
    Calculate the Identity By State (IBS) between two populations for a given genome.
    The IBS is defined as the percentage of genes that are identical between two populations for a given genome.
    Args:
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
        max_perc_id_genes (float, optional): Maximum percentage of identical genes to consider. Defaults to 0.15.
    Returns:
        pl.LazyFrame: LazyFrame containing the IBS information for the given genome and populations.
    """
    comps_lf_filtered = comps_lf.filter(
        (pl.col('perc_id_genes') <= max_perc_id_genes) &
        (pl.col('total_positions')>min_total_positions)
    )
    comps_lf_filtered=comps_lf_filtered.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='inner',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='inner',
        suffix='_2'
    ).rename(
        {"population":"population_2"}
    )
    comps_lf_filtered = comps_lf_filtered.with_columns(
    pl.when(pl.col("population_1") == pl.col("population_2"))
    .then(
        pl.lit("within_population_")
        + pl.col("population_1")
        + pl.lit("|")
        + pl.col("population_2")
    )
    .otherwise(
        pl.concat_str(
            [
                pl.lit("between_population_"),
                pl.concat_str(
                    [
                        pl.min_horizontal("population_1", "population_2"),
                        pl.lit("|"),
                        pl.max_horizontal("population_1", "population_2"),
                    ]
                ),
            ]
        )
    )
    .alias("comparison_type")
    ).fill_null(-1)

    return comps_lf_filtered.group_by(["genome","comparison_type"]).agg(
        pl.col("max_consecutive_length"),
    ).collect(engine="streaming").pivot(
        index="genome",
        columns="comparison_type",
        values="max_consecutive_length",
    ).with_columns(
        pl.col("*").exclude("genome").fill_null([-1])
    )

calculate_identical_frac_vs_popani(genome, population_1, population_2, sample_to_population, comps_lf, min_shared_genes_count=100, min_total_positions=10000)

Calculate the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations. Args: genome (str): The genome to calculate the fraction of identical genes vs popANI for. population_1 (str): The first population to compare. population_2 (str): The second population to compare. sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples Returns: pl.LazyFrame: LazyFrame containing the fraction of identical genes vs popANI information for

Source code in zipstrain/src/zipstrain/visualize.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def calculate_identical_frac_vs_popani(
    genome:str,
    population_1:str,
    population_2:str,
    sample_to_population:pl.LazyFrame,
    comps_lf:pl.LazyFrame,
    min_shared_genes_count:int=100,
    min_total_positions:int=10000
    ):
    """
    Calculate the fraction of identical genes vs  popANI for a given genome and two samples in any possible combination of populations.
    Args:
        genome (str): The genome to calculate the fraction of identical genes vs popANI for.
        population_1 (str): The first population to compare.
        population_2 (str): The second population to compare.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples
    Returns:
        pl.LazyFrame: LazyFrame containing the fraction of identical genes vs popANI information for
    """
    comps_lf_filtered=comps_lf.filter(
        (pl.col('genome') == genome) &
        (pl.col("shared_genes_count")>min_shared_genes_count) &
        (pl.col("total_positions")>min_total_positions)
    ).collect(engine="streaming").lazy()

    comps_lf_filtered=comps_lf_filtered.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='left',
        suffix='_2'
    ).rename(
        {"population":"population_2"}
    )
    comps_lf_filtered = comps_lf_filtered.filter(
        (pl.col("population_1").is_in({population_1, population_2})) &
        (pl.col("population_2").is_in({population_1, population_2}))
    ).collect(engine="streaming").lazy()
    groups={
        "same_1":f"{population_1}-{population_1}",
        "same_2":f"{population_2}-{population_2}",
        "diff":f"{population_1}-{population_2}",
    }
    comps_lf_filtered=comps_lf_filtered.with_columns(
        pl.when((pl.col("population_1")==population_1) & (pl.col("population_2")==population_1))
        .then(pl.lit(groups["same_1"]))
        .when((pl.col("population_1")==population_2) & (pl.col("population_2")==population_2))
        .then(pl.lit(groups["same_2"]))
        .otherwise(pl.lit(groups["diff"]))
        .alias("relationship")
    )
    return comps_lf_filtered.group_by("relationship").agg(
        pl.col("perc_id_genes"),
        pl.col("genome_pop_ani")
    ).collect(engine="streaming")

calculate_strainsharing(comps_lf, breadth_lf, sample_to_population, min_breadth=0.5, strain_similarity_threshold=99.9, min_total_positions=10000)

Calculate strain sharing between populations based on popANI between genomes in their profiles. Strain sharing between two samples is defined as the ratio of genomes passing a strain similarity threshold over the total number of genomes in each sample. So, for two samples A and B, the strain sharing is defined as (Note the assymetric nature of the calculation): strain_sharing(A, B) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in A) strain_sharing(B, A) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in B)

Parameters:

Name Type Description Default
comps_lf LazyFrame

LazyFrame containing the gene profiles of the samples.

required
breadth_lf LazyFrame

LazyFrame containing the genome breadth information.

required
sample_to_population LazyFrame

LazyFrame containing the sample to population mapping.

required
min_breadth float

Minimum genome breadth to consider a genome for strain sharing. Defaults to 0.5.

0.5
strain_similarity_threshold float

Threshold for strain similarity. Defaults to 0.99.

99.9
min_total_positions int

Minimum total positions to consider a genome for strain sharing. Defaults to 10000.

10000

Returns: pl.LazyFrame: LazyFrame containing the strain sharing information between populations. It will be in the following form [Sample A, Sample B, Strain Sharing, Relationship]

Source code in zipstrain/src/zipstrain/visualize.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def calculate_strainsharing(
                            comps_lf:pl.LazyFrame,
                            breadth_lf:pl.LazyFrame,
                            sample_to_population:pl.LazyFrame,
                            min_breadth:float=0.5,
                            strain_similarity_threshold:float=99.9,
                            min_total_positions:int=10000
                            )->dict[str, list[float]]:


    """
    Calculate strain sharing between populations based on popANI between genomes in their profiles.
    Strain sharing between two samples is defined as the ratio of genomes passing a strain similarity threshold over the total number of genomes in each sample.
    So, for two samples A and B, the strain sharing is defined as (Note the assymetric nature of the calculation):
    strain_sharing(A, B) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in A)
    strain_sharing(B, A) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in B)

    Args:
        comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
        breadth_lf (pl.LazyFrame): LazyFrame containing the genome breadth information.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
        min_breadth (float, optional): Minimum genome breadth to consider a genome for strain sharing. Defaults to 0.5.
        strain_similarity_threshold (float, optional): Threshold for strain similarity. Defaults to 0.99.
        min_total_positions (int, optional): Minimum total positions to consider a genome for strain sharing. Defaults to 10000.
    Returns:
        pl.LazyFrame: LazyFrame containing the strain sharing information between populations. It will be in the following form [Sample A, Sample B, Strain Sharing, Relationship]
    """
    comps_lf=comps_lf.filter(
        (pl.col("total_positions")>min_total_positions)
    ).collect(engine="streaming").lazy()
    breadth_lf=breadth_lf.fill_null(0.0)
    breadth_lf_long=(
        breadth_lf.unpivot(
            index=["genome"],
            variable_name="sample",
            value_name="breadth"
        )
    )
    breadth_lf=breadth_lf_long.group_by("sample").agg(num_genomes=(pl.col("breadth")>=min_breadth).sum())
    comps_lf=comps_lf.join(breadth_lf,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"num_genomes":"num_genomes_1"}
    ).join(
        breadth_lf,
        left_on='sample_2',
        right_on='sample',
        how='left',
    ).rename(
        {"num_genomes":"num_genomes_2"}
    )
    comps_lf = comps_lf.join(
        sample_to_population,
        left_on='sample_1',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_1"}
    ).join(
        sample_to_population,
        left_on='sample_2',
        right_on='sample',
        how='left',
    ).rename(
        {"population":"population_2"}
    )
    comps_lf=comps_lf.join(
        breadth_lf_long,
        left_on=["genome","sample_1"],
        right_on=['genome','sample'],
        how='left',
    ).rename(
        {"breadth":"breadth_1"}
    ).join(
        breadth_lf_long,
        left_on=["genome","sample_2"],
        right_on=['genome','sample'],
        how='left',
    ).rename(
        {"breadth":"breadth_2"}
    )
    comps_lf=comps_lf.filter(
        (pl.col("breadth_1") >= min_breadth) &
        (pl.col("breadth_2") >= min_breadth) &
        (pl.col("genome_pop_ani") >= strain_similarity_threshold)
    )

    comps_lf=comps_lf.group_by(
        ["sample_1", "sample_2"]
    ).agg(
        pl.col("genome").count().alias("shared_strain_count"),
        pl.col("num_genomes_1").first().alias("num_genomes_1"),
        pl.col("num_genomes_2").first().alias("num_genomes_2"),
        pl.col("population_1").first().alias("population_1"),
        pl.col("population_2").first().alias("population_2"),
    ).collect(engine="streaming")
    strainsharingrates=defaultdict(list)
    for row in comps_lf.iter_rows(named=True):
        strainsharingrates[row["population_1"]+"_"+ row["population_2"]].append(row["shared_strain_count"] / row["num_genomes_1"])
        strainsharingrates[row["population_2"]+"_"+ row["population_1"]].append(row["shared_strain_count"] / row["num_genomes_2"])
    return strainsharingrates

get_cdf(data, num_bins=10000)

Calculate the cumulative distribution function (CDF) of the given data.

Source code in zipstrain/src/zipstrain/visualize.py
17
18
19
20
21
22
23
24
25
26
def get_cdf(data, num_bins=10000):
    """Calculate the cumulative distribution function (CDF) of the given data."""
    if data[0] == -1:
        return [-1], [-1]
    counts, bin_edges = np.histogram(data, bins=np.linspace(0, 50000, num_bins))
    counts = counts[::-1]
    bin_edges = bin_edges[::-1]
    cummulative_counts = np.cumsum(counts)
    cdf= cummulative_counts / cummulative_counts[-1]
    return bin_edges, cdf

plot_clustermap(comps_lf, genome, sample_to_population, min_comp_len=10000, impute_method=97.0, max_null_samples=200, color_map=None)

Plot a clustermap for the given genome and its associated samples. Args: comps_lf (pl.LazyFrame): LazyFrame containing the comparison data. genome (str): The genome to plot. sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping. Returns: go.Figure: Plotly figure containing the clustermap.

Source code in zipstrain/src/zipstrain/visualize.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def plot_clustermap(
    comps_lf:pl.LazyFrame,
    genome:str,
    sample_to_population:pl.LazyFrame,
    min_comp_len:int=10000,
    impute_method:str|float=97.0,
    max_null_samples:int=200,
    color_map:dict|None=None,
):
    """
    Plot a clustermap for the given genome and its associated samples.
    Args:
        comps_lf (pl.LazyFrame): LazyFrame containing the comparison data.
        genome (str): The genome to plot.
        sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
    Returns:
        go.Figure: Plotly figure containing the clustermap.
    """
    # Filter the comparison data for the specific genome
    comps_lf_filtered = comps_lf.filter(
        (pl.col("genome") == genome) & (pl.col("total_positions") > min_comp_len)
    ).select(
        pl.col("sample_1"),
        pl.col("sample_2"),
        pl.col("genome_pop_ani"),
    )
    comps_lf_filtered_oposite = comps_lf_filtered.select(
        pl.col("sample_2").alias("sample_1"),
        pl.col("sample_1").alias("sample_2"),
        pl.col("genome_pop_ani"),
    )
    # Combine the filtered data with its opposite pairs
    comps_lf_filtered = pl.concat([comps_lf_filtered, comps_lf_filtered_oposite])
    # Make a synthetic table for similarity of samples with themselves of all samples in sample_1 and sample_2 but each sample exists only once
    self_similarity =(
    pl.concat([
        comps_lf_filtered.select(pl.col("sample_1").alias("sample_1")),
        comps_lf_filtered.select(pl.col("sample_2").alias("sample_1"))
    ])
    .unique()
    .sort("sample_1").with_columns(
        pl.col("sample_1").alias("sample_2"),
        pl.lit(100.0).alias("genome_pop_ani"),
    )
    )

    # Combine the self similarity with the filtered data
    comps_lf_filtered = pl.concat([self_similarity, comps_lf_filtered]).collect()
    # Pivot the data for the clustermap
    clustermap_data = comps_lf_filtered.pivot(
        index="sample_1",
        columns="sample_2",
        values="genome_pop_ani"
    )
    # We want to make this a similarity matrix, so we need to frop null values, have sample_1 and sample_2 as index and columns as we
    # Create the clustermap
    exclude_samples=clustermap_data.null_count().transpose(include_header=True, header_name="column", column_names=["null_count"]).filter(pl.col("null_count")>max_null_samples)["column"].to_list()
    # Only include rows and cols not in exclude_samples
    clustermap_data = clustermap_data.filter(~pl.col("sample_1").is_in(exclude_samples))
    clustermap_data = clustermap_data.select(*[col for col in clustermap_data.columns if col not in exclude_samples])
    if isinstance(impute_method, str):
        pass # To be implemented later
    elif isinstance(impute_method, (int, float)):
        clustermap_data = clustermap_data.fill_null(impute_method)
    sample_to_population = clustermap_data.select(pl.col("sample_1")).join(
        sample_to_population.collect(),
        left_on="sample_1",
        right_on="sample",
        how="left")
    sample_to_population_dict = dict(zip(sample_to_population["sample_1"], sample_to_population["population"]))
    if color_map is None:

        num_categories = sample_to_population["population"].n_unique()
        groups= sample_to_population["population"].unique().sort().to_list()
        qualitative_palette = sns.color_palette("hls", num_categories)
        row_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data["sample_1"]]
        col_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data.columns if sample != "sample_1"]
    else:
        groups= list(color_map.keys())
        qualitative_palette= list(color_map.values())
        row_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data["sample_1"]]
        col_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data.columns if sample != "sample_1"]
    fig = sns.clustermap(
        clustermap_data.to_pandas().set_index("sample_1"),
        figsize=(30, 30),
        xticklabels=True, 
        yticklabels=True,
        row_colors=row_colors,
        col_colors=col_colors
    )
    fig.ax_heatmap.set_xticklabels(fig.ax_heatmap.get_xmajorticklabels(), fontsize=0.1)
    fig.ax_heatmap.set_yticklabels(fig.ax_heatmap.get_ymajorticklabels(), fontsize=0.1)
    legend_handles = [mpatches.Patch(color=color, label=label)
                  for label, color in zip(groups, qualitative_palette)]
    fig.ax_heatmap.legend(handles=legend_handles,
                          title='Population',
                        title_fontsize=16,   # bigger title
                        fontsize=14,         # bigger labels
                        handlelength=2.5,    # wider color boxes
                        handleheight=2,
                    bbox_to_anchor=(-0.15, 1), loc="lower left")
    return fig

plot_ibs(df, genome, population_1, population_2, vert_thresh_hor_distance=0.001, num_bins=10000, title='IBS for <GENOME>: <POPULATION_1> vs <POPULATION_2>', xaxis_title='Max Consecutive Length', yaxis_title='CDF')

Plot the Identity By State (IBS) for a given genome and two populations. Args: df (pl.DataFrame): DataFrame containing the IBS information. genome (str): The genome to plot the IBS for. population_1 (str): The first population to plot the IBS for. population_2 (str): The second population to plot the IBS for. title (str, optional): Title of the plot. Defaults to "IBS for ". xaxis_title (str, optional): Title of the x-axis. Defaults to "Membership". yaxis_title (str, optional): Title of the y-axis. Defaults to "Max Consecutive Length". Returns: go.Figure: Plotly figure containing the IBS plot.

Source code in zipstrain/src/zipstrain/visualize.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def plot_ibs(df:pl.DataFrame,
            genome:str,
            population_1:str,
            population_2:str,
            vert_thresh_hor_distance:float=0.001,
            num_bins:int=10000,
            title:str="IBS for <GENOME>: <POPULATION_1> vs <POPULATION_2>",
            xaxis_title:str="Max Consecutive Length",
            yaxis_title:str="CDF"
            ):
    """
    Plot the Identity By State (IBS) for a given genome and two populations.
    Args:
        df (pl.DataFrame): DataFrame containing the IBS information.
        genome (str): The genome to plot the IBS for.
        population_1 (str): The first population to plot the IBS for.
        population_2 (str): The second population to plot the IBS for.
        title (str, optional): Title of the plot. Defaults to "IBS for <GENOME>".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Membership".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Max Consecutive Length".
    Returns:
        go.Figure: Plotly figure containing the IBS plot.
    """
    df_filtered = df.filter(pl.col("genome") == genome)
    if df_filtered.is_empty():
        raise ValueError(f"Genome {genome} not found in the dataframe.")
    plot_data = {}
    key_within_1=f"within_population_{population_1}|{population_1}"
    key_within_2=f"within_population_{population_2}|{population_2}"
    key_between=f"between_population_{min(population_1,population_2)}|{max(population_1,population_2)}"
    if df_filtered.get_column(key_within_1).list.len()[0]==0 or df_filtered.get_column(key_within_2).list.len()[0]==0 or df_filtered.get_column(key_between).list.len()[0]==0:
        raise ValueError(f"Not enough data for populations {population_1} and {population_2} in genome {genome}.")
    plot_data["within_population"]=df_filtered.get_column(key_within_1)[0].to_list()+df_filtered.get_column(key_within_2)[0].to_list()
    plot_data["between_population"]=df_filtered.get_column(key_between)[0].to_list()
    fig = go.Figure()
    between_pop_cdf=get_cdf(plot_data["between_population"], num_bins=num_bins)
    fig.add_trace(go.Scatter(
        x=between_pop_cdf[0][1:].copy(),
        y=between_pop_cdf[1][1:].copy(),
        mode='lines',
        name='between_population',
        line=dict(color='blue')
    ))
    within_pop_cdf=get_cdf(plot_data["within_population"], num_bins=num_bins)
    fig.add_trace(go.Scatter(
        x=within_pop_cdf[0][1:].copy(),
        y=within_pop_cdf[1][1:].copy(),
        mode='lines',
        name='within_population',
        line=dict(color='green')
    ))

    bin_edges=within_pop_cdf[0]
    cdf=within_pop_cdf[1]
    within_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]
    bin_edges=between_pop_cdf[0]
    cdf=between_pop_cdf[1]
    between_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]  
    distance=within_intersect-between_intersect

    fig.update_layout(
        title={"text": title.replace("<GENOME>", genome).replace("<POPULATION_1>", population_1).replace("<POPULATION_2>", population_2), "x": 0.5},
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,

    )
    ###Add a horizontal line from (between_intersect, vert_thresh_hor_distance) to (within_intersect, vert_thresh_hor_distance)
    fig.add_trace(go.Scatter(
        x=[between_intersect, within_intersect],
        y=[vert_thresh_hor_distance, vert_thresh_hor_distance],
        mode='lines+markers',
        line=dict(color='black'),
        showlegend=False
    ))
    ###Add a text annotation at the middle of the horizontal line with the distance
    fig.add_trace(go.Scatter(
        x=[(between_intersect+within_intersect)/2],
        y=[vert_thresh_hor_distance],
        mode="text",
        text=int(distance),
        textposition="top center",
        showlegend=False
    ))
    ##make both axes logarithmic
    fig.update_xaxes(type='log')
    fig.update_yaxes(type='log')

    return fig

plot_ibs_heatmap(df, vert_thresh=0.001, populations=None, num_bins=10000, min_member=50, title='IBS Heatmap', xaxis_title='Population Pair', yaxis_title='Genome')

Plot the Identity By State (IBS) heatmap for a given genome and two populations. Args: df (pl.DataFrame): DataFrame containing the IBS information. title (str, optional): Title of the plot. Defaults to "IBS Heatmap". xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair". yaxis_title (str, optional): Title of the y-axis. Defaults to "Genome". Returns: go.Figure: Plotly figure containing the IBS heatmap.

Source code in zipstrain/src/zipstrain/visualize.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def plot_ibs_heatmap(
    df:pl.DataFrame,
    vert_thresh:float=0.001,
    populations:list[str]|None=None,
    num_bins:int=10000,
    min_member:int=50,
    title:str="IBS Heatmap",
    xaxis_title:str="Population Pair",
    yaxis_title:str="Genome",

):
    """
    Plot the Identity By State (IBS) heatmap for a given genome and two populations.
    Args:
        df (pl.DataFrame): DataFrame containing the IBS information.
        title (str, optional): Title of the plot. Defaults to "IBS Heatmap".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Genome".
    Returns:
        go.Figure: Plotly figure containing the IBS heatmap.
    """
    df = df.with_columns(
    [
        pl.when(pl.col(c).list.len() < min_member)
        .then(pl.lit([-1]))
        .otherwise(pl.col(c))
        .alias(c)
        for c in df.columns if c != "genome"
    ]
)
    if populations is None:
        populations=set(chain.from_iterable(i.replace("within_population_","").replace("between_population_","").split("|") for i in df.columns if i!="genome"))
        populations=sorted(populations)
    heatmap_data = df.rows_by_key("genome", unique=True,include_key=False,named=True)
    fig_data={}
    for genome, genome_data in heatmap_data.items():
        fig_data[genome]={}
        for pop1,pop2 in combinations(populations,2):
            key_between=f"between_population_{min(pop1,pop2)}|{max(pop1,pop2)}"
            key_within_1=f"within_population_{pop1}|{pop1}"
            key_within_2=f"within_population_{pop2}|{pop2}"
            if genome_data.get(key_between, [-1])==[-1] or genome_data.get(key_within_1, [-1])==[-1] or genome_data.get(key_within_2, [-1])==[-1]:
                fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=-1
                continue
            between=get_cdf(genome_data[key_between], num_bins=num_bins)
            within=get_cdf(genome_data[key_within_1]+genome_data[key_within_2], num_bins=num_bins)

            between_intersect=between[0][np.where(between[1]>=vert_thresh)[0][0]]
            within_intersect=within[0][np.where(within[1]>=vert_thresh)[0][0]]
            distance=within_intersect-between_intersect
            fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=distance
    ###Filter the dataframe to only have useful information
    heatmap_df = pd.DataFrame(fig_data).T
    heatmap_df=heatmap_df.mask(heatmap_df < 0, 0)
    heatmap_df=heatmap_df[heatmap_df.sum(axis=1)>0]
    heatmap_df=heatmap_df[[col for col in heatmap_df.columns if heatmap_df[col].sum()>0]]
    heatmap_df_sorted = heatmap_df.assign(row_sum=heatmap_df.sum(axis=1)).sort_values("row_sum", ascending=True).drop(columns="row_sum")

    fig = go.Figure(data=go.Heatmap(
        z=heatmap_df_sorted.values,
        x=heatmap_df_sorted.columns,
        y=heatmap_df_sorted.index
    ))
    return fig

plot_identical_frac_vs_popani(df, genome, title='Fraction of Identical Genes vs popANI for <GENOME>', xaxis_title='Genome-Wide popANI', yaxis_title='Fraction of Identical Genes')

Plot the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations. Args: df (pl.DataFrame): DataFrame containing the fraction of identical genes vs popANI information. title (str, optional): Title of the plot. Defaults to "Fraction of Identical Genes vs popANI". xaxis_title (str, optional): Title of the x-axis. Defaults to "popANI". yaxis_title (str, optional): Title of the y-axis. Defaults to "Fraction of Identical Genes". Returns: go.Figure: Plotly figure containing the fraction of identical genes vs popANI plot.

Source code in zipstrain/src/zipstrain/visualize.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def plot_identical_frac_vs_popani(df:pl.DataFrame,
                                  genome:str,
                                  title:str="Fraction of Identical Genes vs popANI for <GENOME>",
                                  xaxis_title:str="Genome-Wide popANI",
                                  yaxis_title:str="Fraction of Identical Genes",
                                  ):
    """
    Plot the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations.
    Args:
        df (pl.DataFrame): DataFrame containing the fraction of identical genes vs popANI information.
        title (str, optional): Title of the plot. Defaults to "Fraction of Identical Genes vs popANI".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "popANI".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Fraction of Identical Genes".
    Returns:
        go.Figure: Plotly figure containing the fraction of identical genes vs popANI plot.
    """
    fig = go.Figure()
    for group, perc_id_genes, genome_pop_ani in zip(df["relationship"], df["perc_id_genes"], df["genome_pop_ani"]):
        fig.add_trace(go.Scatter(
            x=genome_pop_ani,
            y=perc_id_genes,
            mode='markers',
            name=group
        ))
    fig.update_layout(
        title=title.replace("<GENOME>", genome),
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title
    )
    return fig

plot_strainsharing(strainsharingrates, sample_frac=1, title='Strain Sharing Rates', xaxis_title='Population Pair', yaxis_title='Strain Sharing Rate')

Plot the strain sharing rates between populations. Args: strainsharingrates (dict[str, list[float]]): Dictionary containing the strain sharing rates between populations. title (str, optional): Title of the plot. Defaults to "Strain Sharing". xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair". yaxis_title (str, optional): Title of the y-axis. Defaults to "Strain Sharing Rate". Returns: go.Figure: Plotly figure containing the strain sharing plot.

Source code in zipstrain/src/zipstrain/visualize.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def plot_strainsharing(
    strainsharingrates:dict[str, list[float]],
    sample_frac:float=1,
    title:str="Strain Sharing Rates",
    xaxis_title:str="Population Pair",
    yaxis_title:str="Strain Sharing Rate",
):
    """
    Plot the strain sharing rates between populations.
    Args:
        strainsharingrates (dict[str, list[float]]): Dictionary containing the strain sharing rates between populations.
        title (str, optional): Title of the plot. Defaults to "Strain Sharing".
        xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
        yaxis_title (str, optional): Title of the y-axis. Defaults to "Strain Sharing Rate".
    Returns:
        go.Figure: Plotly figure containing the strain sharing plot.
    """
    for key in strainsharingrates.keys():
        strainsharingrates[key] = np.random.choice(strainsharingrates[key], size=int(len(strainsharingrates[key]) * sample_frac), replace=False)
    fig = go.Figure()
    for pair, rates in strainsharingrates.items():
        fig.add_trace(go.Box(
            y=rates,
            name=pair,
            boxpoints='all',
            jitter=0.3,
            pointpos=0
        ))
    fig.update_layout(
        title={"text": title, "x": 0.5},
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title
    )
    return fig

For more advanced users, new workflows can be developed using the task_manager module to create custom pipelines tailored to specific research needs.


Task Manager

zipstrain.task_manager

Lightweight, asyncio-driven orchestration primitives for building and running scientific data-processing pipelines. This module provides a small, composable framework for defining Tasks with explicit inputs/outputs, bundling Tasks into Batches (local or Slurm), and coordinating their execution with a live terminal UI. It is designed to be easy to extend for new Task types and execution environments. For most users, this module is not directly used. However, it can be used to define new pipelines that chain together multiple steps with clear input/outputs. The unit of execution is a batch, which is a collection of tasks to be executed together. Each batch can have an optional finalization step that runs after all tasks are complete.

Key concepts
  • Inputs and Outputs: These classes encapsulate task inputs and outputs with validation logics. By default, Input and output classes for files, strings, and integers are provided. If needed, new types can be defined by subclassing Input or Output.

  • Engines: Any task object can use a container engine (Docker or Apptainer) or run natively (LocalEngine).

  • Task Each task runs a unit of bash script with defined inputs and expected outputs. If an engine is provided, the command will be wrapped accordingly to run inside the container.

  • Batches: A batch is a collection of tasks to be executed together. Batches can be run locally or submitted to Slurm. Each batch monitors the status of its tasks and updates its own status accordingly. A batch can also have expected outputs that are checked after all tasks are complete. Additionally, a batch can have a finalization step that runs after all tasks are complete.

  • Runner: The Runner class orchestrates task generation, batching, and execution. It manages concurrent batch execution, monitors progress, and provides a live terminal UI using the rich library.

Batch

Bases: ABC

Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly. A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.

Parameters:

Name Type Description Default
tasks list[Task]

List of Task objects to be included in the batch.

required
id str

Unique identifier for the batch.

required
run_dir Path

Directory where the batch will be executed.

required
expected_outputs list[Output]

List of expected outputs for the batch.

required
Source code in zipstrain/src/zipstrain/task_manager.py
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
class Batch(ABC):
    """Batch is a collection of tasks to be executed as a group. This is a base class and should not be instantiated directly.
    A batch is the unit of execution meaning that the enitre batch is either run locally or submitted to a job scheduler like Slurm.

    Args:
        tasks (list[Task]): List of Task objects to be included in the batch.
        id (str): Unique identifier for the batch.
        run_dir (pathlib.Path): Directory where the batch will be executed.
        expected_outputs (list[Output]): List of expected outputs for the batch.
    """
    TEMPLATE_CMD = ""

    def __init__(self, tasks: list[Task],
                 id: str,
                 run_dir: pathlib.Path,
                 expected_outputs: list[Output],
                 file_semaphore: asyncio.Semaphore| None = None
                 ) -> None:
        self.id = id
        self.tasks = tasks
        self.run_dir = pathlib.Path(run_dir)
        self.batch_dir = self.run_dir / self.id
        self.retry_count = 0
        self.expected_outputs = expected_outputs
        self.file_semaphore = file_semaphore
        for output in self.expected_outputs:
            if isinstance(output, BatchFileOutput):
                output.register_batch(self)
        self._status = self._get_initial_status()
        for task in self.tasks:
            task._batch_obj = self
            task.file_semaphore = self.file_semaphore
            for output in task.expected_outputs.values():
                output.register_task(task)
            task._status= task._get_initial_status()
            task.map_io()

        self._runner_obj:Runner = None



    def _get_initial_status(self) -> str:
        """Returns the initial status of the batch based on the presence of the batch directory."""
        if not self.batch_dir.exists():
            return Status.NOT_STARTED.value
        with open(self.batch_dir / ".status", mode="r") as f:
            status_as_written = f.read().strip()
        if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
            all_ready = True
            try:
                for output in self.expected_outputs:
                    if not output.ready():
                        all_ready = False
                        break
            except Exception:
                all_ready = False

            if all_ready:
                return Status.SUCCESS.value
            else:
                return Status.FAILED.value

    def cleanup(self) -> None:
        """The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
        return None

    @abstractmethod
    async def cancel(self) -> None:
        """Cancels the batch. This method should be implemented by subclasses."""
        ...

    def outputs_ready(self) -> bool:
        """Check if all BATCH-LEVEL expected outputs are ready."""
        try:
            for output in self.expected_outputs:
                if not output.ready():
                    return False
            return True
        except Exception:
            return False

    async def _collect_task_status(self) -> list[str]:
        """Collects the status of all tasks asynchronously."""
        return await asyncio.gather(*[task.get_status() for task in self.tasks])

    @abstractmethod
    async def run(self) -> None:
        """Runs the batch. This method should be implemented by subclasses."""
        ...

    @abstractmethod
    def _parse_job_id(self, sbatch_output: str) -> str:
        """Parses the job ID from the sbatch output. This method should be implemented by subclasses."""
        ...

    @property
    def status(self) -> str:
        """Returns the current status of the batch."""
        return self._status

    @property
    def stats(self) -> dict[str, str]:
        """Returns a dictionary of task IDs and their statuses."""
        return {task.id: task.status for task in self.tasks}

    async def update_status(self) -> str:
        """Updates the status of the batch by collecting the status of all tasks."""
        await self._collect_task_status()
    def _set_file_semaphore(self, file_semaphore: asyncio.Semaphore) -> None:
        self.file_semaphore = file_semaphore
        for task in self.tasks:
            task.file_semaphore = file_semaphore
stats property

Returns a dictionary of task IDs and their statuses.

status property

Returns the current status of the batch.

cancel() abstractmethod async

Cancels the batch. This method should be implemented by subclasses.

Source code in zipstrain/src/zipstrain/task_manager.py
668
669
670
671
@abstractmethod
async def cancel(self) -> None:
    """Cancels the batch. This method should be implemented by subclasses."""
    ...
cleanup()

The base class defines if any cleanup is needed after batch success. By default, it does nothing.

Source code in zipstrain/src/zipstrain/task_manager.py
664
665
666
def cleanup(self) -> None:
    """The base class defines if any cleanup is needed after batch success. By default, it does nothing."""
    return None
outputs_ready()

Check if all BATCH-LEVEL expected outputs are ready.

Source code in zipstrain/src/zipstrain/task_manager.py
673
674
675
676
677
678
679
680
681
def outputs_ready(self) -> bool:
    """Check if all BATCH-LEVEL expected outputs are ready."""
    try:
        for output in self.expected_outputs:
            if not output.ready():
                return False
        return True
    except Exception:
        return False
run() abstractmethod async

Runs the batch. This method should be implemented by subclasses.

Source code in zipstrain/src/zipstrain/task_manager.py
687
688
689
690
@abstractmethod
async def run(self) -> None:
    """Runs the batch. This method should be implemented by subclasses."""
    ...
update_status() async

Updates the status of the batch by collecting the status of all tasks.

Source code in zipstrain/src/zipstrain/task_manager.py
707
708
709
async def update_status(self) -> str:
    """Updates the status of the batch by collecting the status of all tasks."""
    await self._collect_task_status()

BatchFileOutput

Bases: Output

This is used when the output is a file path relative to the batch directory. Also it will be registered to the batch instead of the task.

Source code in zipstrain/src/zipstrain/task_manager.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class BatchFileOutput(Output):
    """This is used when the output is a file path relative to the batch directory.
    Also it will be registered to the batch instead of the task.
    """
    def __init__(self, expected_file:str) -> None:
        self._expected_file_name = expected_file

    def ready(self) -> bool:
        """Check if the expected output file exists."""
        return True if self.expected_file.absolute().exists() else False

    def register_batch(self, batch: Batch) -> None:
        """Registers the batch that produces this output and sets the expected file path.

        Args:
        batch (Batch): The batch that produces this output and sets the expected file path.
        """
        self.expected_file = batch.batch_dir / self._expected_file_name
ready()

Check if the expected output file exists.

Source code in zipstrain/src/zipstrain/task_manager.py
243
244
245
def ready(self) -> bool:
    """Check if the expected output file exists."""
    return True if self.expected_file.absolute().exists() else False
register_batch(batch)

Registers the batch that produces this output and sets the expected file path.

Args: batch (Batch): The batch that produces this output and sets the expected file path.

Source code in zipstrain/src/zipstrain/task_manager.py
247
248
249
250
251
252
253
def register_batch(self, batch: Batch) -> None:
    """Registers the batch that produces this output and sets the expected file path.

    Args:
    batch (Batch): The batch that produces this output and sets the expected file path.
    """
    self.expected_file = batch.batch_dir / self._expected_file_name

CollectComps

Bases: Task

A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
class CollectComps(Task):
    """A Task that collects and merges comparison parquet files from multiple FastCompareTask tasks into a single parquet file.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""
    TEMPLATE_CMD="""
    mkdir -p comps
    cp */*_comparison.parquet comps/
    zipstrain utilities merge_parquet --input-dir comps --output-file <output-file>
    rm -rf comps
    """

    @property
    def pre_run(self) -> str:
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"

CollectGeneComps

Bases: Task

A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
class CollectGeneComps(Task):
    """A Task that collects and merges gene comparison parquet files from multiple FastGeneCompareTask tasks into a single parquet file.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""
    TEMPLATE_CMD="""
    mkdir -p gene_comps
    cp */*_gene_comparison.parquet gene_comps/
    zipstrain utilities merge_parquet --input-dir gene_comps --output-file <output-file>
    rm -rf gene_comps
    """

    @property
    def pre_run(self) -> str:
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status"

CompareRunner

Bases: Runner

Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
class CompareRunner(Runner):
    """
    Creates and schedules batches of FastCompareTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.    
    """

    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = FastCompareSlurmBatch
            final_batch_factory = PrepareCompareGenomeRunOutputsSlurmBatch
        else:
            batch_factory = FastCompareLocalBatch
            final_batch_factory = PrepareCompareGenomeRunOutputsLocalBatch
        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )




    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue. Each batch includes a CollectComps task to merge the outputs of the tasks in the batch.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    batch_id = f"batch_{self._batch_counter}"
                    self._batch_counter += 1
                    batch_tasks = buffer + [
                        CollectComps(
                            "concat_parquet",
                            {},
                            {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
                            engine=self.container_engine,
                        )
                    ]
                    expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
                    if self.batch_type == "slurm":
                        batch = self.batch_factory(
                            tasks=batch_tasks,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=expected_outputs,
                            slurm_config=self.slurm_config,
                        )
                    else:
                        batch = self.batch_factory(
                            tasks=batch_tasks,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=expected_outputs,
                        )
                    await self.batches_queue.put(batch)
                await self.batches_queue.put(None)
                self._batcher_done = True
                break
            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                batch_id = f"batch_{self._batch_counter}"
                self._batch_counter += 1
                batch_tasks = buffer + [
                    CollectComps(
                        "concat_parquet",
                        {},
                        {"output-file": FileOutput(f"Merged_batch_{batch_id}.parquet")},
                        engine=self.container_engine,
                    )
                ]
                expected_outputs = [BatchFileOutput(f"concat_parquet/Merged_batch_{batch_id}.parquet")]
                if self.batch_type == "slurm":
                    batch = self.batch_factory(
                        tasks=batch_tasks,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=expected_outputs,
                        slurm_config=self.slurm_config,
                    )
                else:
                    batch = self.batch_factory(
                        tasks=batch_tasks,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=expected_outputs,
                    )
                await self.batches_queue.put(batch)
                buffer = []



    def _create_final_batch(self) -> Batch:
        """Creates the final batch that prepares the overall outputs after all comparison batches are done."""
        final_task = PrepareCompareGenomeRunOutputs(
            id="prepare_outputs",
            inputs={"output-dir": StringInput("Outputs")},
            expected_outputs={},
            engine=self.container_engine,
        )
        expected_outputs = [BatchFileOutput("all_comparisons.parquet")]
        if self.batch_type == "slurm":
            final_batch=self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
                slurm_config=self.slurm_config,
            )
            final_batch._runner_obj = self
            return final_batch
        else:
            final_batch = self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
            )
            final_batch._runner_obj = self
            return final_batch

CompareTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in zipstrain.compare module.

Parameters:

Name Type Description Default
data LazyFrame

Polars LazyFrame containing the data for generating tasks.

required
yield_size int

Number of tasks to yield at a time.

required
comp_config GenomeComparisonConfig

Configuration for genome comparison.

required
memory_mode str

Memory mode for the comparison task. Default is "heavy".

'heavy'
polars_engine str

Polars engine to use. Default is "streaming".

'streaming'
chrom_batch_size int

Chromosome batch size for the comparison task in light memory mode. Default is 10000.

10000
Source code in zipstrain/src/zipstrain/task_manager.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
class CompareTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genomes functionality in
    zipstrain.compare module.

    Args:
        data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
        yield_size (int): Number of tasks to yield at a time.
        comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
        memory_mode (str): Memory mode for the comparison task. Default is "heavy".
        polars_engine (str): Polars engine to use. Default is "streaming".
        chrom_batch_size (int): Chromosome batch size for the comparison task in light memory mode. Default is 10000.
    """
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        comp_config: database.GenomeComparisonConfig,
        memory_mode: str = "heavy",
        polars_engine: str = "streaming",
        chrom_batch_size: int = 10000,
    ) -> None:
        super().__init__(data, yield_size)
        self.comp_config = comp_config
        self.engine = container_engine
        self.memory_mode = memory_mode
        self.polars_engine = polars_engine
        self.chrom_batch_size = chrom_batch_size
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")

    def get_total_tasks(self) -> int:
        """Returns total number of pairwise comparisons to be made."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                inputs = {
                "mpile_1_file": FileInput(row["profile_location_1"]),
                "mpile_2_file": FileInput(row["profile_location_2"]),
                "scaffold_1_file": FileInput(row["scaffold_location_1"]),
                "scaffold_2_file": FileInput(row["scaffold_location_2"]),
                "null_model_file": FileInput(self.comp_config.null_model_loc),
                "stb_file": FileInput(self.comp_config.stb_file_loc),
                "min_cov": IntInput(self.comp_config.min_cov),
                "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
                "memory-mode": StringInput(self.memory_mode),
                "chrom-batch-size": IntInput(self.chrom_batch_size),
                "genome-name": StringInput(self.comp_config.scope),
                "engine": StringInput(self.polars_engine),
                }
                expected_outputs ={
                "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),

                }
                task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
async def generate_tasks(self) -> list[Task]:
    """Yeilds lists of FastCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            inputs = {
            "mpile_1_file": FileInput(row["profile_location_1"]),
            "mpile_2_file": FileInput(row["profile_location_2"]),
            "scaffold_1_file": FileInput(row["scaffold_location_1"]),
            "scaffold_2_file": FileInput(row["scaffold_location_2"]),
            "null_model_file": FileInput(self.comp_config.null_model_loc),
            "stb_file": FileInput(self.comp_config.stb_file_loc),
            "min_cov": IntInput(self.comp_config.min_cov),
            "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
            "memory-mode": StringInput(self.memory_mode),
            "chrom-batch-size": IntInput(self.chrom_batch_size),
            "genome-name": StringInput(self.comp_config.scope),
            "engine": StringInput(self.polars_engine),
            }
            expected_outputs ={
            "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_comparison.parquet" ),

            }
            task = FastCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of pairwise comparisons to be made.

Source code in zipstrain/src/zipstrain/task_manager.py
567
568
569
def get_total_tasks(self) -> int:
    """Returns total number of pairwise comparisons to be made."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

FastCompareLocalBatch

Bases: LocalBatch

A LocalBatch that runs FastCompareTask tasks locally.

Source code in zipstrain/src/zipstrain/task_manager.py
1548
1549
1550
1551
1552
1553
1554
class FastCompareLocalBatch(LocalBatch):
    """A LocalBatch that runs FastCompareTask tasks locally."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)

FastCompareSlurmBatch

Bases: SlurmBatch

A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future

Source code in zipstrain/src/zipstrain/task_manager.py
1556
1557
1558
1559
1560
1561
1562
class FastCompareSlurmBatch(SlurmBatch):
    """A SlurmBatch that runs FastCompareTask tasks on a Slurm cluster. Maybe removed in future"""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)

FastCompareTask

Bases: Task

A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
class FastCompareTask(Task):
    """A Task that performs a fast genome comparison using the fast_profile compare single_compare_genome command.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command.
        """
    TEMPLATE_CMD="""
    zipstrain compare single_compare_genome --mpileup-contig-1 <mpile_1_file> \
    --mpileup-contig-2 <mpile_2_file> \
    --scaffolds-1 <scaffold_1_file> \
    --scaffolds-2 <scaffold_2_file> \
    --null-model <null_model_file> \
    --stb-file <stb_file> \
    --min-cov <min_cov> \
    --min-gene-compare-len <min-gene-compare-len> \
    --memory-mode <memory-mode> \
    --chrom-batch-size <chrom-batch-size> \
    --output-file <output-file> \
    --genome <genome-name> \
    --engine <engine> 
    """

FastGeneCompareLocalBatch

Bases: LocalBatch

A LocalBatch that runs FastGeneCompareTask tasks locally.

Source code in zipstrain/src/zipstrain/task_manager.py
1906
1907
1908
1909
1910
1911
1912
class FastGeneCompareLocalBatch(LocalBatch):
    """A LocalBatch that runs FastGeneCompareTask tasks locally."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)

FastGeneCompareSlurmBatch

Bases: SlurmBatch

A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster.

Source code in zipstrain/src/zipstrain/task_manager.py
1914
1915
1916
1917
1918
1919
1920
class FastGeneCompareSlurmBatch(SlurmBatch):
    """A SlurmBatch that runs FastGeneCompareTask tasks on a Slurm cluster."""
    def cleanup(self) -> None:
        tasks_to_remove = [task for task in self.tasks if isinstance(task, FastGeneCompareTask)]
        for task in tasks_to_remove:
            self.tasks.remove(task)
            shutil.rmtree(task.task_dir)

FastGeneCompareTask

Bases: Task

A Task that performs a fast gene comparison using the compare single_compare_gene command.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
class FastGeneCompareTask(Task):
    """A Task that performs a fast gene comparison using the compare single_compare_gene command.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command.
    """
    TEMPLATE_CMD="""
    zipstrain compare single_compare_gene --mpileup-contig-1 <mpile_1_file> \
    --mpileup-contig-2 <mpile_2_file> \
    --null-model <null_model_file> \
    --stb-file <stb_file> \
    --min-cov <min_cov> \
    --min-gene-compare-len <min-gene-compare-len> \
    --output-file <output-file> \
    --engine <engine> \
    --ani-method <ani-method>
    """

FileInput

Bases: Input

This is used when the input is a file path. By default, the validate method checks for file existence.

Source code in zipstrain/src/zipstrain/task_manager.py
149
150
151
152
153
154
155
156
157
class FileInput(Input):
    """This is used when the input is a file path. By default, the validate method checks for file existence."""
    def validate(self, check_exists: bool = True) -> None:
        if check_exists and not pathlib.Path(self.value).exists():
            raise FileNotFoundError(f"Input file {self.value} does not exist.")

    def get_value(self) -> str:
        """Returns the absolute path of the input file as a string."""
        return str(pathlib.Path(self.value).absolute())
get_value()

Returns the absolute path of the input file as a string.

Source code in zipstrain/src/zipstrain/task_manager.py
155
156
157
def get_value(self) -> str:
    """Returns the absolute path of the input file as a string."""
    return str(pathlib.Path(self.value).absolute())

FileOutput

Bases: Output

This is used when the output is a file path.

Parameters:

Name Type Description Default
expected_file str

The expected output file name relative to the task directory.

required
Source code in zipstrain/src/zipstrain/task_manager.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class FileOutput(Output):
    """This is used when the output is a file path.

    Args:
        expected_file (str): The expected output file name relative to the task directory.
    """
    def __init__(self, expected_file:str) -> None:
        self._expected_file_name = expected_file ### When the task is finished, the expected file should be in task.task_dir / expected_file otherwise ready() will return False

    def ready(self) -> bool:
        """Check if the expected output file exists."""
        return True if self.expected_file.absolute().exists() else False

    def register_task(self, task: Task) -> None:
        """Registers the task that produces this output and sets the expected file path.

        Args:
        task (Task): The task that produces this output.
        """
        super().register_task(task)
        self.expected_file = self.task.task_dir / self._expected_file_name
ready()

Check if the expected output file exists.

Source code in zipstrain/src/zipstrain/task_manager.py
222
223
224
def ready(self) -> bool:
    """Check if the expected output file exists."""
    return True if self.expected_file.absolute().exists() else False
register_task(task)

Registers the task that produces this output and sets the expected file path.

Args: task (Task): The task that produces this output.

Source code in zipstrain/src/zipstrain/task_manager.py
226
227
228
229
230
231
232
233
def register_task(self, task: Task) -> None:
    """Registers the task that produces this output and sets the expected file path.

    Args:
    task (Task): The task that produces this output.
    """
    super().register_task(task)
    self.expected_file = self.task.task_dir / self._expected_file_name

GeneCompareRunner

Bases: Runner

Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
class GeneCompareRunner(Runner):
    """
    Creates and schedules batches of FastGeneCompareTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.    
    """

    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = FastGeneCompareSlurmBatch
            final_batch_factory = PrepareGeneCompareRunOutputsSlurmBatch
        else:
            batch_factory = FastGeneCompareLocalBatch
            final_batch_factory = PrepareGeneCompareRunOutputsLocalBatch
        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )

    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue. Each batch includes a CollectGeneComps task to merge the outputs of the tasks in the batch.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    collect_task = CollectGeneComps(
                        id="collect_gene_comps",
                        inputs={},
                        expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
                        engine=self.container_engine,
                    )
                    buffer.append(collect_task)
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=f"gene_batch_{self._batch_counter}",
                        run_dir=self.run_dir,
                        expected_outputs=[],
                        slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
                    )
                    await self.batches_queue.put(batch)
                    self._batch_counter += 1
                self._batcher_done = True
                break

            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                collect_task = CollectGeneComps(
                    id="collect_gene_comps",
                    inputs={},
                    expected_outputs={"output-file": FileOutput(f"Merged_gene_batch_{self._batch_counter}.parquet")},
                    engine=self.container_engine,
                )
                buffer.append(collect_task)
                batch = self.batch_factory(
                    tasks=buffer,
                    id=f"gene_batch_{self._batch_counter}",
                    run_dir=self.run_dir,
                    expected_outputs=[],
                    slurm_config=self.slurm_config if self.batch_type == "slurm" else None,
                )
                await self.batches_queue.put(batch)
                self._batch_counter += 1
                buffer = []

    def _create_final_batch(self) -> Batch:
        """Creates the final batch that prepares the overall outputs after all gene comparison batches are done."""
        final_task = PrepareGeneCompareRunOutputs(
            id="prepare_gene_outputs",
            inputs={"output-dir": StringInput("Outputs")},
            expected_outputs={},
            engine=self.container_engine,
        )
        expected_outputs = [BatchFileOutput("all_gene_comparisons.parquet")]
        if self.batch_type == "slurm":
            final_batch=self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
                slurm_config=self.slurm_config,
            )
            final_batch._runner_obj = self
            return final_batch
        else:
            final_batch = self.final_batch_factory(
                tasks=[final_task],
                id="Outputs",
                run_dir=self.run_dir,
                expected_outputs=expected_outputs,
            )
            final_batch._runner_obj = self
            return final_batch

GeneCompareTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in zipstrain.compare module.

Parameters:

Name Type Description Default
data LazyFrame

Polars LazyFrame containing the data for generating tasks.

required
yield_size int

Number of tasks to yield at a time.

required
comp_config GenomeComparisonConfig

Configuration for genome comparison.

required
polars_engine str

Polars engine to use. Default is "streaming".

'streaming'
ani_method str

ANI calculation method to use. Default is "popani".

'popani'
Source code in zipstrain/src/zipstrain/task_manager.py
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
class GeneCompareTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastGeneCompareTask objects from a polars DataFrame. Each task compares two profiles using compare_genes functionality in
    zipstrain.compare module.

    Args:
        data (pl.LazyFrame): Polars LazyFrame containing the data for generating tasks.
        yield_size (int): Number of tasks to yield at a time.
        comp_config (database.GenomeComparisonConfig): Configuration for genome comparison.
        polars_engine (str): Polars engine to use. Default is "streaming".
        ani_method (str): ANI calculation method to use. Default is "popani".
    """
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        comp_config: database.GeneComparisonConfig,
        polars_engine: str = "streaming",
        ani_method: str = "popani",
    ) -> None:
        super().__init__(data, yield_size)
        self.comp_config = comp_config
        self.engine = container_engine
        self.polars_engine = polars_engine
        self.ani_method = ani_method
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")

    def get_total_tasks(self) -> int:
        """Returns total number of pairwise comparisons to be made."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                inputs = {
                "mpile_1_file": FileInput(row["profile_location_1"]),
                "mpile_2_file": FileInput(row["profile_location_2"]),
                "null_model_file": FileInput(self.comp_config.null_model_loc),
                "stb_file": FileInput(self.comp_config.stb_file_loc),
                "min_cov": IntInput(self.comp_config.min_cov),
                "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
                "engine": StringInput(self.polars_engine),
                "ani-method": StringInput(self.ani_method),
                }
                expected_outputs ={
                "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
                }
                task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
async def generate_tasks(self) -> list[Task]:
    """Yields lists of FastGeneCompareTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            inputs = {
            "mpile_1_file": FileInput(row["profile_location_1"]),
            "mpile_2_file": FileInput(row["profile_location_2"]),
            "null_model_file": FileInput(self.comp_config.null_model_loc),
            "stb_file": FileInput(self.comp_config.stb_file_loc),
            "min_cov": IntInput(self.comp_config.min_cov),
            "min-gene-compare-len": IntInput(self.comp_config.min_gene_compare_len),
            "engine": StringInput(self.polars_engine),
            "ani-method": StringInput(self.ani_method),
            }
            expected_outputs ={
            "output-file":  FileOutput(row["sample_name_1"]+"_"+row["sample_name_2"]+"_gene_comparison.parquet" ),
            }
            task = FastGeneCompareTask(id=row["sample_name_1"]+"_"+row["sample_name_2"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of pairwise comparisons to be made.

Source code in zipstrain/src/zipstrain/task_manager.py
1718
1719
1720
def get_total_tasks(self) -> int:
    """Returns total number of pairwise comparisons to be made."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

Input

Bases: ABC

Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY. Most commonly used Input types are provided but if you want to define a new one, subclass this and implement validate() and get_value().

Source code in zipstrain/src/zipstrain/task_manager.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class Input(ABC):
    """Abstract base class for task inputs. DO NOT INSTANTIATE DIRECTLY.
    Most commonly used Input types are provided but if you want to define a new one,
    subclass this and implement validate() and get_value().
    """
    def __init__(self, value: str | int) -> None:
        self.value = value
        self.validate()

    @abstractmethod
    def validate(self) -> None:
        ...

    @abstractmethod
    def get_value(self) -> str | int:
        ...

IntInput

Bases: Input

This is used when the input is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
173
174
175
176
177
178
179
180
181
182
183
184
class IntInput(Input):
    """This is used when the input is an integer."""
    def validate(self) -> None:
        """Validate that the input value is an integer."""
        if not isinstance(self.value, int):
            raise ValueError(f"Input value {self.value!r} is not an integer.")

    def get_value(self) -> str:
        """
        Returns the integer value as a string.
        """
        return str(self.value)
get_value()

Returns the integer value as a string.

Source code in zipstrain/src/zipstrain/task_manager.py
180
181
182
183
184
def get_value(self) -> str:
    """
    Returns the integer value as a string.
    """
    return str(self.value)
validate()

Validate that the input value is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
175
176
177
178
def validate(self) -> None:
    """Validate that the input value is an integer."""
    if not isinstance(self.value, int):
        raise ValueError(f"Input value {self.value!r} is not an integer.")

IntOutput

Bases: Output

This is used when the output is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
268
269
270
271
272
273
274
275
276
277
278
class IntOutput(Output):
    """This is used when the output is an integer."""
    def ready(self) -> bool:
        """Check if the output value is an integer."""
        if isinstance(self._value, int):
            return True
        elif self._value is not None:
            raise ValueError(f"Output value for task {self.task.id} is not an integer.")
        else:
            return False
        return False
ready()

Check if the output value is an integer.

Source code in zipstrain/src/zipstrain/task_manager.py
270
271
272
273
274
275
276
277
278
def ready(self) -> bool:
    """Check if the output value is an integer."""
    if isinstance(self._value, int):
        return True
    elif self._value is not None:
        raise ValueError(f"Output value for task {self.task.id} is not an integer.")
    else:
        return False
    return False

LocalBatch

Bases: Batch

Batch that runs tasks locally in a single shell script.

Source code in zipstrain/src/zipstrain/task_manager.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
class LocalBatch(Batch):
    """Batch that runs tasks locally in a single shell script."""
    TEMPLATE_CMD = "#!/bin/bash\n"

    def __init__(self, tasks, id, run_dir, expected_outputs) -> None:
        super().__init__(tasks, id, run_dir, expected_outputs)
        self._script = self.TEMPLATE_CMD + "\nset -o pipefail\n"
        self._proc: asyncio.subprocess.Process | None = None 


    async def run(self) -> None:
        """This method runs all tasks in the batch locally by creating a shell script and executing it."""
        if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
            self.batch_dir.mkdir(parents=True, exist_ok=True)
            self._status = Status.RUNNING.value


            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

            for task in self.tasks:
                if task.status == Status.NOT_STARTED.value:
                    task.task_dir.mkdir(parents=True, exist_ok=True)  # Create task directory
                    await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)

            script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch

            script = self._script
            for task in self.tasks:
                if task.status == Status.NOT_STARTED.value or task.status == Status.FAILED.value:
                    script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

            await write_file(script_path, script, self.file_semaphore)

            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

            self._proc = await asyncio.create_subprocess_exec(
                "bash", f"{self.id}.sh",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
                cwd=self.batch_dir,
            )
            try:
                out_bytes, err_bytes = await self._proc.communicate()
            except asyncio.CancelledError:
                if self._proc and self._proc.returncode is None:
                    self._proc.terminate()
                raise

            await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
            await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)

            if self._proc.returncode == 0 and self.outputs_ready():
                self.cleanup()
                self._status = Status.SUCCESS.value
                await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
            else:
                self._status = Status.FAILED.value
                await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

        elif self.status == Status.SUCCESS.value and self.outputs_ready():
            self._status = Status.SUCCESS.value
        else:
            self._status = Status.FAILED.value

    def _parse_job_id(self, sbatch_output):
        return super()._parse_job_id(sbatch_output)

    def cleanup(self) -> None:
        super().cleanup()

    async def cancel(self) -> None:
        """Cancels the local batch by terminating the subprocess if it's running."""
        if self._proc and self._proc.returncode is None:
            self._proc.terminate()
            try:
                await asyncio.wait_for(self._proc.wait(), timeout=5.0)
            except asyncio.TimeoutError:
                self._proc.kill()
                await self._proc.wait()
            self._status = Status.FAILED.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
cancel() async

Cancels the local batch by terminating the subprocess if it's running.

Source code in zipstrain/src/zipstrain/task_manager.py
785
786
787
788
789
790
791
792
793
794
795
async def cancel(self) -> None:
    """Cancels the local batch by terminating the subprocess if it's running."""
    if self._proc and self._proc.returncode is None:
        self._proc.terminate()
        try:
            await asyncio.wait_for(self._proc.wait(), timeout=5.0)
        except asyncio.TimeoutError:
            self._proc.kill()
            await self._proc.wait()
        self._status = Status.FAILED.value
        await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
run() async

This method runs all tasks in the batch locally by creating a shell script and executing it.

Source code in zipstrain/src/zipstrain/task_manager.py
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
async def run(self) -> None:
    """This method runs all tasks in the batch locally by creating a shell script and executing it."""
    if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
        self.batch_dir.mkdir(parents=True, exist_ok=True)
        self._status = Status.RUNNING.value


        await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

        for task in self.tasks:
            if task.status == Status.NOT_STARTED.value:
                task.task_dir.mkdir(parents=True, exist_ok=True)  # Create task directory
                await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)

        script_path = self.batch_dir / f"{self.id}.sh" # Path to the shell script for the batch

        script = self._script
        for task in self.tasks:
            if task.status == Status.NOT_STARTED.value or task.status == Status.FAILED.value:
                script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

        await write_file(script_path, script, self.file_semaphore)

        await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

        self._proc = await asyncio.create_subprocess_exec(
            "bash", f"{self.id}.sh",
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            cwd=self.batch_dir,
        )
        try:
            out_bytes, err_bytes = await self._proc.communicate()
        except asyncio.CancelledError:
            if self._proc and self._proc.returncode is None:
                self._proc.terminate()
            raise

        await write_file(self.batch_dir / f"{self.id}.out", out_bytes.decode(), self.file_semaphore)
        await write_file(self.batch_dir / f"{self.id}.err", err_bytes.decode(), self.file_semaphore)

        if self._proc.returncode == 0 and self.outputs_ready():
            self.cleanup()
            self._status = Status.SUCCESS.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
        else:
            self._status = Status.FAILED.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

    elif self.status == Status.SUCCESS.value and self.outputs_ready():
        self._status = Status.SUCCESS.value
    else:
        self._status = Status.FAILED.value

Output

Bases: ABC

Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY. Most commonly used Output types are provided but if you want to define a new one, subclass this and implement ready(). This method is used to check if the output is ready/valid after task completion.

Source code in zipstrain/src/zipstrain/task_manager.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class Output(ABC):
    """Abstract base class for task outputs. DO NOT INSTANTIATE DIRECTLY.
    Most commonly used Output types are provided but if you want to define a new one,
    subclass this and implement ready().
    This method is used to check if the output is ready/valid after task completion.
    """
    def __init__(self) -> None:
        self._value = None ## Will be set by the task when it completes
        self.task = None  ## Will be set when the output is registered to a task

    @property
    def value(self):
        return self._value

    @abstractmethod
    def ready(self) -> bool:
        ...

    def register_task(self, task: Task) -> None:
        """Registers the task that produces this output. In most cases, you won't need to override this.

        Args:
        task (Task): The task that produces this output.
        """
        self.task = task
register_task(task)

Registers the task that produces this output. In most cases, you won't need to override this.

Args: task (Task): The task that produces this output.

Source code in zipstrain/src/zipstrain/task_manager.py
205
206
207
208
209
210
211
def register_task(self, task: Task) -> None:
    """Registers the task that produces this output. In most cases, you won't need to override this.

    Args:
    task (Task): The task that produces this output.
    """
    self.task = task

PrepareCompareGenomeRunOutputs

Bases: Task

A Task that prepares the final output by merging all parquet files after all genome comparisons are done.

Source code in zipstrain/src/zipstrain/task_manager.py
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
class PrepareCompareGenomeRunOutputs(Task):
    """A Task that prepares the final output by merging all parquet files after all genome comparisons are done."""
    TEMPLATE_CMD="""
    mkdir -p <output-dir>/comps
    find "$(pwd)" -type f -name "Merged_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/comps/
    zipstrain utilities merge_parquet --input-dir <output-dir>/comps --output-file <output-dir>/all_comparisons.parquet
    rm -rf <output-dir>/comps
    """

    @property
    def pre_run(self) -> str:
        """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
pre_run property

Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs.

PrepareGeneCompareRunOutputs

Bases: Task

A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done.

Source code in zipstrain/src/zipstrain/task_manager.py
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
class PrepareGeneCompareRunOutputs(Task):
    """A Task that prepares the final output by merging all gene comparison parquet files after all gene comparisons are done."""
    TEMPLATE_CMD="""
    mkdir -p <output-dir>/gene_comps
    find "$(pwd)" -type f -name "Merged_gene_batch_*.parquet" -print0 | xargs -0 -I {} ln -s {} <output-dir>/gene_comps/
    zipstrain utilities merge_parquet --input-dir <output-dir>/gene_comps --output-file <output-dir>/all_gene_comparisons.parquet
    rm -rf <output-dir>/gene_comps
    """

    @property
    def pre_run(self) -> str:
        """Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self._batch_obj._runner_obj.run_dir.absolute()}"
pre_run property

Sets the task status to RUNNING and changes directory to the runner's run directory since this task may need to access multiple batch outputs.

ProfileBamTask

Bases: Task

A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command. The inputs to this task includes:

- bam-file: The input BAM file to be profiled.

- bed-file: The BED file specifying the regions to profile.

- sample-name: The name of the sample being processed.

- gene-range-table: A BED file specifying the gene ranges for the sample.

- num-threads: The number of threads to use for processing.

- genome-length-file: A file containing the lengths of the genomes in the reference fasta.

- stb-file: The STB file used for profiling.

Parameters:

Name Type Description Default
id str

Unique identifier for the task.

required
inputs dict[str, Input]

Dictionary of input parameters for the task.

required
expected_outputs dict[str, Output]

Dictionary of expected outputs for the task.

required
engine Engine

Container engine to wrap the command.

required
Source code in zipstrain/src/zipstrain/task_manager.py
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
class ProfileBamTask(Task):
    """A Task that generates a mpileup file and genome breadth file in parquet format for a given BAM file using the fast_profile profile_bam command.
    The inputs to this task includes:

        - bam-file: The input BAM file to be profiled.

        - bed-file: The BED file specifying the regions to profile.

        - sample-name: The name of the sample being processed.

        - gene-range-table: A BED file specifying the gene ranges for the sample.

        - num-threads: The number of threads to use for processing.

        - genome-length-file: A file containing the lengths of the genomes in the reference fasta.

        - stb-file: The STB file used for profiling.

    Args:
        id (str): Unique identifier for the task.
        inputs (dict[str, Input]): Dictionary of input parameters for the task.
        expected_outputs (dict[str, Output]): Dictionary of expected outputs for the task.
        engine (Engine): Container engine to wrap the command."""

    TEMPLATE_CMD="""
    ln -s <bam-file> input.bam
    ln -s <bed-file> bed_file.bed
    ln -s <gene-range-table> gene-range-table.bed
    samtools index <bam-file>
    zipstrain profile profile-single --bam-file input.bam \
    --bed-file bed_file.bed \
    --gene-range-table gene-range-table.bed \
    --num-workers <num-workers> \
    --output-dir .
    mv input.bam.parquet <sample-name>.parquet
    samtools idxstats <bam-file> |  awk '$3 > 0 {print $1}' > <sample-name>.parquet.scaffolds
    zipstrain utilities genome_breadth_matrix --profile <sample-name>.parquet \
        --genome-length <genome-length-file> \
        --stb <stb-file> \
        --min-cov <breadth-min-cov> \
        --output-file <sample-name>_breadth.parquet
    """

ProfileRunner

Bases: Runner

Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
class ProfileRunner(Runner):
    """
    Creates and schedules batches of ProfileBamTask tasks using either local or Slurm batches.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.
    """
    def __init__(
        self,
        run_dir: str | pathlib.Path,
        task_generator: TaskGenerator,
        container_engine: Engine,
        max_concurrent_batches: int = 1,
        poll_interval: float = 1.0,
        tasks_per_batch: int = 10,
        batch_type: str = "local",
        slurm_config: SlurmConfig | None = None,
    ) -> None:
        if batch_type == "slurm":
            if slurm_config is None:
                raise ValueError("Slurm config must be provided for slurm batch type.")
            batch_factory = SlurmBatch
            final_batch_factory = None
        else:
            batch_factory = LocalBatch
            final_batch_factory = None

        super().__init__(
            run_dir=run_dir,
            task_generator=task_generator,
            container_engine=container_engine,
            batch_factory=batch_factory,
            final_batch_factory=final_batch_factory,
            max_concurrent_batches=max_concurrent_batches,
            poll_interval=poll_interval,
            tasks_per_batch=tasks_per_batch,
            batch_type=batch_type,
            slurm_config=slurm_config,
        )

    async def _batcher(self):
        """
        Defines the batcher coroutine that collects tasks from the tasks_queue, groups them into batches,
        and puts the batches into the batches_queue.
        """
        buffer: list[Task] = []
        while True:
            task = await self.tasks_queue.get()
            if task is None:
                if buffer:
                    batch_id = f"batch_{self._batch_counter}"
                    self._batch_counter += 1
                    if self.batch_type == "slurm":
                        batch = self.batch_factory(
                            tasks=buffer,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=[],
                            slurm_config=self.slurm_config,
                        )
                    else:
                        batch = self.batch_factory(
                            tasks=buffer,
                            id=batch_id,
                            run_dir=self.run_dir,
                            expected_outputs=[],
                        )
                    await self.batches_queue.put(batch)
                await self.batches_queue.put(None)
                self._batcher_done = True
                break
            buffer.append(task)
            if len(buffer) == self.tasks_per_batch:
                batch_id = f"batch_{self._batch_counter}"
                self._batch_counter += 1
                if self.batch_type == "slurm":
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=[],
                        slurm_config=self.slurm_config,
                    )
                else:
                    batch = self.batch_factory(
                        tasks=buffer,
                        id=batch_id,
                        run_dir=self.run_dir,
                        expected_outputs=[],
                    )
                await self.batches_queue.put(batch)
                buffer = []

ProfileTaskGenerator

Bases: TaskGenerator

This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file.

Source code in zipstrain/src/zipstrain/task_manager.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
class ProfileTaskGenerator(TaskGenerator):
    """This TaskGenerator generates FastProfileTask objects from a polars DataFrame. Each task profiles a BAM file."""
    def __init__(
        self,
        data: pl.LazyFrame,
        yield_size: int,
        container_engine: Engine,
        stb_file: str,
        profile_bed_file: str,
        gene_range_file: str,
        genome_length_file: str,
        num_procs: int = 4,
        breadth_min_cov: int = 1,
    ) -> None:
        super().__init__(data, yield_size)
        self.stb_file = pathlib.Path(stb_file)
        self.profile_bed_file = pathlib.Path(profile_bed_file)
        self.gene_range_file = pathlib.Path(gene_range_file)
        self.genome_length_file = pathlib.Path(genome_length_file)
        self.num_procs = num_procs
        self.breadth_min_cov = breadth_min_cov
        self.engine = container_engine
        if type(self.data) is not pl.LazyFrame:
            raise ValueError("data must be a polars LazyFrame.")
        for path_attr in [
            self.stb_file,
            self.profile_bed_file,
            self.gene_range_file,
            self.genome_length_file,
        ]:
            if not path_attr.exists():
                raise FileNotFoundError(f"File {path_attr} does not exist.")

    def get_total_tasks(self) -> int:
        """Returns total number of profiles to be generated."""
        return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

    async def generate_tasks(self) -> list[Task]:
        """Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
        while polars is collecting data to avoid blocking.
        """
        for offset in range(0, self._total_tasks, self.yield_size):
            batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
            tasks = []
            for row in batch_df.iter_rows(named=True):
                inputs = {
                "bam-file": FileInput(row["bamfile"]),
                "sample-name": StringInput(row["sample_name"]),
                "stb-file": FileInput(self.stb_file),
                "bed-file": FileInput(self.profile_bed_file),
                "gene-range-table": FileInput(self.gene_range_file),
                "genome-length-file": FileInput(self.genome_length_file),
                "num-threads": IntInput(self.num_procs),
                "breadth-min-cov": IntInput(self.breadth_min_cov),
                }
                expected_outputs ={
                "profile":  FileOutput(row["sample_name"]+".parquet" ),
                "breadth":  FileOutput(row["sample_name"]+"_breadth.parquet" ),
                "scaffold": FileOutput(row["sample_name"]+".parquet.scaffolds" ),
                }
                task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
                tasks.append(task)
            yield tasks
generate_tasks() async

Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop while polars is collecting data to avoid blocking.

Source code in zipstrain/src/zipstrain/task_manager.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
async def generate_tasks(self) -> list[Task]:
    """Yeilds lists of FastProfileTask objects based on the data in batches of yield_size. This method yields the control back to the event loop
    while polars is collecting data to avoid blocking.
    """
    for offset in range(0, self._total_tasks, self.yield_size):
        batch_df = await self.data.slice(offset, self.yield_size).collect_async(engine="streaming")
        tasks = []
        for row in batch_df.iter_rows(named=True):
            inputs = {
            "bam-file": FileInput(row["bamfile"]),
            "sample-name": StringInput(row["sample_name"]),
            "stb-file": FileInput(self.stb_file),
            "bed-file": FileInput(self.profile_bed_file),
            "gene-range-table": FileInput(self.gene_range_file),
            "genome-length-file": FileInput(self.genome_length_file),
            "num-threads": IntInput(self.num_procs),
            "breadth-min-cov": IntInput(self.breadth_min_cov),
            }
            expected_outputs ={
            "profile":  FileOutput(row["sample_name"]+".parquet" ),
            "breadth":  FileOutput(row["sample_name"]+"_breadth.parquet" ),
            "scaffold": FileOutput(row["sample_name"]+".parquet.scaffolds" ),
            }
            task = ProfileBamTask(id=row["sample_name"], inputs=inputs, expected_outputs=expected_outputs, engine=self.engine)
            tasks.append(task)
        yield tasks
get_total_tasks()

Returns total number of profiles to be generated.

Source code in zipstrain/src/zipstrain/task_manager.py
505
506
507
def get_total_tasks(self) -> int:
    """Returns total number of profiles to be generated."""
    return self.data.select(size=pl.len()).collect(engine="streaming")["size"][0]

Runner

Bases: ABC

Base Runner class to manage task generation, batching, and execution.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
task_generator TaskGenerator

An instance of TaskGenerator to produce tasks.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
batch_factory Batch

The class that creates Batch instances. It should be a subclass of Batch or its subclasses.

required
final_batch_factory Batch

A callable that creates the final Batch instance.

required
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 1.0.

1.0
tasks_per_batch int

Number of tasks to include in each batch. Default is

10
batch_type str

Type of batch to use ("local" or "slurm"). Default is "local".

'local'
slurm_config SlurmConfig | None

Configuration for Slurm batches if batch_type is "slurm". Default is None.

None
Source code in zipstrain/src/zipstrain/task_manager.py
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
class Runner(ABC):
    """Base Runner class to manage task generation, batching, and execution.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        task_generator (TaskGenerator): An instance of TaskGenerator to produce tasks.
        container_engine (Engine): An instance of Engine to wrap task commands.
        batch_factory (Batch): The class that creates Batch instances. It should be a subclass of Batch or its subclasses.
        final_batch_factory (Batch): A callable that creates the final Batch instance.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 1.0.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is
        batch_type (str): Type of batch to use ("local" or "slurm"). Default is "local".
        slurm_config (SlurmConfig | None): Configuration for Slurm batches if batch_type
            is "slurm". Default is None.

    """
    TERMINAL_BATCH_STATES = {Status.SUCCESS.value, Status.FAILED.value}
    def __init__(self,
                    run_dir: str | pathlib.Path,
                    task_generator: TaskGenerator,
                    container_engine: Engine,
                    batch_factory: Batch,
                    final_batch_factory: Batch,
                    max_concurrent_batches: int = 1,
                    poll_interval: float = 1.0,
                    tasks_per_batch: int = 10,
                    batch_type: str = "local",
                    slurm_config: SlurmConfig | None = None,
                    max_retries: int = 3,
                    ) -> None:
        self.run_dir = pathlib.Path(run_dir)
        self.run_dir.mkdir(parents=True, exist_ok=True)
        self.task_generator = task_generator
        self.container_engine = container_engine
        self.max_concurrent_batches = max_concurrent_batches
        self.poll_interval = poll_interval
        self.tasks_per_batch = tasks_per_batch
        self.batch_type = batch_type
        self.slurm_config = slurm_config
        self.tasks_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches * tasks_per_batch)
        self.batches_queue: asyncio.Queue = asyncio.Queue(maxsize=2 * max_concurrent_batches)
        self._finished_batches_count = 0
        self._success_batches_count = 0
        self._produced_tasks_count = 0
        self._active_batches: list[Batch] = []
        self._batch_counter = 0
        self._batcher_done = False
        self._final_batch_created = False
        self.batch_factory = batch_factory
        self.final_batch_factory = final_batch_factory
        self._failed_batches_count = 0
        self.max_retries = max_retries
        self.total_expected_tasks = self.task_generator.get_total_tasks()
        self.total_expected_batches = (self.total_expected_tasks + tasks_per_batch - 1) // tasks_per_batch
        self._shutdown_event = asyncio.Event()
        self._shutdown_initiated = False

    async def _refill_tasks(self):
        """Repeatedly call task_generator until it returns an empty list. This feeds tasks into the tasks_queue and waits for the queue to have space if it's full in an async manner."""
        async for tasks in self.task_generator.generate_tasks():
            for task in tasks:
                await self.tasks_queue.put(task)
                self._produced_tasks_count += 1
        await self.tasks_queue.put(None)

    @abstractmethod
    async def _batcher(self):
        ...

    def _create_final_batch(self) -> Batch|None:
        """Creates the final batch using the final_batch_factory callable."""
        return None


    async def _shutdown(self):
        """Cancel all active batches and signal shutdown."""
        if self._shutdown_initiated:
            return
        self._shutdown_initiated = True
        console.print("[yellow]Shutdown requested. Cancelling active jobs...[/]")

        for batch in list(self._active_batches):
            try:
                await batch.cancel()  
            except Exception as e:
                console.print(f"[red]Error cancelling batch {batch.id}: {e}[/]")

        # Signal the main loop to stop
        self._shutdown_event.set()

    async def run(self):
        """
        Run the producer, batcher and worker coroutines and present a live UI while working.
        Runs the task generator to produce tasks, batches them using the batcher,
        and executes batches with up to [max_concurrent_batches] parallel workers.
        UI: displays an overall panel (produced/finished counts), active batch Progress bars,
        and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

        """
        asyncio.create_task(self._batcher())
        asyncio.create_task(self._refill_tasks())
        semaphore = asyncio.Semaphore(self.max_concurrent_batches)
        file_semaphore = asyncio.Semaphore(20)
        async def run_batch(batch: Batch):
            async with semaphore:
                while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
                    await batch.run()
                    if batch.status == Status.SUCCESS.value:
                        break
                    else:
                        batch.retry_count += 1

                self._finished_batches_count += 1

                if batch.status == Status.SUCCESS.value:
                    self._success_batches_count += 1

                elif batch.status == Status.FAILED.value:
                    self._failed_batches_count += 1

                if batch in self._active_batches:
                    self._active_batches.remove(batch)

        # Rich progress objects

        overall_progress = Progress(
            TextColumn(f"[bold white]{type(self).__name__}[/]"),
            BarColumn(),
            TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
            TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
            TimeElapsedColumn(),
            expand=True,
        )
        overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)

        batch_progress = Progress(
            TextColumn("[bold white]{task.fields[batch_id]}[/]"),
            BarColumn(),
            TextColumn("{task.completed}/{task.total}"),
            TextColumn("• {task.fields[status]}"),
            TimeElapsedColumn(),
            expand=True,
        )

        batch_to_progress_id: dict[Batch, int] = {}
        batch_task_totals: dict[Batch, int] = {}

        body = Panel(Group(
            Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
            Panel(overall_progress, title="Overall Progress"),
            Panel(batch_progress, title="Active Batches", height=10),
            Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
        ), border_style="magenta")

        loop = asyncio.get_running_loop()
        for sig in (signal.SIGINT, signal.SIGTERM):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))

        with Live(body, console=console, refresh_per_second=2) as live:
            while not self._shutdown_event.is_set():
                await self._update_statuses()
                if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
                    if not self._final_batch_created:
                        final_batch = self._create_final_batch()
                        if final_batch is not None:
                            await self.batches_queue.put(final_batch)
                            self._final_batch_created = True
                        else:
                            self._final_batch_created = True
                            break
                while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty(): 
                    batch = await self.batches_queue.get()
                    if batch is not None:
                        batch._set_file_semaphore(file_semaphore)
                        self._active_batches.append(batch)
                        asyncio.create_task(run_batch(batch))
                # Update overall progress fields
                overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)  
                # Add newly queued batches into UI


                for batch in list(self._active_batches):
                    if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
                        total = len(batch.tasks) if batch.tasks else 1
                        task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
                        batch_to_progress_id[batch] = task_id
                        batch_task_totals[batch] = total
                # Remove finished batches from UI
                for batch, tid in list(batch_to_progress_id.items()):
                    if batch.status in self.TERMINAL_BATCH_STATES:
                        try:
                            batch_progress.remove_task(tid)
                        except Exception:
                            pass
                        del batch_to_progress_id[batch]
                        if batch in batch_task_totals:
                            del batch_task_totals[batch]
                # Update per-batch progress
                for batch, tid in batch_to_progress_id.items():
                    completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
                    total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
                    batch_progress.update(tid, completed=completed, total=total, status=batch.status)
                # Update system panel
                body = Panel(Group(
                    Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
                    Panel(overall_progress, title="Overall Progress"),
                    Panel(batch_progress, title="Active Batches"),
                    Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
                ), border_style="magenta")
                live.update(body)
                await asyncio.sleep(self.poll_interval)

        # final UI summary
        console.clear()
        total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
        summary = Panel(
            f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{total_batches} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
            expand=True,
            title="Summary",
            border_style="green",
        )
        console.print(summary)

    async def _update_statuses(self):
        await asyncio.gather(*[batch.update_status() for batch in self._active_batches if batch.status not in self.TERMINAL_BATCH_STATES])

    def _make_system_stats_panel(self):
        """helpers to create a system stats panel for the live UI."""
        def usage_bar(label: str, percent: float, color: str):
            p = Progress(
                TextColumn(f"[bold]{label}[/]"),
                BarColumn(bar_width=None, complete_style=color),
                TextColumn(f"{percent:.1f}%"),
                expand=True,
            )
            p.add_task("", total=100, completed=percent)
            return Panel(p, expand=True, width=30)

        cpu = psutil.cpu_percent(interval=None)
        ram = psutil.virtual_memory().percent
        cpu_panel = usage_bar("CPU", cpu, "cyan")
        ram_panel = usage_bar("RAM", ram, "magenta")
        return Columns([cpu_panel, ram_panel], expand=True, equal=True, align="center")
run() async

Run the producer, batcher and worker coroutines and present a live UI while working. Runs the task generator to produce tasks, batches them using the batcher, and executes batches with up to [max_concurrent_batches] parallel workers. UI: displays an overall panel (produced/finished counts), active batch Progress bars, and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

Source code in zipstrain/src/zipstrain/task_manager.py
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
async def run(self):
    """
    Run the producer, batcher and worker coroutines and present a live UI while working.
    Runs the task generator to produce tasks, batches them using the batcher,
    and executes batches with up to [max_concurrent_batches] parallel workers.
    UI: displays an overall panel (produced/finished counts), active batch Progress bars,
    and system stats (CPU/RAM) using Rich Live to mirror the Runner presentation.

    """
    asyncio.create_task(self._batcher())
    asyncio.create_task(self._refill_tasks())
    semaphore = asyncio.Semaphore(self.max_concurrent_batches)
    file_semaphore = asyncio.Semaphore(20)
    async def run_batch(batch: Batch):
        async with semaphore:
            while batch.status != Status.SUCCESS.value and batch.retry_count < self.max_retries:
                await batch.run()
                if batch.status == Status.SUCCESS.value:
                    break
                else:
                    batch.retry_count += 1

            self._finished_batches_count += 1

            if batch.status == Status.SUCCESS.value:
                self._success_batches_count += 1

            elif batch.status == Status.FAILED.value:
                self._failed_batches_count += 1

            if batch in self._active_batches:
                self._active_batches.remove(batch)

    # Rich progress objects

    overall_progress = Progress(
        TextColumn(f"[bold white]{type(self).__name__}[/]"),
        BarColumn(),
        TextColumn("• {task.fields[produced_tasks]}/{task.fields[total_expected_tasks]} tasks produced"),
        TextColumn("• {task.fields[finished_batches]}/{task.fields[total_expected_batches]} batches finished • {task.fields[failed_batches]} batches failed"),
        TimeElapsedColumn(),
        expand=True,
    )
    overall_task = overall_progress.add_task("overall", produced_tasks=0, total_expected_tasks=self.total_expected_tasks, finished_batches=0, total_expected_batches=self.total_expected_batches, failed_batches=0)

    batch_progress = Progress(
        TextColumn("[bold white]{task.fields[batch_id]}[/]"),
        BarColumn(),
        TextColumn("{task.completed}/{task.total}"),
        TextColumn("• {task.fields[status]}"),
        TimeElapsedColumn(),
        expand=True,
    )

    batch_to_progress_id: dict[Batch, int] = {}
    batch_task_totals: dict[Batch, int] = {}

    body = Panel(Group(
        Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
        Panel(overall_progress, title="Overall Progress"),
        Panel(batch_progress, title="Active Batches", height=10),
        Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
    ), border_style="magenta")

    loop = asyncio.get_running_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown()))

    with Live(body, console=console, refresh_per_second=2) as live:
        while not self._shutdown_event.is_set():
            await self._update_statuses()
            if self._batcher_done and self.batches_queue.empty() and len(self._active_batches) == 0:
                if not self._final_batch_created:
                    final_batch = self._create_final_batch()
                    if final_batch is not None:
                        await self.batches_queue.put(final_batch)
                        self._final_batch_created = True
                    else:
                        self._final_batch_created = True
                        break
            while len(self._active_batches) < self.max_concurrent_batches and not self.batches_queue.empty(): 
                batch = await self.batches_queue.get()
                if batch is not None:
                    batch._set_file_semaphore(file_semaphore)
                    self._active_batches.append(batch)
                    asyncio.create_task(run_batch(batch))
            # Update overall progress fields
            overall_progress.update(overall_task, produced_tasks=self._produced_tasks_count, finished_batches=self._finished_batches_count, failed_batches=self._failed_batches_count)  
            # Add newly queued batches into UI


            for batch in list(self._active_batches):
                if batch not in batch_to_progress_id and batch.status not in self.TERMINAL_BATCH_STATES:
                    total = len(batch.tasks) if batch.tasks else 1
                    task_id = batch_progress.add_task("", total=total, batch_id=f"Batch {batch.id}", status=batch.status)
                    batch_to_progress_id[batch] = task_id
                    batch_task_totals[batch] = total
            # Remove finished batches from UI
            for batch, tid in list(batch_to_progress_id.items()):
                if batch.status in self.TERMINAL_BATCH_STATES:
                    try:
                        batch_progress.remove_task(tid)
                    except Exception:
                        pass
                    del batch_to_progress_id[batch]
                    if batch in batch_task_totals:
                        del batch_task_totals[batch]
            # Update per-batch progress
            for batch, tid in batch_to_progress_id.items():
                completed = sum(1 for t in batch.tasks if t.status in self.TERMINAL_BATCH_STATES)
                total = batch_task_totals.get(batch, max(1, len(batch.tasks)))
                batch_progress.update(tid, completed=completed, total=total, status=batch.status)
            # Update system panel
            body = Panel(Group(
                Align.center(f"[bold magenta]ZipStrain {type(self).__name__}[/]\n", vertical="middle"),
                Panel(overall_progress, title="Overall Progress"),
                Panel(batch_progress, title="Active Batches"),
                Panel(self._make_system_stats_panel(), title="System Stats", expand=True),
            ), border_style="magenta")
            live.update(body)
            await asyncio.sleep(self.poll_interval)

    # final UI summary
    console.clear()
    total_batches = self._batch_counter + (1 if self._final_batch_created and self.final_batch_factory is not None else 0)
    summary = Panel(
        f"[bold green]Run finished![/]\n\n{self._success_batches_count}/{total_batches} batches succeeded.\n\nProduced tasks: {self._produced_tasks_count}\nElapsed: (see time in UI)",
        expand=True,
        title="Summary",
        border_style="green",
    )
    console.print(summary)

SlurmBatch

Bases: Batch

Batch that submits tasks to a Slurm job scheduler.

Parameters:

Name Type Description Default
tasks list[Task]

List of Task objects to be included in the batch.

required
id str

Unique identifier for the batch.

required
run_dir Path

Directory where the batch will be executed.

required
expected_outputs list[Output]

List of expected outputs for the batch.

required
slurm_config SlurmConfig

Configuration for Slurm job submission. Refer to SlurmConfig class for details.

required
Source code in zipstrain/src/zipstrain/task_manager.py
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
class SlurmBatch(Batch):
    """Batch that submits tasks to a Slurm job scheduler.

     Args:
        tasks (list[Task]): List of Task objects to be included in the batch.
        id (str): Unique identifier for the batch.
        run_dir (pathlib.Path): Directory where the batch will be executed.
        expected_outputs (list[Output]): List of expected outputs for the batch.
        slurm_config (SlurmConfig): Configuration for Slurm job submission. Refer to SlurmConfig class for details."""
    TEMPLATE_CMD = "#!/bin/bash\n"

    def __init__(self, tasks, id, run_dir, expected_outputs, slurm_config: SlurmConfig) -> None:
        super().__init__(tasks, id, run_dir, expected_outputs)
        self._check_slurm_works()
        self.slurm_config = slurm_config
        self._script = self.TEMPLATE_CMD + self.slurm_config.to_slurm_args() + "\nset -o pipefail\n"
        self._job_id = None

    def _check_slurm_works(self) -> None:
        """Checks if Slurm commands are available on the system."""
        try:
            subprocess.run(["sbatch", "--version"], capture_output=True, text=True, check=True)
            subprocess.run(["sacct", "--version"], capture_output=True, text=True, check=True)
        except:
            raise EnvironmentError("Slurm does not seem to be available or configured properly on this system.")

    async def cancel(self) -> None:
        """Cancel a running or submitted Slurm job."""
        if self._job_id:
            proc = await asyncio.create_subprocess_exec(
                "scancel", self._job_id,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            await proc.wait()

        self._status = Status.FAILED.value
        await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

    async def run(self) -> None:
        """This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
        This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
        """

        if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
            self.batch_dir.mkdir(parents=True, exist_ok=True)
            self._status = Status.RUNNING.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
            # create task directories and initialize .status if needed

            for task in self.tasks:
                if task.status == Status.NOT_STARTED.value:
                    task.task_dir.mkdir(parents=True, exist_ok=True)
                    await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
            # write the batch script (all tasks included)

            batch_path = self.batch_dir / f"{self.id}.batch"
            script=self._script
            for task in self.tasks:
                if task.status == Status.NOT_STARTED.value:
                    script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

            await write_file(batch_path, script, self.file_semaphore)

            proc = await asyncio.create_subprocess_exec(
                "sbatch","--parsable", batch_path.name,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
                cwd=str(self.batch_dir),
            )
            out_bytes, out_err = await proc.communicate()
            out = out_bytes.decode().strip() if out_bytes else ""
            if proc.returncode == 0:
                try:
                    self._job_id = self._parse_job_id(out)
                    self._status = Status.SUBMITTED.value
                    await self._wait_to_finish()
                except Exception:
                    self._status = Status.FAILED.value
            else:
                self._status = Status.FAILED.value

            if self._status == Status.SUCCESS.value and self.outputs_ready():
                self.cleanup()
                self._status = Status.SUCCESS.value
                await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
            else:
                self._status = Status.FAILED.value
                await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

        else:
            if self.status == Status.SUCCESS.value and self.outputs_ready():
                self._status = Status.SUCCESS.value
            else:
                self._status = Status.FAILED.value

    def _parse_job_id(self, sbatch_output: str) -> str:
        if match := re.search(r"(\d+)", sbatch_output):
            return match.group(1)
        else:
            raise ValueError("Could not parse job ID from sbatch output.")

    async def _wait_to_finish(self,sleep_duration:float=1.0):
        while self.status not in (Status.SUCCESS.value, Status.FAILED.value):
            await self.update_status()
            await asyncio.sleep(sleep_duration)

    async def update_status(self):
        if self._job_id is None:
            self._status=Status.NOT_STARTED.value
        else:
            await self._collect_task_status()
            out= await asyncio.create_subprocess_exec(
                "sacct", "-j", self._job_id, "--format=State", "--noheader","--allocations",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )
            out_bytes, _ = await out.communicate()
            if out_bytes:
                state = out_bytes.decode().strip()
                if state in ["FAILED", "CANCELLED", "TIMEOUT"]:
                    self._status = Status.FAILED.value
                elif state=="RUNNING":
                    self._status = Status.RUNNING.value
                elif state in ["COMPLETED", "COMPLETING"]:
                    self._status = Status.SUCCESS.value
                else:
                    self._status = Status.PENDING.value
                await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
cancel() async

Cancel a running or submitted Slurm job.

Source code in zipstrain/src/zipstrain/task_manager.py
824
825
826
827
828
829
830
831
832
833
834
835
async def cancel(self) -> None:
    """Cancel a running or submitted Slurm job."""
    if self._job_id:
        proc = await asyncio.create_subprocess_exec(
            "scancel", self._job_id,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
        )
        await proc.wait()

    self._status = Status.FAILED.value
    await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
run() async

This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion. This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.

Source code in zipstrain/src/zipstrain/task_manager.py
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
async def run(self) -> None:
    """This method submits the batch to Slurm by creating a batch script and using sbatch command. It also monitors the job status until completion.
    This method is unavoidably different from LocalBatch.run() because of the nature of Slurm job submission.
    """

    if self.status != Status.SUCCESS and self.status != Status.FAILED.value:
        self.batch_dir.mkdir(parents=True, exist_ok=True)
        self._status = Status.RUNNING.value
        await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
        # create task directories and initialize .status if needed

        for task in self.tasks:
            if task.status == Status.NOT_STARTED.value:
                task.task_dir.mkdir(parents=True, exist_ok=True)
                await write_file(task.task_dir / ".status", Status.NOT_STARTED.value, self.file_semaphore)
        # write the batch script (all tasks included)

        batch_path = self.batch_dir / f"{self.id}.batch"
        script=self._script
        for task in self.tasks:
            if task.status == Status.NOT_STARTED.value:
                script += f"\n{task.pre_run}\n{task.command}\n{task.post_run}\n"

        await write_file(batch_path, script, self.file_semaphore)

        proc = await asyncio.create_subprocess_exec(
            "sbatch","--parsable", batch_path.name,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            cwd=str(self.batch_dir),
        )
        out_bytes, out_err = await proc.communicate()
        out = out_bytes.decode().strip() if out_bytes else ""
        if proc.returncode == 0:
            try:
                self._job_id = self._parse_job_id(out)
                self._status = Status.SUBMITTED.value
                await self._wait_to_finish()
            except Exception:
                self._status = Status.FAILED.value
        else:
            self._status = Status.FAILED.value

        if self._status == Status.SUCCESS.value and self.outputs_ready():
            self.cleanup()
            self._status = Status.SUCCESS.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)
        else:
            self._status = Status.FAILED.value
            await write_file(self.batch_dir / ".status", self._status, self.file_semaphore)

    else:
        if self.status == Status.SUCCESS.value and self.outputs_ready():
            self._status = Status.SUCCESS.value
        else:
            self._status = Status.FAILED.value

SlurmConfig

Bases: BaseModel

Configuration model for Slurm batch jobs.

Attributes:

Name Type Description
time str

Time limit for the job in HH:MM:SS format.

tasks int

Number of tasks.

mem int

Memory in GB.

additional_params dict

Additional SLURM parameters as key-value pairs.

NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of "#SBATCH --cpus-per-task=4" to the sbatch script.

Source code in zipstrain/src/zipstrain/task_manager.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class SlurmConfig(BaseModel):
    """Configuration model for Slurm batch jobs.

    Attributes:
        time (str): Time limit for the job in HH:MM:SS format.
        tasks (int): Number of tasks.
        mem (int): Memory in GB.
        additional_params (dict): Additional SLURM parameters as key-value pairs.

    NOTE: Additional paramters for slurm should be provided in the additional_params dict in the form
    of {"param-name": "param-value"}, e.g., {"cpus-per-task": "4"} will result in the addition of
    "#SBATCH --cpus-per-task=4" to the sbatch script.

    """
    time: str = Field(description="Time limit for the job.")
    tasks: int = Field(default=1, description="Number of tasks.")
    mem: int = Field(default=4, description="Memory in GB.")
    additional_params: dict[str, str] = Field(default_factory=dict, description="Additional SLURM parameters as key-value pairs.")

    @field_validator("time")
    def validate_time(cls, v):
        """Validate time format HH:MM:SS (H..HHH allowed)."""
        if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
            raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
        return v

    def to_slurm_args(self) -> str:
        """Generates the slurm batch file header form the configuration object"""
        args = [
            f"#SBATCH --time={self.time}",
            f"#SBATCH --ntasks={self.tasks}",
            f"#SBATCH --mem={self.mem}G",
        ]
        for key, value in self.additional_params.items():
            args.append(f"#SBATCH --{key}={value}")
        return "\n".join(args)

    @classmethod
    def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
        """Load SlurmConfig from a JSON file."""
        path = pathlib.Path(json_path)
        if not path.exists():
            raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
        return cls.model_validate_json(path.read_text())
from_json(json_path) classmethod

Load SlurmConfig from a JSON file.

Source code in zipstrain/src/zipstrain/task_manager.py
100
101
102
103
104
105
106
@classmethod
def from_json(cls, json_path: str | pathlib.Path) -> SlurmConfig:
    """Load SlurmConfig from a JSON file."""
    path = pathlib.Path(json_path)
    if not path.exists():
        raise FileNotFoundError(f"Slurm config file {json_path} does not exist.")
    return cls.model_validate_json(path.read_text())
to_slurm_args()

Generates the slurm batch file header form the configuration object

Source code in zipstrain/src/zipstrain/task_manager.py
89
90
91
92
93
94
95
96
97
98
def to_slurm_args(self) -> str:
    """Generates the slurm batch file header form the configuration object"""
    args = [
        f"#SBATCH --time={self.time}",
        f"#SBATCH --ntasks={self.tasks}",
        f"#SBATCH --mem={self.mem}G",
    ]
    for key, value in self.additional_params.items():
        args.append(f"#SBATCH --{key}={value}")
    return "\n".join(args)
validate_time(v)

Validate time format HH:MM:SS (H..HHH allowed).

Source code in zipstrain/src/zipstrain/task_manager.py
82
83
84
85
86
87
@field_validator("time")
def validate_time(cls, v):
    """Validate time format HH:MM:SS (H..HHH allowed)."""
    if not re.match(r"^\d{1,3}:\d{2}:\d{2}$", v):
        raise ValueError("Time must be in the format HH:MM:SS (H..HHH allowed)")
    return v

Status

Bases: StrEnum

Enumeration of possible task and batch statuses.

Source code in zipstrain/src/zipstrain/task_manager.py
119
120
121
122
123
124
125
126
127
128
class Status(StrEnum):
    """Enumeration of possible task and batch statuses."""
    BATCH_NOT_ASSIGNED = "batch_not_assigned"
    NOT_STARTED = "not_started"
    RUNNING = "running"
    DONE = "done"
    FAILED = "failed"
    SUBMITTED = "submitted"
    SUCCESS = "success"
    PENDING = "pending"

StringInput

Bases: Input

This is used when the input is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
160
161
162
163
164
165
166
167
168
169
170
class StringInput(Input):
    """This is used when the input is a string."""

    def validate(self) -> None:
        """Validate that the input value is a string."""
        if not isinstance(self.value, str):
            raise ValueError(f"Input value {self.value!r} is not a string.")

    def get_value(self) -> str:
        """Returns the string value."""
        return str(self.value)
get_value()

Returns the string value.

Source code in zipstrain/src/zipstrain/task_manager.py
168
169
170
def get_value(self) -> str:
    """Returns the string value."""
    return str(self.value)
validate()

Validate that the input value is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
163
164
165
166
def validate(self) -> None:
    """Validate that the input value is a string."""
    if not isinstance(self.value, str):
        raise ValueError(f"Input value {self.value!r} is not a string.")

StringOutput

Bases: Output

This is used when the output is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
256
257
258
259
260
261
262
263
264
265
class StringOutput(Output):
    """This is used when the output is a string."""
    def ready(self) -> bool:
        """Check if the output value is a string."""
        if isinstance(self._value, str):
            return True
        elif self._value is not None:
            raise ValueError(f"Output value for task {self.task.id} is not a string.")
        else:
            return False
ready()

Check if the output value is a string.

Source code in zipstrain/src/zipstrain/task_manager.py
258
259
260
261
262
263
264
265
def ready(self) -> bool:
    """Check if the output value is a string."""
    if isinstance(self._value, str):
        return True
    elif self._value is not None:
        raise ValueError(f"Output value for task {self.task.id} is not a string.")
    else:
        return False

Task

Bases: ABC

Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>. As an example, if a task has an input file called "input-file" and an expected output file called "output-file", the TEMPLATE_CMD could be something like: TEMPLATE_CMD = "some_command --input --output " the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.

Source code in zipstrain/src/zipstrain/task_manager.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
class Task(ABC):
    """Abstract base class for tasks. DO NOT INSTANTIATE DIRECTLY. Any new task type should subclass this
    and implement the TEMPLATE_CMD class attribute. Inputs and expected outputs are specified using <>.
    As an example, if a task has an input file called "input-file" and an expected output file called "output-file",
    the TEMPLATE_CMD could be something like:
        TEMPLATE_CMD = "some_command --input <input-file> --output <output-file>"
    the outputs and inputs will be mapped to the command when map_io() is called later in the runtime.
    """
    TEMPLATE_CMD = ""

    def __init__(
        self,
        id: str,
        inputs: dict[str, Input | Output],
        expected_outputs: dict[str, Output] ,
        engine: Engine,
        batch_obj: Batch | None = None,
        file_semaphore: asyncio.Semaphore | None = None
    ) -> None:
        self.id = id
        self.inputs = inputs
        self.expected_outputs = expected_outputs
        self._batch_obj = batch_obj
        self.engine = engine
        self._status = self._get_initial_status()
        self.file_semaphore = file_semaphore

    def map_io(self) -> None:
        """Maps inputs and expected outputs to the command template. Note that when this method is called,
        all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
        However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
        """
        cmd = self.TEMPLATE_CMD
        for key, value in self.inputs.items():
            cmd = cmd.replace(f"<{key}>", value.get_value())
        # if any placeholders remain, report them

        for handle, output in self.expected_outputs.items():
            cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
        remaining = re.findall(r"<\w+>", cmd)
        if remaining:
            raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
        self._command = cmd

    @property
    def batch_dir(self) -> pathlib.Path:
        """Returns the batch directory path. Raises an error if the task is not associated with any batch yet."""
        if self._batch_obj is None:
            raise ValueError(f"Task {self.id} is not associated with any batch yet.")
        return self._batch_obj.batch_dir

    @property
    def task_dir(self) -> pathlib.Path:
        """Returns the task directory path."""
        return self.batch_dir / self.id

    @property
    def command(self) -> str:
        """Returns the command to be executed, wrapped with the engine if applicable."""
        file_inputs = [v for v in self.inputs.values() if isinstance(v, FileInput)]
        return self.engine.wrap(self._command, file_inputs)

    @property
    def pre_run(self) -> str:
        """Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like
        batch aggregation."""
        return f"echo {Status.RUNNING.value} > {self.task_dir.absolute()}/.status && cd {self.task_dir.absolute()}"

    @property
    def status(self) -> str:
        """Returns the current status of the task."""
        return self._status

    @property
    def post_run(self) -> str:
        """Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like
        batch aggregation."""
        return f"cd {self.batch_dir.absolute()} && echo {Status.DONE.value} > {self.task_dir.absolute()}/.status"

    async def get_status(self) -> str:
        """Asynchronously reads the task status from the .status file in the task directory."""
        status_path = self.task_dir / ".status"
        # read the status file if it exists
        if status_path.exists():
            raw = await read_file(status_path, self.file_semaphore)
            self._status = raw.strip()

            # if task reported 'done', check outputs to decide success/failure
            if self._status == Status.DONE.value:
                all_ready = True
                try:
                    for output in self.expected_outputs.values():
                        if not output.ready():
                            all_ready = False
                            break
                except Exception:
                    all_ready = False

                if all_ready:
                    self._status = Status.SUCCESS.value
                    await write_file(status_path, Status.SUCCESS.value, self.file_semaphore)
                else:
                    self._status = Status.FAILED.value
                    await write_file(status_path, Status.FAILED.value, self.file_semaphore)
                    raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")

        return self._status

    def _get_initial_status(self) -> str:
        """Returns the initial status of the task based on the presence of the batch and task directories."""
        if self._batch_obj is None:
            return Status.BATCH_NOT_ASSIGNED.value
        if not self.task_dir.exists():
            return Status.NOT_STARTED.value
        status_file = self.task_dir / ".status"
        with open(status_file, mode="r") as f:
            status_as_written = f.read().strip()
        if status_as_written in (Status.DONE.value, Status.SUCCESS.value):
            all_ready = True
            try:
                for output in self.expected_outputs.values():
                    if not output.ready():
                        all_ready = False
                        break
            except Exception:
                all_ready = False

            if all_ready:
                return Status.SUCCESS.value
            else:
                return Status.FAILED.value
batch_dir property

Returns the batch directory path. Raises an error if the task is not associated with any batch yet.

command property

Returns the command to be executed, wrapped with the engine if applicable.

post_run property

Does the necessary steps after running the task command. This should not be overridden by subclasses unless a task needs special teardown like batch aggregation.

pre_run property

Does the necessary setup before running the task command. This should not be overridden by subclasses unless a task needs special setup like batch aggregation.

status property

Returns the current status of the task.

task_dir property

Returns the task directory path.

get_status() async

Asynchronously reads the task status from the .status file in the task directory.

Source code in zipstrain/src/zipstrain/task_manager.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
async def get_status(self) -> str:
    """Asynchronously reads the task status from the .status file in the task directory."""
    status_path = self.task_dir / ".status"
    # read the status file if it exists
    if status_path.exists():
        raw = await read_file(status_path, self.file_semaphore)
        self._status = raw.strip()

        # if task reported 'done', check outputs to decide success/failure
        if self._status == Status.DONE.value:
            all_ready = True
            try:
                for output in self.expected_outputs.values():
                    if not output.ready():
                        all_ready = False
                        break
            except Exception:
                all_ready = False

            if all_ready:
                self._status = Status.SUCCESS.value
                await write_file(status_path, Status.SUCCESS.value, self.file_semaphore)
            else:
                self._status = Status.FAILED.value
                await write_file(status_path, Status.FAILED.value, self.file_semaphore)
                raise ValueError(f"Task {self.id} reported done but outputs are not ready or invalid. {self.expected_outputs['output-file'].expected_file.absolute()}")

    return self._status
map_io()

Maps inputs and expected outputs to the command template. Note that when this method is called, all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries. However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.

Source code in zipstrain/src/zipstrain/task_manager.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def map_io(self) -> None:
    """Maps inputs and expected outputs to the command template. Note that when this method is called,
    all of the inputs and outputs in the TEMPLATE_CMD must be defined in the inputs and expected_outputs dictionaries.
    However, this method is not called by the user directly. It is called by the Batch when the task is added to a batch.
    """
    cmd = self.TEMPLATE_CMD
    for key, value in self.inputs.items():
        cmd = cmd.replace(f"<{key}>", value.get_value())
    # if any placeholders remain, report them

    for handle, output in self.expected_outputs.items():
        cmd = cmd.replace(f"<{handle}>", str(output.expected_file.absolute()))
    remaining = re.findall(r"<\w+>", cmd)
    if remaining:
        raise ValueError(f"Not all inputs were mapped in task {self.id}. Remaining placeholders: {remaining}")
    self._command = cmd

TaskGenerator

Bases: ABC

Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner. Some important concepts:

  • generate_tasks() is an async generator that yields lists of Task objects.

  • yield_size determines how many tasks are generated and yielded at a time.

  • get_total_tasks() returns the total number of tasks that can be generated.

Source code in zipstrain/src/zipstrain/task_manager.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class TaskGenerator(ABC):
    """Abstract base class for task generators. DO NOT INSTANTIATE DIRECTLY. A subclass of this class 
    should provide an async generator method called generate_tasks() that yields lists of Task objects in an async manner.
    Some important concepts:

    - generate_tasks() is an async generator that yields lists of Task objects.

    - yield_size determines how many tasks are generated and yielded at a time.

    - get_total_tasks() returns the total number of tasks that can be generated.

    """
    def __init__(self,
                 data,
                 yield_size:int,

                 ):
        self.data = data
        self.yield_size = yield_size
        self._total_tasks = self.get_total_tasks()

    @abstractmethod
    async def generate_tasks(self) -> list[Task]:
        pass

    @abstractmethod
    def get_total_tasks(self) -> int:
        pass

get_cpu_usage()

Returns the current CPU usage percentage.

Source code in zipstrain/src/zipstrain/task_manager.py
930
931
932
def get_cpu_usage():
    """Returns the current CPU usage percentage."""
    return psutil.cpu_percent(interval=0.1)

get_memory_usage()

Returns the current memory usage percentage.

Source code in zipstrain/src/zipstrain/task_manager.py
934
935
936
def get_memory_usage():
    """Returns the current memory usage percentage."""
    return psutil.virtual_memory().percent

lazy_run_compares(run_dir, container_engine, comps_db=None, tasks_per_batch=10, max_concurrent_batches=1, poll_interval=5.0, execution_mode='local', slurm_config=None, memory_mode='heavy', chrom_batch_size=10000, polars_engine='streaming')

A helper function to quickly set up and run a CompareRunner with given parameters.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
comps_db GenomeComparisonDatabase | None

An instance of GenomeComparisonDatabase containing comparison data.

None
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 5.0.

5.0
execution_mode str

Execution mode, either "local" or "slurm". Default is "local".

'local'
Source code in zipstrain/src/zipstrain/task_manager.py
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
def lazy_run_compares(
    run_dir: str | pathlib.Path,
    container_engine: Engine,
    comps_db: database.GenomeComparisonDatabase|None = None,
    tasks_per_batch: int = 10,
    max_concurrent_batches: int = 1,
    poll_interval: float = 5.0,
    execution_mode: str = "local",
    slurm_config: SlurmConfig | None = None,
    memory_mode: str = "heavy",
    chrom_batch_size: int = 10000,
    polars_engine: str = "streaming"
) -> None:
    """A helper function to quickly set up and run a CompareRunner with given parameters.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        container_engine (Engine): An instance of Engine to wrap task commands.
        comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
        execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
    """
    task_generator = CompareTaskGenerator(
        data=comps_db.to_complete_input_table(),
        yield_size=tasks_per_batch,
        container_engine=container_engine,
        comp_config=comps_db.config,
        memory_mode=memory_mode,
        polars_engine=polars_engine,
        chrom_batch_size=chrom_batch_size,
    )
    if execution_mode=="local":
        batch_type="local"
    elif execution_mode=="slurm":
        batch_type="slurm"
    else:
        raise ValueError(f"Unknown execution mode: {execution_mode}")
    runner = CompareRunner(
        run_dir=pathlib.Path(run_dir),
        task_generator=task_generator,
        container_engine=container_engine,
        max_concurrent_batches=max_concurrent_batches,
        poll_interval=poll_interval,
        tasks_per_batch=tasks_per_batch,
        batch_type=batch_type,
        slurm_config=slurm_config,
    )
    asyncio.run(runner.run())

lazy_run_gene_compares(run_dir, container_engine, comps_db=None, tasks_per_batch=10, max_concurrent_batches=1, poll_interval=5.0, execution_mode='local', slurm_config=None, polars_engine='streaming', ani_method='popani')

A helper function to quickly set up and run a GeneCompareRunner with given parameters.

Parameters:

Name Type Description Default
run_dir str | Path

Directory where the runner will operate.

required
container_engine Engine

An instance of Engine to wrap task commands.

required
comps_db GenomeComparisonDatabase | None

An instance of GenomeComparisonDatabase containing comparison data.

None
tasks_per_batch int

Number of tasks to include in each batch. Default is 10.

10
max_concurrent_batches int

Maximum number of batches to run concurrently. Default is 1.

1
poll_interval float

Time interval in seconds to poll for batch status updates. Default is 5.0.

5.0
execution_mode str

Execution mode, either "local" or "slurm". Default is "local".

'local'
polars_engine str

Polars engine to use. Default is "streaming".

'streaming'
ani_method str

ANI calculation method to use. Default is "popani".

'popani'
Source code in zipstrain/src/zipstrain/task_manager.py
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
def lazy_run_gene_compares(
    run_dir: str | pathlib.Path,
    container_engine: Engine,
    comps_db: database.GeneComparisonDatabase | None = None,
    tasks_per_batch: int = 10,
    max_concurrent_batches: int = 1,
    poll_interval: float = 5.0,
    execution_mode: str = "local",
    slurm_config: SlurmConfig | None = None,
    polars_engine: str = "streaming",
    ani_method: str = "popani"
) -> None:
    """A helper function to quickly set up and run a GeneCompareRunner with given parameters.

    Args:
        run_dir (str | pathlib.Path): Directory where the runner will operate.
        container_engine (Engine): An instance of Engine to wrap task commands.
        comps_db (GenomeComparisonDatabase | None): An instance of GenomeComparisonDatabase containing comparison data.
        tasks_per_batch (int): Number of tasks to include in each batch. Default is 10.
        max_concurrent_batches (int): Maximum number of batches to run concurrently. Default is 1.
        poll_interval (float): Time interval in seconds to poll for batch status updates. Default is 5.0.
        execution_mode (str): Execution mode, either "local" or "slurm". Default is "local".
        polars_engine (str): Polars engine to use. Default is "streaming".
        ani_method (str): ANI calculation method to use. Default is "popani".
    """
    task_generator = GeneCompareTaskGenerator(
        data=comps_db.to_complete_input_table(),
        yield_size=tasks_per_batch,
        container_engine=container_engine,
        comp_config=comps_db.config,
        polars_engine=polars_engine,
        ani_method=ani_method,
    )
    if execution_mode=="local":
        batch_type="local"
    elif execution_mode=="slurm":
        batch_type="slurm"
    else:
        raise ValueError(f"Unknown execution mode: {execution_mode}")
    runner = GeneCompareRunner(
        run_dir=pathlib.Path(run_dir),
        task_generator=task_generator,
        container_engine=container_engine,
        max_concurrent_batches=max_concurrent_batches,
        poll_interval=poll_interval,
        tasks_per_batch=tasks_per_batch,
        batch_type=batch_type,
        slurm_config=slurm_config,
    )
    asyncio.run(runner.run())