Python SDK: Fine-tuning
This guide covers the fine-tuning capabilities of the PaiTIENT Secure Model Service Python SDK, allowing you to customize AI models for healthcare-specific use cases while maintaining HIPAA/SOC2 compliance.
Fine-tuning Overview
Fine-tuning allows you to adapt pre-trained models to your specific healthcare domain or use case, improving performance, accuracy, and adherence to guidelines.
The PaiTIENT Secure Model Service provides secure fine-tuning capabilities with:
- Data Security: All fine-tuning data is encrypted and processed in secure environments
- Compliance: The entire fine-tuning process adheres to HIPAA and SOC2 requirements
- Efficient Training: Parameter-efficient fine-tuning methods like LoRA
- Performance Tracking: Detailed metrics for fine-tuning progress and evaluation
Prerequisites
Before fine-tuning a model, ensure you have:
- Installed the PaiTIENT Python SDK
- Set up authentication credentials
- Prepared a fine-tuning dataset
- Selected a base model for fine-tuning
Dataset Preparation
Dataset Format
The PaiTIENT Secure Model Service supports various dataset formats for different fine-tuning approaches:
Instruction Dataset Format (JSONL)
For instruction fine-tuning, use a JSONL file with this structure:
{"instruction": "Explain the mechanism of action for metformin", "response": "Metformin primarily works by reducing glucose production in the liver (hepatic gluconeogenesis), improving insulin sensitivity in peripheral tissues, and enhancing glucose uptake in skeletal muscle. It also decreases intestinal absorption of glucose."}
{"instruction": "What are the symptoms of diabetic ketoacidosis?", "response": "The main symptoms of diabetic ketoacidosis include excessive thirst, frequent urination, nausea and vomiting, abdominal pain, weakness or fatigue, shortness of breath, fruity-scented breath, and confusion. It may also present with high blood glucose levels, high ketone levels in the urine, and electrolyte imbalances."}Conversation Dataset Format (JSONL)
For conversational fine-tuning, use a JSONL file with this structure:
{"conversations": [{"role": "user", "content": "What are the side effects of atorvastatin?"}, {"role": "assistant", "content": "Atorvastatin may cause side effects including muscle pain and weakness, which in rare cases can lead to serious muscle damage (rhabdomyolysis). Other common side effects include joint pain, diarrhea, nausea, and elevated liver enzymes. Rare but serious side effects include liver damage, diabetes, and memory problems. Always consult your healthcare provider if you experience concerning symptoms while taking atorvastatin."}]}
{"conversations": [{"role": "user", "content": "I'm experiencing leg pain after starting atorvastatin. Should I be concerned?"}, {"role": "assistant", "content": "Leg pain after starting atorvastatin could be a sign of muscle-related side effects, which should be taken seriously. Please contact your healthcare provider immediately to discuss these symptoms. Your doctor might want to check your creatine kinase levels to rule out muscle damage. Don't stop taking the medication without medical advice, but this symptom warrants prompt medical attention."}]}Dataset Validation
Validate your dataset before fine-tuning:
from paitient_secure_model import Client
from paitient_secure_model.validation import validate_fine_tuning_dataset
client = Client()
# Validate dataset
validation_result = client.validate_dataset(
file_path="clinical_guidelines_dataset.jsonl",
format="instruction" # or "conversation"
)
if validation_result.is_valid:
print("Dataset validation successful")
print(f"Number of examples: {validation_result.num_examples}")
print(f"Total tokens: {validation_result.total_tokens}")
else:
print("Dataset validation failed:")
for error in validation_result.errors:
print(f" - {error}")Dataset Upload
Upload your dataset to the secure environment:
# Upload dataset
dataset = client.upload_dataset(
file_path="clinical_guidelines_dataset.jsonl",
name="Clinical Guidelines Dataset",
description="Fine-tuning dataset for clinical guideline adherence",
format="instruction",
tags={"domain": "cardiology", "source": "guidelines"}
)
print(f"Dataset ID: {dataset.id}")
print(f"Number of examples: {dataset.num_examples}")
print(f"Total tokens: {dataset.total_tokens}")Basic Fine-tuning
The simplest way to fine-tune a model:
from paitient_secure_model import Client
# Initialize client
client = Client()
# Start a fine-tuning job
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1"
)
print(f"Fine-tuning job ID: {fine_tuning_job.id}")
print(f"Status: {fine_tuning_job.status}")
# Wait for fine-tuning to complete
fine_tuning_job.wait_until_complete()
print(f"Fine-tuning is now {fine_tuning_job.status}")
print(f"Fine-tuned model: {fine_tuning_job.fine_tuned_model}")Fine-tuning Options
Training Parameters
Configure training parameters for your fine-tuning job:
# Fine-tuning with specific training parameters
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1",
hyperparameters={
"learning_rate": 1e-4,
"batch_size": 8,
"epochs": 3,
"warmup_steps": 100,
"lora_rank": 16,
"lora_alpha": 32,
"lora_dropout": 0.05
}
)Validation Dataset
Use a validation dataset to monitor training progress:
# Upload validation dataset
validation_dataset = client.upload_dataset(
file_path="clinical_guidelines_validation.jsonl",
name="Clinical Guidelines Validation Dataset",
format="instruction"
)
# Fine-tuning with validation
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
validation_dataset_id=validation_dataset.id,
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1",
validation_frequency=0.1 # Validate after 10% of training steps
)Compute Configuration
Configure compute resources for your fine-tuning job:
# Fine-tuning with compute configuration
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1",
compute_config={
"instance_type": "g5.2xlarge",
"instance_count": 1,
"max_runtime_hours": 24
}
)Security Settings
Apply security settings to your fine-tuning job:
from paitient_secure_model import Client
from paitient_secure_model.security import SecuritySettings
# Initialize client
client = Client()
# Fine-tuning with security settings
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1",
security_settings=SecuritySettings(
network_isolation=True, # Enable network isolation
encryption_level="maximum", # Maximum encryption level
audit_logging=True, # Enable comprehensive audit logging
compliance_mode="hipaa" # Enable HIPAA compliance mode
)
)Advanced Fine-tuning Methods
LoRA Fine-tuning
Low-Rank Adaptation (LoRA) is the default fine-tuning method:
# LoRA fine-tuning with advanced parameters
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-cardiology-v1",
hyperparameters={
"lora_rank": 16, # Rank of LoRA matrices
"lora_alpha": 32, # LoRA scaling factor
"lora_dropout": 0.05, # Dropout probability for LoRA layers
"target_modules": ["q_proj", "v_proj"], # Target modules for LoRA
"learning_rate": 1e-4,
"batch_size": 8,
"epochs": 3
}
)QLoRA Fine-tuning
Quantized LoRA (QLoRA) for memory-efficient fine-tuning:
# QLoRA fine-tuning
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="qlora",
model_name="clinical-assistant-cardiology-v1",
hyperparameters={
"quantization_bits": 4, # 4-bit quantization
"lora_rank": 16,
"lora_alpha": 32,
"learning_rate": 1e-4,
"batch_size": 8,
"epochs": 3
}
)Full Fine-tuning
Full parameter fine-tuning for maximum performance:
# Full fine-tuning (requires more compute resources)
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="full",
model_name="clinical-assistant-cardiology-v1",
hyperparameters={
"learning_rate": 5e-5,
"batch_size": 4,
"epochs": 2,
"warmup_steps": 100,
"weight_decay": 0.01
},
compute_config={
"instance_type": "g5.12xlarge", # More powerful GPU instance
"instance_count": 2,
"max_runtime_hours": 48
}
)Fine-tuning Management
Check Fine-tuning Status
Monitor the status of your fine-tuning job:
# Get fine-tuning job status
job = client.get_fine_tuning_job("ft_12345abcde")
print(f"Status: {job.status}")
print(f"Created: {job.created_at}")
print(f"Updated: {job.updated_at}")
print(f"Base model: {job.base_model}")
# Get detailed job information
details = job.get_details()
print(f"Current epoch: {details.current_epoch}/{details.total_epochs}")
print(f"Training loss: {details.training_loss}")
print(f"Validation loss: {details.validation_loss}")
print(f"Training samples processed: {details.processed_samples}")List Fine-tuning Jobs
Retrieve a list of all your fine-tuning jobs:
# List all fine-tuning jobs
jobs = client.list_fine_tuning_jobs()
for job in jobs:
print(f"{job.id}: {job.model_name} - {job.status}")
# Filter fine-tuning jobs
completed_jobs = client.list_fine_tuning_jobs(
filters={
"status": "completed",
"base_model": "ZimaBlueAI/HuatuoGPT-o1-8B"
}
)
for job in completed_jobs:
print(f"{job.id}: {job.model_name} - Completed")Cancel Fine-tuning Job
Cancel a fine-tuning job that's no longer needed:
# Cancel fine-tuning job
client.cancel_fine_tuning_job("ft_12345abcde")Fine-tuning Metrics
Retrieve metrics for your fine-tuning job:
# Get fine-tuning metrics
metrics = client.get_fine_tuning_metrics("ft_12345abcde")
# Plot training and validation loss
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(metrics.steps, metrics.training_loss, label="Training Loss")
plt.plot(metrics.steps, metrics.validation_loss, label="Validation Loss")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Fine-tuning Progress")
plt.legend()
plt.grid(True)
plt.savefig("fine_tuning_progress.png")Fine-tuned Model Deployment
Deploy your fine-tuned model:
# Get fine-tuned model ID
fine_tuning_job = client.get_fine_tuning_job("ft_12345abcde")
fine_tuned_model = fine_tuning_job.fine_tuned_model
# Deploy the fine-tuned model
deployment = client.create_deployment(
model_name=fine_tuned_model,
deployment_name="clinical-assistant-cardiology"
)
# Wait for deployment to complete
deployment.wait_until_ready()
print(f"Deployment is now {deployment.status}")
print(f"Endpoint: {deployment.endpoint}")
# Test the fine-tuned model
response = client.generate_text(
deployment_id=deployment.id,
prompt="What is the recommended first-line treatment for hypertension in diabetic patients?"
)
print(response.text)Fine-tuning Evaluation
Evaluate your fine-tuned model against the base model:
from paitient_secure_model import Client
from paitient_secure_model.evaluation import ComparativeEvaluation
client = Client()
# Get fine-tuned model deployment
fine_tuned_deployment = client.get_deployment("dep_fine_tuned")
# Get base model deployment
base_model_deployment = client.get_deployment("dep_base_model")
# Create comparative evaluation
evaluation = ComparativeEvaluation(
name="Clinical Guidelines Adherence Evaluation",
deployments=[
{"id": base_model_deployment.id, "name": "Base Model"},
{"id": fine_tuned_deployment.id, "name": "Fine-tuned Model"}
],
metrics=[
"factuality",
"guideline_adherence",
"completeness",
"clinical_accuracy"
]
)
# Run evaluation on test dataset
results = evaluation.run(
dataset="clinical_test_cases.jsonl",
num_samples=50
)
# Print comparative results
print("Comparative Evaluation Results:")
for metric in results.metrics:
base_score = results.scores["Base Model"][metric]
fine_tuned_score = results.scores["Fine-tuned Model"][metric]
improvement = fine_tuned_score - base_score
print(f"{metric}: {base_score:.2f} → {fine_tuned_score:.2f} ({improvement:+.2f})")Advanced Use Cases
Multi-task Fine-tuning
Fine-tune a model for multiple medical specialties:
# Upload datasets for different specialties
cardiology_dataset = client.upload_dataset(
file_path="cardiology_guidelines.jsonl",
name="Cardiology Guidelines Dataset",
format="instruction"
)
endocrinology_dataset = client.upload_dataset(
file_path="endocrinology_guidelines.jsonl",
name="Endocrinology Guidelines Dataset",
format="instruction"
)
oncology_dataset = client.upload_dataset(
file_path="oncology_guidelines.jsonl",
name="Oncology Guidelines Dataset",
format="instruction"
)
# Combine datasets for multi-task fine-tuning
combined_dataset = client.merge_datasets(
dataset_ids=[
cardiology_dataset.id,
endocrinology_dataset.id,
oncology_dataset.id
],
name="Multi-specialty Guidelines Dataset"
)
# Fine-tune on the combined dataset
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id=combined_dataset.id,
fine_tuning_method="lora",
model_name="clinical-assistant-multi-specialty-v1"
)Continual Fine-tuning
Update your fine-tuned model with new data:
# Get existing fine-tuned model
existing_fine_tuned_model = "ft:ZimaBlueAI/HuatuoGPT-o1-8B:clinical-assistant-v1"
# Upload new dataset with updated guidelines
new_guidelines_dataset = client.upload_dataset(
file_path="updated_guidelines_2023.jsonl",
name="Updated Guidelines 2023",
format="instruction"
)
# Fine-tune from the previously fine-tuned model
fine_tuning_job = client.create_fine_tuning_job(
base_model=existing_fine_tuned_model, # Start from previously fine-tuned model
dataset_id=new_guidelines_dataset.id,
fine_tuning_method="lora",
model_name="clinical-assistant-v2",
hyperparameters={
"learning_rate": 5e-5, # Lower learning rate for continued fine-tuning
"epochs": 1 # Fewer epochs needed
}
)Domain Adaptation with Synthetic Data
Generate and use synthetic data for domain adaptation:
from paitient_secure_model import Client
from paitient_secure_model.data_generation import generate_synthetic_data
client = Client()
# Generate synthetic medical conversations
synthetic_data = generate_synthetic_data(
generation_model="ZimaBlueAI/HuatuoGPT-o1-8B",
domain="neurology",
template_file="neurology_templates.json",
num_examples=500,
output_file="synthetic_neurology_dataset.jsonl"
)
# Upload synthetic dataset
synthetic_dataset = client.upload_dataset(
file_path="synthetic_neurology_dataset.jsonl",
name="Synthetic Neurology Dataset",
format="conversation"
)
# Fine-tune with synthetic data
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id=synthetic_dataset.id,
fine_tuning_method="lora",
model_name="neurology-assistant-v1"
)Error Handling
Implement proper error handling for fine-tuning:
from paitient_secure_model import Client
from paitient_secure_model.exceptions import (
FineTuningError,
ResourceNotFoundError,
QuotaExceededError,
InvalidParameterError
)
client = Client()
try:
fine_tuning_job = client.create_fine_tuning_job(
base_model="ZimaBlueAI/HuatuoGPT-o1-8B",
dataset_id="ds_12345abcde",
fine_tuning_method="lora",
model_name="clinical-assistant-v1"
)
except InvalidParameterError as e:
print(f"Invalid parameter: {e}")
except QuotaExceededError as e:
print(f"Quota exceeded: {e}")
except FineTuningError as e:
print(f"Fine-tuning failed: {e}")
print(f"Job ID: {e.job_id}")
print(f"Status: {e.status}")
print(f"Reason: {e.reason}")
# Get detailed error information
if e.job_id:
logs = client.get_fine_tuning_logs(
job_id=e.job_id,
limit=10
)
print("Error logs:")
for log in logs:
print(f" {log.message}")Best Practices
Dataset Quality
Follow these best practices for dataset preparation:
- Ensure diversity in your examples
- Focus on domain-specific content relevant to your use case
- Include edge cases common in clinical scenarios
- Balance response length appropriate for your application
- Use high-quality, verified medical information
- De-identify any patient data according to HIPAA guidelines
Hyperparameter Optimization
Optimize hyperparameters for your specific use case:
- Start with defaults for your first fine-tuning job
- Experiment with learning rates between 1e-5 and 5e-4
- Adjust batch size based on available memory
- Try different LoRA ranks (8, 16, 32) to balance efficiency and effectiveness
- Use early stopping to prevent overfitting
Infrastructure Considerations
Optimize your infrastructure for fine-tuning:
- Choose appropriate instance types based on model size
- Monitor GPU memory usage to avoid out-of-memory errors
- Implement checkpointing for long-running jobs
- Use distributed training for very large models
Evaluation Strategy
Develop a robust evaluation strategy:
- Create a dedicated test set not used in training
- Define domain-specific metrics relevant to clinical performance
- Compare against baseline models
- Involve domain experts in qualitative evaluation
- Test on realistic scenarios from your specific healthcare domain
Next Steps
- Learn about Text Generation
- Explore Deployment Options
- Understand Model Evaluation
- Review our Security Best Practices