"""Repository utilities and helper functions."""
import subprocess
import tomllib
from pathlib import Path
from collections import defaultdict
import typer
[docs]
def get_config_path():
"""Get the path to the user config file."""
config_dir = Path.home() / ".config" / "dbx-python-cli"
return config_dir / "config.toml"
[docs]
def get_default_config_path():
"""Get the path to the default config file shipped with the package."""
return Path(__file__).parent.parent / "config.toml"
[docs]
def get_config():
"""Load configuration from user config or default config."""
user_config_path = get_config_path()
default_config_path = get_default_config_path()
# Try user config first
if user_config_path.exists():
with open(user_config_path, "rb") as f:
return tomllib.load(f)
# Fall back to default config
if default_config_path.exists():
with open(default_config_path, "rb") as f:
return tomllib.load(f)
# If neither exists, return empty config
return {}
[docs]
def get_base_dir(config):
"""Get the base directory for cloning repos."""
base_dir = config.get("repo", {}).get("base_dir", "~/repos")
return Path(base_dir).expanduser()
[docs]
def is_flat_mode(config):
"""Return True if repos live directly under base_dir (flat layout)."""
return config.get("repo", {}).get("flat", False)
[docs]
def get_group_dir(base_dir, group, flat=False):
"""Return the group directory path, or base_dir itself in flat mode."""
return base_dir if flat else base_dir / group
[docs]
def get_repo_dir(base_dir, group, repo_name, flat=False):
"""Return the repo directory path regardless of layout mode."""
return base_dir / repo_name if flat else base_dir / group / repo_name
[docs]
def get_projects_dir(base_dir, flat=False):
"""Return the directory where Django projects live.
In flat mode projects live directly in base_dir.
In grouped mode they live in base_dir/projects/.
"""
return base_dir if flat else base_dir / "projects"
[docs]
def get_repo_groups(config):
"""Get repository groups from config."""
return config.get("repo", {}).get("groups", {})
[docs]
def get_global_groups(config):
"""
Get the list of global group names from config.
Repos in global groups are installed into every other group's venv
automatically when running ``dbx install -g <group>``.
Returns:
list: Group names listed under ``repo.global_groups``, or an empty list.
"""
return config.get("repo", {}).get("global_groups", [])
[docs]
def get_install_dirs(config, group_name, repo_name):
"""
Get install directories for a repository.
For repos with packages in subdirectories, returns a list of subdirectories to install.
For regular repos, returns None (install from root).
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'langchain')
repo_name: Name of the repository (e.g., 'langchain-mongodb')
Returns:
list: List of install directories, or None if packages are at the root
"""
groups = get_repo_groups(config)
if group_name not in groups:
return None
install_dirs_config = groups[group_name].get("install_dirs", {})
return install_dirs_config.get(repo_name)
[docs]
def get_build_commands(config, group_name, repo_name):
"""
Get build commands for a repository.
For repos that need a build step before installation (e.g., cmake builds),
returns a list of shell commands to run.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django')
repo_name: Name of the repository (e.g., 'libmongocrypt')
Returns:
list: List of build commands, or None if no build needed
"""
groups = get_repo_groups(config)
if group_name not in groups:
return None
build_commands_config = groups[group_name].get("build_commands", {})
return build_commands_config.get(repo_name)
[docs]
def get_test_runner(config, group_name, repo_name):
"""
Get test runner configuration for a repository.
Returns the test runner command/script if configured, otherwise None (use pytest).
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django')
repo_name: Name of the repository (e.g., 'django')
Returns:
str: Test runner path/command, or None for default pytest
"""
groups = get_repo_groups(config)
if group_name not in groups:
return None
test_runner_config = groups[group_name].get("test_runner", {})
return test_runner_config.get(repo_name)
[docs]
def get_evergreen_project_name(config, repo_name):
"""Get the Evergreen project name for a repository.
Args:
config: Configuration dictionary
repo_name: Name of the repository
Returns:
str: Evergreen project name, or None if not configured
"""
return config.get("evergreen", {}).get(repo_name, {}).get("project_name")
[docs]
def get_install_groups(config, group_name, repo_name):
"""
Get default dependency groups to install for a repository.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'langchain')
repo_name: Name of the repository (e.g., 'langchain-mongodb')
Returns:
list: List of dependency groups to install by default, or empty list
"""
groups = get_repo_groups(config)
if group_name not in groups:
return []
install_groups_config = groups[group_name].get("install_groups", {})
return install_groups_config.get(repo_name, [])
[docs]
def should_skip_install(config, group_name, repo_name):
"""
Check if a repository should skip automatic installation.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django-3p')
repo_name: Name of the repository (e.g., 'django')
Returns:
bool: True if installation should be skipped, False otherwise
"""
groups = get_repo_groups(config)
if group_name not in groups:
return False
skip_install_config = groups[group_name].get("skip_install", [])
return repo_name in skip_install_config
[docs]
def get_test_runner_args(config, group_name, repo_name):
"""
Get default arguments for a custom test runner.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django')
repo_name: Name of the repository (e.g., 'django')
Returns:
list: List of default args to pass to the test runner, or empty list
"""
groups = get_repo_groups(config)
if group_name not in groups:
return []
test_runner_args_config = groups[group_name].get("test_runner_args", {})
return test_runner_args_config.get(repo_name, [])
[docs]
def get_python_version(config, group_name=None):
"""
Get the Python version to use for a group's virtual environment.
When configured, ``dbx clone`` and ``dbx env init`` will use this Python
version when creating the group's venv.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django'), or None for the default
Returns:
str: Python version string (e.g., '3.13'), or None to use system default
"""
default = config.get("repo", {}).get("python_version")
if group_name is None:
return default
groups = get_repo_groups(config)
if group_name not in groups:
return default
return groups[group_name].get("python_version") or default
[docs]
def get_editor(config, group_name=None, repo_name=None):
"""
Get the editor to use for opening repositories.
Priority order:
1. Repo-specific editor setting (if repo_name provided)
2. Group-level editor setting (if group_name provided)
3. Global editor setting in config
4. EDITOR environment variable
5. Default to 'vim'
Args:
config: Configuration dictionary
group_name: Optional name of the group (e.g., 'django')
repo_name: Optional name of the repository (e.g., 'django')
Returns:
str: Editor command to use
"""
import os
# Check repo-specific editor setting
if group_name and repo_name:
groups = get_repo_groups(config)
if group_name in groups:
editor_config = groups[group_name].get("editor", {})
if isinstance(editor_config, dict) and repo_name in editor_config:
return editor_config[repo_name]
# Check group-level editor setting
if group_name:
groups = get_repo_groups(config)
if group_name in groups:
group_editor = groups[group_name].get("editor")
if isinstance(group_editor, str):
return group_editor
# Check global editor setting
global_editor = config.get("repo", {}).get("editor")
if global_editor:
return global_editor
# Fall back to EDITOR environment variable
env_editor = os.environ.get("EDITOR")
if env_editor:
return env_editor
# Final fallback to vim
return "vim"
[docs]
def get_preferred_branch(config, group_name, repo_name):
"""
Get the preferred branch to switch to after cloning a repository.
When configured, ``dbx clone`` will run ``git switch <branch>`` immediately
after a successful clone, so the working tree starts on the right branch
without any manual step.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'django')
repo_name: Name of the repository (e.g., 'django')
Returns:
str: Branch name to switch to, or None if no preferred branch is configured
"""
groups = get_repo_groups(config)
if group_name not in groups:
return None
# Try preferred_branch first (new name), fall back to default_branch (old name) for backwards compatibility
preferred_branch_config = groups[group_name].get("preferred_branch", {})
if not preferred_branch_config:
preferred_branch_config = groups[group_name].get("default_branch", {})
return preferred_branch_config.get(repo_name)
[docs]
def switch_to_branch(repo_path: Path, branch_name: str, verbose: bool = False) -> bool:
"""
Switch to a branch in a cloned repository.
Runs ``git switch <branch_name>`` in *repo_path*. Failures are reported as
warnings rather than hard errors so that the caller's workflow is not
interrupted.
Args:
repo_path: Path to the repository
branch_name: Branch to switch to
verbose: Whether to show verbose output
Returns:
True if the switch succeeded, False otherwise
"""
if verbose:
typer.echo(f" [verbose] Switching to branch '{branch_name}'")
try:
result = subprocess.run(
["git", "-C", str(repo_path), "switch", branch_name],
check=False,
capture_output=True,
text=True,
)
if result.returncode == 0:
typer.echo(f" 🔀 Switched to branch '{branch_name}'")
return True
else:
typer.echo(
f" ⚠️ Could not switch to branch '{branch_name}': "
f"{result.stderr.strip() or 'unknown error'}",
err=True,
)
return False
except Exception as exc:
typer.echo(
f" ⚠️ Could not switch to branch '{branch_name}': {exc}",
err=True,
)
return False
[docs]
def get_test_env_vars(config, group_name, repo_name, base_dir):
"""
Get environment variables for test runs.
Returns a dictionary of environment variables to set when running tests.
Supports both group-level and repo-specific environment variables.
When no repo-specific entry is found in the repo's own group, falls back
to checking global groups (listed under ``repo.global_groups``), so that
repos cloned from a global group into another group still pick up their
test environment configuration.
Args:
config: Configuration dictionary
group_name: Name of the group (e.g., 'pymongo')
repo_name: Name of the repository (e.g., 'mongo-python-driver')
base_dir: Base directory path for resolving relative paths
Returns:
dict: Dictionary of environment variable names to values
"""
groups = get_repo_groups(config)
env_vars = {}
def _collect_env(grp_name):
"""Collect group-level and repo-specific env vars for a given group."""
if grp_name not in groups:
return {}
result = {}
group_env_config = groups[grp_name].get("test_env", {})
if isinstance(group_env_config, dict):
for key, value in group_env_config.items():
if not isinstance(value, dict):
result[key] = _expand_env_var_value(value, base_dir, grp_name)
repo_env_config = group_env_config.get(repo_name, {})
if isinstance(repo_env_config, dict):
for key, value in repo_env_config.items():
result[key] = _expand_env_var_value(value, base_dir, grp_name)
return result
# Collect from the repo's own group first
env_vars = _collect_env(group_name)
# If nothing found, fall back to global groups so repos cloned into a
# different group still pick up their test_env configuration.
if not env_vars:
for gname in get_global_groups(config):
fallback = _collect_env(gname)
if fallback:
env_vars = fallback
break
return env_vars
def _expand_env_var_value(value, base_dir, group_name):
"""
Expand special placeholders in environment variable values.
Supports:
- {base_dir}: Expands to the base directory path
- {group}: Expands to the group name
- ~: Expands to user home directory
Args:
value: The environment variable value (string)
base_dir: Base directory path
group_name: Name of the group
Returns:
str: Expanded value
"""
if not isinstance(value, str):
return str(value)
# Expand placeholders
expanded = value.replace("{base_dir}", str(base_dir))
expanded = expanded.replace("{group}", group_name)
# Expand user home directory
expanded = str(Path(expanded).expanduser())
return expanded
def _build_repo_group_map(config):
"""
Build a mapping of repo name → config group name for flat mode.
Scans all non-global groups in config and returns the first group
each repo name appears in (respecting group_priority order).
"""
groups = get_repo_groups(config)
global_group_names = set(get_global_groups(config))
priority = get_group_priority(config)
repo_to_group = {}
# Process priority groups first so they win on conflicts
ordered_group_names = priority + [
g for g in groups if g not in priority and g not in global_group_names
]
for group_name in ordered_group_names:
if group_name not in groups or group_name in global_group_names:
continue
for url in groups[group_name].get("repos", []):
repo_name = extract_repo_name_from_url(url)
if repo_name not in repo_to_group:
repo_to_group[repo_name] = group_name
return repo_to_group
[docs]
def find_all_repos(base_dir, config=None):
"""
Find all cloned repositories in the base directory.
In flat mode (``config["repo"]["flat"] = true``) repos live directly
under *base_dir*. Otherwise the classic two-level layout is used:
``base_dir/<group>/<repo>``.
Args:
base_dir: Path to the base directory
config: Optional configuration dictionary; enables flat-mode detection
and group assignment from config when flat is true.
Returns:
list: List of dictionaries with 'name', 'path', and 'group' keys
"""
repos = []
if not base_dir.exists():
return repos
if config and is_flat_mode(config):
# Flat layout: repos are direct children of base_dir.
# Assign config group names so -g filtering still works.
repo_to_group = _build_repo_group_map(config)
for repo_dir in sorted(base_dir.iterdir()):
if not repo_dir.is_dir():
continue
if (repo_dir / ".git").exists():
group = repo_to_group.get(repo_dir.name, "")
repos.append({"name": repo_dir.name, "path": repo_dir, "group": group})
elif (repo_dir / "manage.py").exists():
# Django projects: identified by manage.py, live directly in base_dir
repos.append({"name": repo_dir.name, "path": repo_dir, "group": ""})
return repos
# Grouped layout: base_dir/<group>/<repo>
for group_dir in base_dir.iterdir():
if group_dir.is_dir():
for repo_dir in group_dir.iterdir():
if repo_dir.is_dir():
# Check if it's a git repo
if (repo_dir / ".git").exists():
repos.append(
{
"name": repo_dir.name,
"path": repo_dir,
"group": group_dir.name,
}
)
# Also check if it's a project (has pyproject.toml but no .git)
# This allows projects to be found by install command
elif (
group_dir.name == "projects"
and (repo_dir / "pyproject.toml").exists()
):
repos.append(
{
"name": repo_dir.name,
"path": repo_dir,
"group": "projects",
}
)
return repos
[docs]
def get_group_priority(config):
"""
Get the group priority list from configuration.
Args:
config: Configuration dictionary
Returns:
list: List of group names in priority order (highest priority first)
"""
if not config:
return []
return config.get("repo", {}).get("group_priority", [])
[docs]
def find_repo_by_name(repo_name, base_dir, config=None):
"""
Find a repository by name in the base directory.
If multiple repos with the same name exist in different groups,
returns the one in the highest priority group.
Args:
repo_name: Name of the repository to find
base_dir: Path to the base directory containing group subdirectories
config: Optional configuration dictionary for group priority
Returns:
dict: Dictionary with 'name', 'path', and 'group' keys, or None if not found
"""
all_repos = find_all_repos(base_dir, config)
matching_repos = [repo for repo in all_repos if repo["name"] == repo_name]
if not matching_repos:
return None
if len(matching_repos) == 1:
return matching_repos[0]
# Multiple repos found - use priority ordering
priority = get_group_priority(config)
# First, try to find a repo in a prioritized group
for group_name in priority:
for repo in matching_repos:
if repo["group"] == group_name:
return repo
# If no prioritized group found, return the first match
return matching_repos[0]
[docs]
def find_repo_by_path(path, base_dir, config=None):
"""
Find a repository by filesystem path.
Resolves *path* to an absolute path and checks all known repos for a
match. An exact match on the repo root is tried first; if not found,
the function checks whether *path* is located *inside* a known repo
(useful when the caller is in a subdirectory of the repo).
Args:
path: A :class:`pathlib.Path` (or anything accepted by ``Path()``)
pointing at or inside the repository root.
base_dir: Path to the base directory containing group subdirectories.
config: Optional configuration dictionary.
Returns:
dict: Dictionary with ``'name'``, ``'path'``, and ``'group'`` keys,
or ``None`` if no matching repository is found.
"""
from pathlib import Path as _Path
resolved = _Path(path).resolve()
all_repos = find_all_repos(base_dir, config)
# Exact match first
for r in all_repos:
if r["path"].resolve() == resolved:
return r
# Path is inside a repo (e.g. a subdirectory)
for r in all_repos:
try:
resolved.relative_to(r["path"].resolve())
return r
except ValueError:
continue
return None
[docs]
def find_all_repos_by_name(repo_name, base_dir, config=None):
"""
Find all repositories with a given name in the base directory.
Args:
repo_name: Name of the repository to find
base_dir: Path to the base directory containing group subdirectories
config: Optional configuration dictionary
Returns:
list: List of dictionaries with 'name', 'path', and 'group' keys
"""
all_repos = find_all_repos(base_dir, config)
return [repo for repo in all_repos if repo["name"] == repo_name]
[docs]
def list_repos(base_dir, format_style="default", config=None):
"""
List all repositories in a formatted way.
Args:
base_dir: Path to the base directory containing group subdirectories
format_style: Output format style ('default', 'tree', 'grouped', or 'simple')
config: Optional config dict to compare available vs cloned repos
Returns:
str: Formatted list of repositories
"""
repos = find_all_repos(base_dir, config)
# If config is provided, get available repos from config
available_repos = {}
global_group_names = set()
global_repo_names = []
if config:
groups = config.get("repo", {}).get("groups", {})
global_group_names = set(get_global_groups(config))
# Collect repo names from global groups first
for gname in global_group_names:
if gname in groups:
for url in groups[gname].get("repos", []):
global_repo_names.append(extract_repo_name_from_url(url))
# Build available_repos for non-global groups only
for group_name, group_config in groups.items():
if group_name in global_group_names:
# Global groups are not cloned to their own directory — skip them
continue
repo_urls = group_config.get("repos", [])
for url in repo_urls:
repo_name = extract_repo_name_from_url(url)
if group_name not in available_repos:
available_repos[group_name] = []
available_repos[group_name].append(repo_name)
# Inject global repos into every non-global group's available list
for group_name in available_repos:
for repo_name in global_repo_names:
if repo_name not in available_repos[group_name]:
available_repos[group_name].append(repo_name)
# If no repos cloned and no config, return None
if not repos and not available_repos:
return None
# Flat mode: render a tree grouped by config group
if config and is_flat_mode(config):
cloned_names = {r["name"] for r in repos}
# Build repo → groups mapping from config (non-global groups only)
repo_to_groups = defaultdict(list)
for gname, gconfig in groups.items():
if gname in global_group_names:
continue
for url in gconfig.get("repos", []):
rname = extract_repo_name_from_url(url)
repo_to_groups[rname].append(gname)
# Cloned repos that belong to no config group are rendered ungrouped at the end
ungrouped_cloned = sorted(r for r in cloned_names if not repo_to_groups.get(r))
all_groups = set(available_repos.keys()) - global_group_names
if not all_groups and not ungrouped_cloned:
return None
lines = []
sorted_groups = sorted(all_groups)
total_sections = len(sorted_groups) + (1 if ungrouped_cloned else 0)
for i, group in enumerate(sorted_groups):
is_last_group = i == total_sections - 1
group_prefix = "└──" if is_last_group else "├──"
group_label = typer.style(f"{group}/", fg=typer.colors.CYAN, bold=True)
lines.append(f"{group_prefix} {group_label}")
available_in_group = set(available_repos.get(group, []))
cloned_in_group = {
r for r in cloned_names if group in repo_to_groups.get(r, [])
}
all_repos_in_group = sorted(available_in_group | cloned_in_group)
for j, repo_name in enumerate(all_repos_in_group):
is_last_repo = j == len(all_repos_in_group) - 1
continuation = " " if is_last_group else "│ "
repo_prefix = "└──" if is_last_repo else "├──"
is_cloned = repo_name in cloned_names
is_available = repo_name in available_in_group
if is_cloned and is_available:
status = typer.style("✓", fg=typer.colors.GREEN)
elif is_cloned:
status = typer.style("?", fg=typer.colors.MAGENTA)
else:
status = typer.style("○", fg=typer.colors.YELLOW)
lines.append(f"{continuation}{repo_prefix} {status} {repo_name}")
for k, repo_name in enumerate(ungrouped_cloned):
is_last = k == len(ungrouped_cloned) - 1
prefix = "└──" if is_last else "├──"
status = typer.style("?", fg=typer.colors.MAGENTA)
lines.append(f"{prefix} {status} {repo_name}")
return "\n".join(lines) if lines else None
if format_style == "tree":
# Tree format with group as parent
# Build cloned repos dict
cloned = defaultdict(list)
for repo in sorted(repos, key=lambda r: (r["group"], r["name"])):
cloned[repo["group"]].append(repo["name"])
# Merge available and cloned groups, excluding global groups
# (global repos are cloned into target group directories, not their own directory)
all_groups = (
set(cloned.keys()) | set(available_repos.keys())
) - global_group_names
lines = []
sorted_groups = sorted(all_groups)
for i, group in enumerate(sorted_groups):
is_last_group = i == len(sorted_groups) - 1
group_prefix = "└──" if is_last_group else "├──"
lines.append(f"{group_prefix} {group}/")
# Get all repos for this group (available and cloned)
available_in_group = set(available_repos.get(group, []))
cloned_in_group = set(cloned.get(group, []))
all_repos_in_group = sorted(available_in_group | cloned_in_group)
for j, repo_name in enumerate(all_repos_in_group):
is_last_repo = j == len(all_repos_in_group) - 1
continuation = " " if is_last_group else "│ "
repo_prefix = "└──" if is_last_repo else "├──"
# Add status indicator if config is provided
if config:
is_cloned = repo_name in cloned_in_group
is_available = repo_name in available_in_group
if is_cloned and is_available:
status = "✓" # Cloned
elif is_cloned and not is_available:
status = "?" # Cloned but not in config
else:
status = "○" # Available but not cloned
lines.append(f"{continuation}{repo_prefix} {status} {repo_name}")
else:
lines.append(f"{continuation}{repo_prefix} {repo_name}")
return "\n".join(lines)
elif format_style == "grouped":
# Group repos by group name
grouped = defaultdict(list)
for repo in sorted(repos, key=lambda r: (r["group"], r["name"])):
grouped[repo["group"]].append(repo["name"])
lines = []
for group in sorted(grouped.keys()):
lines.append(f" [{group}]")
for repo_name in grouped[group]:
lines.append(f" • {repo_name}")
return "\n".join(lines)
elif format_style == "simple":
# Simple list with group in parentheses
lines = []
for repo in sorted(repos, key=lambda r: (r["group"], r["name"])):
lines.append(f" • {repo['name']} ({repo['group']})")
return "\n".join(lines)
else: # default - use tree format
# Default format: tree structure
# Build cloned repos dict
cloned = defaultdict(list)
for repo in sorted(repos, key=lambda r: (r["group"], r["name"])):
cloned[repo["group"]].append(repo["name"])
# Merge available and cloned groups, excluding global groups
# (global repos are cloned into target group directories, not their own directory)
all_groups = (
set(cloned.keys()) | set(available_repos.keys())
) - global_group_names
lines = []
sorted_groups = sorted(all_groups)
for i, group in enumerate(sorted_groups):
is_last_group = i == len(sorted_groups) - 1
group_prefix = "└──" if is_last_group else "├──"
group_label = typer.style(f"{group}/", fg=typer.colors.CYAN, bold=True)
lines.append(f"{group_prefix} {group_label}")
# Get all repos for this group (available and cloned)
available_in_group = set(available_repos.get(group, []))
cloned_in_group = set(cloned.get(group, []))
all_repos_in_group = sorted(available_in_group | cloned_in_group)
for j, repo_name in enumerate(all_repos_in_group):
is_last_repo = j == len(all_repos_in_group) - 1
continuation = " " if is_last_group else "│ "
repo_prefix = "└──" if is_last_repo else "├──"
# Add status indicator if config is provided
if config:
is_cloned = repo_name in cloned_in_group
is_available = repo_name in available_in_group
if is_cloned and is_available:
status = typer.style("✓", fg=typer.colors.GREEN)
elif is_cloned and not is_available:
status = typer.style("?", fg=typer.colors.MAGENTA)
else:
status = typer.style("○", fg=typer.colors.YELLOW)
lines.append(f"{continuation}{repo_prefix} {status} {repo_name}")
else:
lines.append(f"{continuation}{repo_prefix} {repo_name}")
return "\n".join(lines)