Skip to content

Python SDK: Fine-tuning

This guide covers the fine-tuning capabilities of the PaiTIENT Secure Model Service Python SDK, allowing you to customize AI models for healthcare-specific use cases while maintaining HIPAA/SOC2 compliance.

Fine-tuning Overview

Fine-tuning allows you to adapt pre-trained models to your specific healthcare domain or use case, improving performance, accuracy, and adherence to guidelines.

The PaiTIENT Secure Model Service provides secure fine-tuning capabilities with:

  1. Data Security: All fine-tuning data is encrypted and processed in secure environments
  2. Compliance: The entire fine-tuning process adheres to HIPAA and SOC2 requirements
  3. Efficient Training: Parameter-efficient fine-tuning methods like LoRA
  4. Performance Tracking: Detailed metrics for fine-tuning progress and evaluation

Prerequisites

Before fine-tuning a model, ensure you have:

  1. Installed the PaiTIENT Python SDK
  2. Set up authentication credentials
  3. Prepared a fine-tuning dataset
  4. Selected a base model for fine-tuning

Dataset Preparation

Dataset Format

The PaiTIENT Secure Model Service supports various dataset formats for different fine-tuning approaches:

Instruction Dataset Format (JSONL)

For instruction fine-tuning, use a JSONL file with this structure:

jsonl
{"instruction": "Explain the mechanism of action for metformin", "response": "Metformin primarily works by reducing glucose production in the liver (hepatic gluconeogenesis), improving insulin sensitivity in peripheral tissues, and enhancing glucose uptake in skeletal muscle. It also decreases intestinal absorption of glucose."}
{"instruction": "What are the symptoms of diabetic ketoacidosis?", "response": "The main symptoms of diabetic ketoacidosis include excessive thirst, frequent urination, nausea and vomiting, abdominal pain, weakness or fatigue, shortness of breath, fruity-scented breath, and confusion. It may also present with high blood glucose levels, high ketone levels in the urine, and electrolyte imbalances."}

Conversation Dataset Format (JSONL)

For conversational fine-tuning, use a JSONL file with this structure:

jsonl
{"conversations": [{"role": "user", "content": "What are the side effects of atorvastatin?"}, {"role": "assistant", "content": "Atorvastatin may cause side effects including muscle pain and weakness, which in rare cases can lead to serious muscle damage (rhabdomyolysis). Other common side effects include joint pain, diarrhea, nausea, and elevated liver enzymes. Rare but serious side effects include liver damage, diabetes, and memory problems. Always consult your healthcare provider if you experience concerning symptoms while taking atorvastatin."}]}
{"conversations": [{"role": "user", "content": "I'm experiencing leg pain after starting atorvastatin. Should I be concerned?"}, {"role": "assistant", "content": "Leg pain after starting atorvastatin could be a sign of muscle-related side effects, which should be taken seriously. Please contact your healthcare provider immediately to discuss these symptoms. Your doctor might want to check your creatine kinase levels to rule out muscle damage. Don't stop taking the medication without medical advice, but this symptom warrants prompt medical attention."}]}

Dataset Validation

Validate your dataset before fine-tuning:

python
from paitient_secure_model import Client
from paitient_secure_model.validation import validate_fine_tuning_dataset

client = Client()

# Validate dataset
validation_result = client.validate_dataset(
    file_path="clinical_guidelines_dataset.jsonl",
    format="instruction"  # or "conversation"
)

if validation_result.is_valid:
    print("Dataset validation successful")
    print(f"Number of examples: {validation_result.num_examples}")
    print(f"Total tokens: {validation_result.total_tokens}")
else:
    print("Dataset validation failed:")
    for error in validation_result.errors:
        print(f"  - {error}")

Dataset Upload

Upload your dataset to the secure environment:

python
# Upload dataset
dataset = client.upload_dataset(
    file_path="clinical_guidelines_dataset.jsonl",
    name="Clinical Guidelines Dataset",
    description="Fine-tuning dataset for clinical guideline adherence",
    format="instruction",
    tags={"domain": "cardiology", "source": "guidelines"}
)

print(f"Dataset ID: {dataset.id}")
print(f"Number of examples: {dataset.num_examples}")
print(f"Total tokens: {dataset.total_tokens}")

Basic Fine-tuning

The simplest way to fine-tune a model:

python
from paitient_secure_model import Client

# Initialize client
client = Client()

# Start a fine-tuning job
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1"
)

print(f"Fine-tuning job ID: {fine_tuning_job.id}")
print(f"Status: {fine_tuning_job.status}")

# Wait for fine-tuning to complete
fine_tuning_job.wait_until_complete()
print(f"Fine-tuning is now {fine_tuning_job.status}")
print(f"Fine-tuned model: {fine_tuning_job.fine_tuned_model}")

Fine-tuning Options

Training Parameters

Configure training parameters for your fine-tuning job:

python
# Fine-tuning with specific training parameters
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1",
    hyperparameters={
        "learning_rate": 1e-4,
        "batch_size": 8,
        "epochs": 3,
        "warmup_steps": 100,
        "lora_rank": 16,
        "lora_alpha": 32,
        "lora_dropout": 0.05
    }
)

Validation Dataset

Use a validation dataset to monitor training progress:

python
# Upload validation dataset
validation_dataset = client.upload_dataset(
    file_path="clinical_guidelines_validation.jsonl",
    name="Clinical Guidelines Validation Dataset",
    format="instruction"
)

# Fine-tuning with validation
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    validation_dataset_id=validation_dataset.id,
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1",
    validation_frequency=0.1  # Validate after 10% of training steps
)

Compute Configuration

Configure compute resources for your fine-tuning job:

python
# Fine-tuning with compute configuration
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1",
    compute_config={
        "instance_type": "g5.2xlarge",
        "instance_count": 1,
        "max_runtime_hours": 24
    }
)

Security Settings

Apply security settings to your fine-tuning job:

python
from paitient_secure_model import Client
from paitient_secure_model.security import SecuritySettings

# Initialize client
client = Client()

# Fine-tuning with security settings
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1",
    security_settings=SecuritySettings(
        network_isolation=True,        # Enable network isolation
        encryption_level="maximum",    # Maximum encryption level
        audit_logging=True,            # Enable comprehensive audit logging
        compliance_mode="hipaa"        # Enable HIPAA compliance mode
    )
)

Advanced Fine-tuning Methods

LoRA Fine-tuning

Low-Rank Adaptation (LoRA) is the default fine-tuning method:

python
# LoRA fine-tuning with advanced parameters
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="lora",
    model_name="clinical-assistant-cardiology-v1",
    hyperparameters={
        "lora_rank": 16,              # Rank of LoRA matrices
        "lora_alpha": 32,             # LoRA scaling factor
        "lora_dropout": 0.05,         # Dropout probability for LoRA layers
        "target_modules": ["q_proj", "v_proj"],  # Target modules for LoRA
        "learning_rate": 1e-4,
        "batch_size": 8,
        "epochs": 3
    }
)

QLoRA Fine-tuning

Quantized LoRA (QLoRA) for memory-efficient fine-tuning:

python
# QLoRA fine-tuning
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="qlora",
    model_name="clinical-assistant-cardiology-v1",
    hyperparameters={
        "quantization_bits": 4,       # 4-bit quantization
        "lora_rank": 16,
        "lora_alpha": 32,
        "learning_rate": 1e-4,
        "batch_size": 8,
        "epochs": 3
    }
)

Full Fine-tuning

Full parameter fine-tuning for maximum performance:

python
# Full fine-tuning (requires more compute resources)
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id="ds_12345abcde",
    fine_tuning_method="full",
    model_name="clinical-assistant-cardiology-v1",
    hyperparameters={
        "learning_rate": 5e-5,
        "batch_size": 4,
        "epochs": 2,
        "warmup_steps": 100,
        "weight_decay": 0.01
    },
    compute_config={
        "instance_type": "g5.12xlarge",  # More powerful GPU instance
        "instance_count": 2,
        "max_runtime_hours": 48
    }
)

Fine-tuning Management

Check Fine-tuning Status

Monitor the status of your fine-tuning job:

python
# Get fine-tuning job status
job = client.get_fine_tuning_job("ft_12345abcde")
print(f"Status: {job.status}")
print(f"Created: {job.created_at}")
print(f"Updated: {job.updated_at}")
print(f"Base model: {job.base_model}")

# Get detailed job information
details = job.get_details()
print(f"Current epoch: {details.current_epoch}/{details.total_epochs}")
print(f"Training loss: {details.training_loss}")
print(f"Validation loss: {details.validation_loss}")
print(f"Training samples processed: {details.processed_samples}")

List Fine-tuning Jobs

Retrieve a list of all your fine-tuning jobs:

python
# List all fine-tuning jobs
jobs = client.list_fine_tuning_jobs()
for job in jobs:
    print(f"{job.id}: {job.model_name} - {job.status}")

# Filter fine-tuning jobs
completed_jobs = client.list_fine_tuning_jobs(
    filters={
        "status": "completed",
        "base_model": "ZimaBlueAI/HuatuoGPT-o1-8B"
    }
)
for job in completed_jobs:
    print(f"{job.id}: {job.model_name} - Completed")

Cancel Fine-tuning Job

Cancel a fine-tuning job that's no longer needed:

python
# Cancel fine-tuning job
client.cancel_fine_tuning_job("ft_12345abcde")

Fine-tuning Metrics

Retrieve metrics for your fine-tuning job:

python
# Get fine-tuning metrics
metrics = client.get_fine_tuning_metrics("ft_12345abcde")

# Plot training and validation loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(metrics.steps, metrics.training_loss, label="Training Loss")
plt.plot(metrics.steps, metrics.validation_loss, label="Validation Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Fine-tuning Progress")
plt.legend()
plt.grid(True)
plt.savefig("fine_tuning_progress.png")

Fine-tuned Model Deployment

Deploy your fine-tuned model:

python
# Get fine-tuned model ID
fine_tuning_job = client.get_fine_tuning_job("ft_12345abcde")
fine_tuned_model = fine_tuning_job.fine_tuned_model

# Deploy the fine-tuned model
deployment = client.create_deployment(
    model_name=fine_tuned_model,
    deployment_name="clinical-assistant-cardiology"
)

# Wait for deployment to complete
deployment.wait_until_ready()
print(f"Deployment is now {deployment.status}")
print(f"Endpoint: {deployment.endpoint}")

# Test the fine-tuned model
response = client.generate_text(
    deployment_id=deployment.id,
    prompt="What is the recommended first-line treatment for hypertension in diabetic patients?"
)

print(response.text)

Fine-tuning Evaluation

Evaluate your fine-tuned model against the base model:

python
from paitient_secure_model import Client
from paitient_secure_model.evaluation import ComparativeEvaluation

client = Client()

# Get fine-tuned model deployment
fine_tuned_deployment = client.get_deployment("dep_fine_tuned")

# Get base model deployment
base_model_deployment = client.get_deployment("dep_base_model")

# Create comparative evaluation
evaluation = ComparativeEvaluation(
    name="Clinical Guidelines Adherence Evaluation",
    deployments=[
        {"id": base_model_deployment.id, "name": "Base Model"},
        {"id": fine_tuned_deployment.id, "name": "Fine-tuned Model"}
    ],
    metrics=[
        "factuality",
        "guideline_adherence",
        "completeness",
        "clinical_accuracy"
    ]
)

# Run evaluation on test dataset
results = evaluation.run(
    dataset="clinical_test_cases.jsonl",
    num_samples=50
)

# Print comparative results
print("Comparative Evaluation Results:")
for metric in results.metrics:
    base_score = results.scores["Base Model"][metric]
    fine_tuned_score = results.scores["Fine-tuned Model"][metric]
    improvement = fine_tuned_score - base_score
    print(f"{metric}: {base_score:.2f}{fine_tuned_score:.2f} ({improvement:+.2f})")

Advanced Use Cases

Multi-task Fine-tuning

Fine-tune a model for multiple medical specialties:

python
# Upload datasets for different specialties
cardiology_dataset = client.upload_dataset(
    file_path="cardiology_guidelines.jsonl",
    name="Cardiology Guidelines Dataset",
    format="instruction"
)

endocrinology_dataset = client.upload_dataset(
    file_path="endocrinology_guidelines.jsonl",
    name="Endocrinology Guidelines Dataset",
    format="instruction"
)

oncology_dataset = client.upload_dataset(
    file_path="oncology_guidelines.jsonl",
    name="Oncology Guidelines Dataset",
    format="instruction"
)

# Combine datasets for multi-task fine-tuning
combined_dataset = client.merge_datasets(
    dataset_ids=[
        cardiology_dataset.id,
        endocrinology_dataset.id,
        oncology_dataset.id
    ],
    name="Multi-specialty Guidelines Dataset"
)

# Fine-tune on the combined dataset
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id=combined_dataset.id,
    fine_tuning_method="lora",
    model_name="clinical-assistant-multi-specialty-v1"
)

Continual Fine-tuning

Update your fine-tuned model with new data:

python
# Get existing fine-tuned model
existing_fine_tuned_model = "ft:ZimaBlueAI/HuatuoGPT-o1-8B:clinical-assistant-v1"

# Upload new dataset with updated guidelines
new_guidelines_dataset = client.upload_dataset(
    file_path="updated_guidelines_2023.jsonl",
    name="Updated Guidelines 2023",
    format="instruction"
)

# Fine-tune from the previously fine-tuned model
fine_tuning_job = client.create_fine_tuning_job(
    base_model=existing_fine_tuned_model,  # Start from previously fine-tuned model
    dataset_id=new_guidelines_dataset.id,
    fine_tuning_method="lora",
    model_name="clinical-assistant-v2",
    hyperparameters={
        "learning_rate": 5e-5,  # Lower learning rate for continued fine-tuning
        "epochs": 1             # Fewer epochs needed
    }
)

Domain Adaptation with Synthetic Data

Generate and use synthetic data for domain adaptation:

python
from paitient_secure_model import Client
from paitient_secure_model.data_generation import generate_synthetic_data

client = Client()

# Generate synthetic medical conversations
synthetic_data = generate_synthetic_data(
    generation_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    domain="neurology",
    template_file="neurology_templates.json",
    num_examples=500,
    output_file="synthetic_neurology_dataset.jsonl"
)

# Upload synthetic dataset
synthetic_dataset = client.upload_dataset(
    file_path="synthetic_neurology_dataset.jsonl",
    name="Synthetic Neurology Dataset",
    format="conversation"
)

# Fine-tune with synthetic data
fine_tuning_job = client.create_fine_tuning_job(
    base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
    dataset_id=synthetic_dataset.id,
    fine_tuning_method="lora",
    model_name="neurology-assistant-v1"
)

Error Handling

Implement proper error handling for fine-tuning:

python
from paitient_secure_model import Client
from paitient_secure_model.exceptions import (
    FineTuningError,
    ResourceNotFoundError,
    QuotaExceededError,
    InvalidParameterError
)

client = Client()

try:
    fine_tuning_job = client.create_fine_tuning_job(
        base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
        dataset_id="ds_12345abcde",
        fine_tuning_method="lora",
        model_name="clinical-assistant-v1"
    )
except InvalidParameterError as e:
    print(f"Invalid parameter: {e}")
except QuotaExceededError as e:
    print(f"Quota exceeded: {e}")
except FineTuningError as e:
    print(f"Fine-tuning failed: {e}")
    print(f"Job ID: {e.job_id}")
    print(f"Status: {e.status}")
    print(f"Reason: {e.reason}")
    
    # Get detailed error information
    if e.job_id:
        logs = client.get_fine_tuning_logs(
            job_id=e.job_id,
            limit=10
        )
        print("Error logs:")
        for log in logs:
            print(f"  {log.message}")

Best Practices

Dataset Quality

Follow these best practices for dataset preparation:

  1. Ensure diversity in your examples
  2. Focus on domain-specific content relevant to your use case
  3. Include edge cases common in clinical scenarios
  4. Balance response length appropriate for your application
  5. Use high-quality, verified medical information
  6. De-identify any patient data according to HIPAA guidelines

Hyperparameter Optimization

Optimize hyperparameters for your specific use case:

  1. Start with defaults for your first fine-tuning job
  2. Experiment with learning rates between 1e-5 and 5e-4
  3. Adjust batch size based on available memory
  4. Try different LoRA ranks (8, 16, 32) to balance efficiency and effectiveness
  5. Use early stopping to prevent overfitting

Infrastructure Considerations

Optimize your infrastructure for fine-tuning:

  1. Choose appropriate instance types based on model size
  2. Monitor GPU memory usage to avoid out-of-memory errors
  3. Implement checkpointing for long-running jobs
  4. Use distributed training for very large models

Evaluation Strategy

Develop a robust evaluation strategy:

  1. Create a dedicated test set not used in training
  2. Define domain-specific metrics relevant to clinical performance
  3. Compare against baseline models
  4. Involve domain experts in qualitative evaluation
  5. Test on realistic scenarios from your specific healthcare domain

Next Steps

Released under the MIT License.